xref: /llvm-project/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp (revision ccc02563f4d620d4d29a1cbd2c463871cc54745b)
1db011775SGeoffrey Martin-Noble //===- StructuredOpsUtils.cpp - Utilities used by structured ops ----------===//
2db011775SGeoffrey Martin-Noble //
3db011775SGeoffrey Martin-Noble // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4db011775SGeoffrey Martin-Noble // See https://llvm.org/LICENSE.txt for license information.
5db011775SGeoffrey Martin-Noble // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6db011775SGeoffrey Martin-Noble //
7db011775SGeoffrey Martin-Noble //===----------------------------------------------------------------------===//
8db011775SGeoffrey Martin-Noble 
9db011775SGeoffrey Martin-Noble #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10db011775SGeoffrey Martin-Noble #include "mlir/IR/AffineMap.h"
11f286af29SAlexander Belyaev #include "mlir/IR/Builders.h"
12db011775SGeoffrey Martin-Noble #include "mlir/IR/BuiltinAttributes.h"
134d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
14c21e88ccSMahesh Ravishankar #include "llvm/ADT/StringSet.h"
15db011775SGeoffrey Martin-Noble 
164f1c1242SOleg Shyshkov #include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc"
174f1c1242SOleg Shyshkov 
18db011775SGeoffrey Martin-Noble using namespace mlir;
19db011775SGeoffrey Martin-Noble 
isRowMajorMatmul(ArrayAttr indexingMaps)20db011775SGeoffrey Martin-Noble bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) {
21db011775SGeoffrey Martin-Noble   if (indexingMaps.size() != 3)
22db011775SGeoffrey Martin-Noble     return false;
23db011775SGeoffrey Martin-Noble 
249f495098SNatashaKnk   AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
259f495098SNatashaKnk   AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
269f495098SNatashaKnk   AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
27db011775SGeoffrey Martin-Noble 
28db011775SGeoffrey Martin-Noble   if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
29db011775SGeoffrey Martin-Noble       map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
30db011775SGeoffrey Martin-Noble       map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
31db011775SGeoffrey Martin-Noble     return false;
32db011775SGeoffrey Martin-Noble   }
33db011775SGeoffrey Martin-Noble 
34db011775SGeoffrey Martin-Noble   // Extract dimensions for MxK * KxN -> MxN
35db011775SGeoffrey Martin-Noble   AffineExpr m = map2.getResult(0);
36db011775SGeoffrey Martin-Noble   AffineExpr n = map2.getResult(1);
37db011775SGeoffrey Martin-Noble   AffineExpr k = map0.getResult(1);
38db011775SGeoffrey Martin-Noble   auto *context = indexingMaps.getContext();
39db011775SGeoffrey Martin-Noble   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
40db011775SGeoffrey Martin-Noble   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
41db011775SGeoffrey Martin-Noble   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
42db011775SGeoffrey Martin-Noble   auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
43db011775SGeoffrey Martin-Noble   return indexingMaps == maps;
44db011775SGeoffrey Martin-Noble }
45db011775SGeoffrey Martin-Noble 
isColumnMajorMatmul(ArrayAttr indexingMaps)46db011775SGeoffrey Martin-Noble bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) {
47db011775SGeoffrey Martin-Noble   if (indexingMaps.size() != 3)
48db011775SGeoffrey Martin-Noble     return false;
49db011775SGeoffrey Martin-Noble 
509f495098SNatashaKnk   AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
519f495098SNatashaKnk   AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
529f495098SNatashaKnk   AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
53db011775SGeoffrey Martin-Noble 
54db011775SGeoffrey Martin-Noble   if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
55db011775SGeoffrey Martin-Noble       map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
56db011775SGeoffrey Martin-Noble       map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
57db011775SGeoffrey Martin-Noble     return false;
58db011775SGeoffrey Martin-Noble   }
59db011775SGeoffrey Martin-Noble 
60db011775SGeoffrey Martin-Noble   // Extract dimensions for KxM * NxK -> NxM
61db011775SGeoffrey Martin-Noble   AffineExpr n = map2.getResult(0);
62db011775SGeoffrey Martin-Noble   AffineExpr m = map2.getResult(1);
63db011775SGeoffrey Martin-Noble   AffineExpr k = map0.getResult(0);
64db011775SGeoffrey Martin-Noble   auto *context = indexingMaps.getContext();
65db011775SGeoffrey Martin-Noble   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, context));
66db011775SGeoffrey Martin-Noble   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, context));
67db011775SGeoffrey Martin-Noble   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
68db011775SGeoffrey Martin-Noble   auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
69db011775SGeoffrey Martin-Noble   return indexingMaps == maps;
70db011775SGeoffrey Martin-Noble }
71db011775SGeoffrey Martin-Noble 
isRowMajorBatchMatmul(ArrayAttr indexingMaps)72db011775SGeoffrey Martin-Noble bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
73db011775SGeoffrey Martin-Noble   if (indexingMaps.size() != 3)
74db011775SGeoffrey Martin-Noble     return false;
75db011775SGeoffrey Martin-Noble 
769f495098SNatashaKnk   AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
779f495098SNatashaKnk   AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
789f495098SNatashaKnk   AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
79db011775SGeoffrey Martin-Noble 
80db011775SGeoffrey Martin-Noble   if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
81db011775SGeoffrey Martin-Noble       map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
82db011775SGeoffrey Martin-Noble       map1.getNumInputs() != 4 || map2.getNumInputs() != 4) {
83db011775SGeoffrey Martin-Noble     return false;
84db011775SGeoffrey Martin-Noble   }
85db011775SGeoffrey Martin-Noble 
86db011775SGeoffrey Martin-Noble   // Extract dimensions for BxMxK * BxKxN -> BxMxN
87db011775SGeoffrey Martin-Noble   AffineExpr b = map2.getResult(0);
88db011775SGeoffrey Martin-Noble   AffineExpr m = map2.getResult(1);
89db011775SGeoffrey Martin-Noble   AffineExpr n = map2.getResult(2);
90db011775SGeoffrey Martin-Noble   AffineExpr k = map0.getResult(2);
91db011775SGeoffrey Martin-Noble   auto *context = indexingMaps.getContext();
92db011775SGeoffrey Martin-Noble   auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, k}, context));
93db011775SGeoffrey Martin-Noble   auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {b, k, n}, context));
94db011775SGeoffrey Martin-Noble   auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, n}, context));
95db011775SGeoffrey Martin-Noble   auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
96db011775SGeoffrey Martin-Noble   return indexingMaps == maps;
97db011775SGeoffrey Martin-Noble }
98f286af29SAlexander Belyaev 
isVecmat(ArrayAttr indexingMaps)999f495098SNatashaKnk bool mlir::isVecmat(ArrayAttr indexingMaps) {
1009f495098SNatashaKnk   if (indexingMaps.size() != 3)
1019f495098SNatashaKnk     return false;
1029f495098SNatashaKnk   AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
1039f495098SNatashaKnk   AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
1049f495098SNatashaKnk   AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
1059f495098SNatashaKnk 
1069f495098SNatashaKnk   if (map0.getNumResults() != 1 || map1.getNumResults() != 2 ||
1079f495098SNatashaKnk       map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
1089f495098SNatashaKnk       map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
1099f495098SNatashaKnk     return false;
1109f495098SNatashaKnk   }
1119f495098SNatashaKnk 
1129f495098SNatashaKnk   // Extract dimensions for K * KxN -> N
1139f495098SNatashaKnk   AffineExpr k = map0.getResult(0);
1149f495098SNatashaKnk   AffineExpr n = map2.getResult(0);
1159f495098SNatashaKnk   auto *context = indexingMaps.getContext();
1169f495098SNatashaKnk   auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
1179f495098SNatashaKnk   auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, context));
1189f495098SNatashaKnk   auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
1199f495098SNatashaKnk   auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
1209f495098SNatashaKnk   return indexingMaps == maps;
1219f495098SNatashaKnk }
1229f495098SNatashaKnk 
isBatchVecmat(ArrayAttr indexingMaps)1238a80e331Sbjacob bool mlir::isBatchVecmat(ArrayAttr indexingMaps) {
1248a80e331Sbjacob   if (indexingMaps.size() != 3)
1258a80e331Sbjacob     return false;
1268a80e331Sbjacob   AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
1278a80e331Sbjacob   AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
1288a80e331Sbjacob   AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
1298a80e331Sbjacob 
1308a80e331Sbjacob   if (map0.getNumResults() != 2 || map1.getNumResults() != 3 ||
1318a80e331Sbjacob       map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
1328a80e331Sbjacob       map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
1338a80e331Sbjacob     return false;
1348a80e331Sbjacob   }
1358a80e331Sbjacob 
1368a80e331Sbjacob   // Extract dimensions for B*K * B*K*N -> B*N
1378a80e331Sbjacob   AffineExpr b = map0.getResult(0);
1388a80e331Sbjacob   AffineExpr k = map0.getResult(1);
1398a80e331Sbjacob   AffineExpr n = map2.getResult(1);
1408a80e331Sbjacob   auto *context = indexingMaps.getContext();
1418a80e331Sbjacob   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
1428a80e331Sbjacob   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k, n}, context));
1438a80e331Sbjacob   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
1448a80e331Sbjacob   auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
1458a80e331Sbjacob   return indexingMaps == maps;
1468a80e331Sbjacob }
1478a80e331Sbjacob 
isMatvec(ArrayAttr indexingMaps)1489f495098SNatashaKnk bool mlir::isMatvec(ArrayAttr indexingMaps) {
1499f495098SNatashaKnk   if (indexingMaps.size() != 3)
1509f495098SNatashaKnk     return false;
1519f495098SNatashaKnk   AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
1529f495098SNatashaKnk   AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
1539f495098SNatashaKnk   AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
1549f495098SNatashaKnk 
1559f495098SNatashaKnk   if (map0.getNumResults() != 2 || map1.getNumResults() != 1 ||
1569f495098SNatashaKnk       map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
1579f495098SNatashaKnk       map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
1589f495098SNatashaKnk     return false;
1599f495098SNatashaKnk   }
1609f495098SNatashaKnk 
1619f495098SNatashaKnk   // Extract dimensions for N*K * K -> N
1629f495098SNatashaKnk   AffineExpr k = map1.getResult(0);
1639f495098SNatashaKnk   AffineExpr n = map2.getResult(0);
1649f495098SNatashaKnk   auto *context = indexingMaps.getContext();
1659f495098SNatashaKnk   auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, context));
1669f495098SNatashaKnk   auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
1679f495098SNatashaKnk   auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
1689f495098SNatashaKnk   auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
1699f495098SNatashaKnk   return indexingMaps == maps;
1709f495098SNatashaKnk }
1719f495098SNatashaKnk 
isBatchMatvec(ArrayAttr indexingMaps)1729f495098SNatashaKnk bool mlir::isBatchMatvec(ArrayAttr indexingMaps) {
1739f495098SNatashaKnk   if (indexingMaps.size() != 3)
1749f495098SNatashaKnk     return false;
1759f495098SNatashaKnk   AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
1769f495098SNatashaKnk   AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
1779f495098SNatashaKnk   AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
1789f495098SNatashaKnk 
1799f495098SNatashaKnk   if (map0.getNumResults() != 3 || map1.getNumResults() != 2 ||
1809f495098SNatashaKnk       map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
1819f495098SNatashaKnk       map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
1829f495098SNatashaKnk     return false;
1839f495098SNatashaKnk   }
1849f495098SNatashaKnk 
1859f495098SNatashaKnk   // Extract dimensions for B*N*K * B*K -> B*N
1869f495098SNatashaKnk   AffineExpr b = map0.getResult(0);
1879f495098SNatashaKnk   AffineExpr k = map1.getResult(1);
1889f495098SNatashaKnk   AffineExpr n = map2.getResult(1);
1899f495098SNatashaKnk   auto *context = indexingMaps.getContext();
1909f495098SNatashaKnk   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, n, k}, context));
1919f495098SNatashaKnk   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
1929f495098SNatashaKnk   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
1939f495098SNatashaKnk   auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
1949f495098SNatashaKnk   return indexingMaps == maps;
1959f495098SNatashaKnk }
1969f495098SNatashaKnk 
clone(OpBuilder & b,Operation * op,TypeRange newResultTypes,ValueRange newOperands)197f286af29SAlexander Belyaev Operation *mlir::clone(OpBuilder &b, Operation *op, TypeRange newResultTypes,
198f286af29SAlexander Belyaev                        ValueRange newOperands) {
1994d67b278SJeff Niu   IRMapping bvm;
200f286af29SAlexander Belyaev   OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
201f286af29SAlexander Belyaev                        op->getAttrs());
202*ccc02563SAviad Cohen   for (Region &r : op->getRegions()) {
203*ccc02563SAviad Cohen     Region *newRegion = state.addRegion();
204*ccc02563SAviad Cohen     b.cloneRegionBefore(r, *newRegion, newRegion->begin(), bvm);
205*ccc02563SAviad Cohen   }
206f286af29SAlexander Belyaev   return b.create(state);
207f286af29SAlexander Belyaev }
208f286af29SAlexander Belyaev 
cloneWithoutRegions(OpBuilder & b,Operation * op,TypeRange newResultTypes,ValueRange newOperands)209f286af29SAlexander Belyaev Operation *mlir::cloneWithoutRegions(OpBuilder &b, Operation *op,
210f286af29SAlexander Belyaev                                      TypeRange newResultTypes,
211f286af29SAlexander Belyaev                                      ValueRange newOperands) {
212f286af29SAlexander Belyaev   OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
213f286af29SAlexander Belyaev                        op->getAttrs());
214f286af29SAlexander Belyaev   for (size_t cnt = 0, e = op->getNumRegions(); cnt < e; ++cnt)
215f286af29SAlexander Belyaev     state.addRegion();
216f286af29SAlexander Belyaev   return b.create(state);
217f286af29SAlexander Belyaev }
218c21e88ccSMahesh Ravishankar 
219c21e88ccSMahesh Ravishankar SmallVector<NamedAttribute>
getPrunedAttributeList(Operation * op,ArrayRef<StringRef> elidedAttrs)220c21e88ccSMahesh Ravishankar mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
221d0e507f5SMahesh Ravishankar   llvm::StringSet<> elidedAttrsSet;
222c21e88ccSMahesh Ravishankar   elidedAttrsSet.insert(elidedAttrs.begin(), elidedAttrs.end());
223c21e88ccSMahesh Ravishankar   SmallVector<NamedAttribute> attrs;
224c21e88ccSMahesh Ravishankar   for (auto attr : op->getAttrs()) {
225c21e88ccSMahesh Ravishankar     if (elidedAttrsSet.count(attr.getName()))
226c21e88ccSMahesh Ravishankar       continue;
227c21e88ccSMahesh Ravishankar     attrs.push_back(attr);
228c21e88ccSMahesh Ravishankar   }
229c21e88ccSMahesh Ravishankar   return attrs;
230c21e88ccSMahesh Ravishankar }
231