1 //===- StructuredOpsUtilsTest.cpp - StructuredOpsUtils unit tests ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 10 #include "mlir/IR/AffineExpr.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "gmock/gmock.h" 13 #include "gtest/gtest.h" 14 15 using namespace mlir; 16 using testing::Not; 17 using testing::Truly; 18 19 namespace { 20 21 TEST(isRowMajorMatmul, Simple) { 22 MLIRContext context; 23 24 AffineExpr m, n, k; 25 bindDims(&context, m, n, k); 26 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); 27 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); 28 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); 29 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 30 31 EXPECT_THAT(maps, Truly(isRowMajorMatmul)); 32 } 33 34 TEST(isRowMajorMatmul, BindingShifted) { 35 MLIRContext context; 36 37 AffineExpr m, n, k; 38 bindDims(&context, k, m, n); // bind in different order 39 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); 40 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); 41 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); 42 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 43 44 EXPECT_THAT(maps, Truly(isRowMajorMatmul)); 45 } 46 47 TEST(isRowMajorMatmul, BindingSwapped) { 48 MLIRContext context; 49 50 AffineExpr m, n, k; 51 bindDims(&context, k, n, m); // bind in different order 52 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); 53 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); 54 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); 55 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 56 57 EXPECT_THAT(maps, Truly(isRowMajorMatmul)); 58 } 59 60 TEST(isRowMajorMatmul, ColumnMajor) { 61 MLIRContext context; 62 63 AffineExpr m, n, k; 64 bindDims(&context, m, n, k); 65 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); 66 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); 67 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); 68 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 69 70 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); 71 } 72 73 TEST(isRowMajorMatmul, FirstInputSwapped) { 74 MLIRContext context; 75 76 AffineExpr m, n, k; 77 bindDims(&context, m, n, k); 78 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context)); 79 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); 80 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); 81 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 82 83 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); 84 } 85 86 TEST(isRowMajorMatmul, TooFewMaps) { 87 MLIRContext context; 88 89 AffineExpr m, n, k; 90 bindDims(&context, m, n, k); 91 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); 92 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); 93 auto maps = ArrayAttr::get(&context, {mapA, mapB}); 94 95 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); 96 } 97 98 TEST(isRowMajorMatmul, TooManyMaps) { 99 MLIRContext context; 100 101 AffineExpr m, n, k; 102 bindDims(&context, m, n, k); 103 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); 104 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); 105 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); 106 auto mapD = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context)); 107 108 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC, mapD}); 109 110 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); 111 } 112 113 TEST(isRowMajorMatmul, TooFewDims) { 114 MLIRContext context; 115 116 AffineExpr m, n, k; 117 bindDims(&context, m, n, k); 118 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); 119 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context)); 120 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); 121 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 122 123 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); 124 } 125 126 TEST(isRowMajorMatmul, TooFewOutputs) { 127 MLIRContext context; 128 129 AffineExpr m, n, k; 130 bindDims(&context, m, n, k); 131 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m}, &context)); 132 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); 133 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); 134 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 135 136 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); 137 } 138 139 TEST(isColumnMajorMatmul, Simple) { 140 MLIRContext context; 141 142 AffineExpr m, n, k; 143 bindDims(&context, m, n, k); 144 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); 145 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); 146 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); 147 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 148 149 EXPECT_THAT(maps, Truly(isColumnMajorMatmul)); 150 } 151 152 TEST(isColumnMajorMatmul, BindingShifted) { 153 MLIRContext context; 154 155 AffineExpr m, n, k; 156 bindDims(&context, k, m, n); // bind in different order 157 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); 158 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); 159 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); 160 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 161 162 EXPECT_THAT(maps, Truly(isColumnMajorMatmul)); 163 } 164 165 TEST(isColumnMajorMatmul, BindingSwapped) { 166 MLIRContext context; 167 168 AffineExpr m, n, k; 169 bindDims(&context, k, n, m); // bind in different order 170 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); 171 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); 172 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); 173 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 174 175 EXPECT_THAT(maps, Truly(isColumnMajorMatmul)); 176 } 177 178 TEST(isColumnMajorMatmul, RowMajor) { 179 MLIRContext context; 180 181 AffineExpr m, n, k; 182 bindDims(&context, m, n, k); 183 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); 184 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); 185 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); 186 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 187 188 EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul))); 189 } 190 191 TEST(isColumnMajorMatmul, FirstInputSwapped) { 192 MLIRContext context; 193 194 AffineExpr m, n, k; 195 bindDims(&context, m, n, k); 196 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, &context)); 197 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); 198 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); 199 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 200 201 EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul))); 202 } 203 204 TEST(isRowMajorBatchMatmul, Simple) { 205 MLIRContext context; 206 207 AffineExpr batch, m, n, k; 208 bindDims(&context, batch, m, n, k); 209 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context)); 210 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context)); 211 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context)); 212 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 213 214 EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul)); 215 } 216 217 TEST(isRowMajorBatchMatmul, BindingShifted) { 218 MLIRContext context; 219 220 AffineExpr batch, m, n, k; 221 bindDims(&context, k, batch, m, n); // bind in different order 222 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context)); 223 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context)); 224 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context)); 225 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 226 227 EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul)); 228 } 229 230 TEST(isRowMajorBatchMatmul, BindingSwapped) { 231 MLIRContext context; 232 233 AffineExpr batch, m, n, k; 234 bindDims(&context, batch, k, n, m); // bind in different order 235 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context)); 236 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context)); 237 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context)); 238 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 239 240 EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul)); 241 } 242 243 TEST(isRowMajorBatchMatmul, FirstInputSwapped) { 244 MLIRContext context; 245 246 AffineExpr batch, m, n, k; 247 bindDims(&context, batch, m, n, k); 248 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, m}, &context)); 249 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context)); 250 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context)); 251 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); 252 253 EXPECT_THAT(maps, Not(Truly(isRowMajorBatchMatmul))); 254 } 255 256 } // namespace 257