xref: /llvm-project/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp (revision ccc02563f4d620d4d29a1cbd2c463871cc54745b)
1 //===- StructuredOpsUtils.cpp - Utilities used by structured ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10 #include "mlir/IR/AffineMap.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinAttributes.h"
13 #include "mlir/IR/IRMapping.h"
14 #include "llvm/ADT/StringSet.h"
15 
16 #include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc"
17 
18 using namespace mlir;
19 
isRowMajorMatmul(ArrayAttr indexingMaps)20 bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) {
21   if (indexingMaps.size() != 3)
22     return false;
23 
24   AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
25   AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
26   AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
27 
28   if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
29       map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
30       map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
31     return false;
32   }
33 
34   // Extract dimensions for MxK * KxN -> MxN
35   AffineExpr m = map2.getResult(0);
36   AffineExpr n = map2.getResult(1);
37   AffineExpr k = map0.getResult(1);
38   auto *context = indexingMaps.getContext();
39   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
40   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
41   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
42   auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
43   return indexingMaps == maps;
44 }
45 
isColumnMajorMatmul(ArrayAttr indexingMaps)46 bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) {
47   if (indexingMaps.size() != 3)
48     return false;
49 
50   AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
51   AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
52   AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
53 
54   if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
55       map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
56       map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
57     return false;
58   }
59 
60   // Extract dimensions for KxM * NxK -> NxM
61   AffineExpr n = map2.getResult(0);
62   AffineExpr m = map2.getResult(1);
63   AffineExpr k = map0.getResult(0);
64   auto *context = indexingMaps.getContext();
65   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, context));
66   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, context));
67   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
68   auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
69   return indexingMaps == maps;
70 }
71 
isRowMajorBatchMatmul(ArrayAttr indexingMaps)72 bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
73   if (indexingMaps.size() != 3)
74     return false;
75 
76   AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
77   AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
78   AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
79 
80   if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
81       map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
82       map1.getNumInputs() != 4 || map2.getNumInputs() != 4) {
83     return false;
84   }
85 
86   // Extract dimensions for BxMxK * BxKxN -> BxMxN
87   AffineExpr b = map2.getResult(0);
88   AffineExpr m = map2.getResult(1);
89   AffineExpr n = map2.getResult(2);
90   AffineExpr k = map0.getResult(2);
91   auto *context = indexingMaps.getContext();
92   auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, k}, context));
93   auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {b, k, n}, context));
94   auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, n}, context));
95   auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
96   return indexingMaps == maps;
97 }
98 
isVecmat(ArrayAttr indexingMaps)99 bool mlir::isVecmat(ArrayAttr indexingMaps) {
100   if (indexingMaps.size() != 3)
101     return false;
102   AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
103   AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
104   AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
105 
106   if (map0.getNumResults() != 1 || map1.getNumResults() != 2 ||
107       map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
108       map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
109     return false;
110   }
111 
112   // Extract dimensions for K * KxN -> N
113   AffineExpr k = map0.getResult(0);
114   AffineExpr n = map2.getResult(0);
115   auto *context = indexingMaps.getContext();
116   auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
117   auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, context));
118   auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
119   auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
120   return indexingMaps == maps;
121 }
122 
isBatchVecmat(ArrayAttr indexingMaps)123 bool mlir::isBatchVecmat(ArrayAttr indexingMaps) {
124   if (indexingMaps.size() != 3)
125     return false;
126   AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
127   AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
128   AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
129 
130   if (map0.getNumResults() != 2 || map1.getNumResults() != 3 ||
131       map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
132       map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
133     return false;
134   }
135 
136   // Extract dimensions for B*K * B*K*N -> B*N
137   AffineExpr b = map0.getResult(0);
138   AffineExpr k = map0.getResult(1);
139   AffineExpr n = map2.getResult(1);
140   auto *context = indexingMaps.getContext();
141   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
142   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k, n}, context));
143   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
144   auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
145   return indexingMaps == maps;
146 }
147 
isMatvec(ArrayAttr indexingMaps)148 bool mlir::isMatvec(ArrayAttr indexingMaps) {
149   if (indexingMaps.size() != 3)
150     return false;
151   AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
152   AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
153   AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
154 
155   if (map0.getNumResults() != 2 || map1.getNumResults() != 1 ||
156       map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
157       map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
158     return false;
159   }
160 
161   // Extract dimensions for N*K * K -> N
162   AffineExpr k = map1.getResult(0);
163   AffineExpr n = map2.getResult(0);
164   auto *context = indexingMaps.getContext();
165   auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, context));
166   auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
167   auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
168   auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
169   return indexingMaps == maps;
170 }
171 
isBatchMatvec(ArrayAttr indexingMaps)172 bool mlir::isBatchMatvec(ArrayAttr indexingMaps) {
173   if (indexingMaps.size() != 3)
174     return false;
175   AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
176   AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
177   AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
178 
179   if (map0.getNumResults() != 3 || map1.getNumResults() != 2 ||
180       map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
181       map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
182     return false;
183   }
184 
185   // Extract dimensions for B*N*K * B*K -> B*N
186   AffineExpr b = map0.getResult(0);
187   AffineExpr k = map1.getResult(1);
188   AffineExpr n = map2.getResult(1);
189   auto *context = indexingMaps.getContext();
190   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, n, k}, context));
191   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
192   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
193   auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
194   return indexingMaps == maps;
195 }
196 
clone(OpBuilder & b,Operation * op,TypeRange newResultTypes,ValueRange newOperands)197 Operation *mlir::clone(OpBuilder &b, Operation *op, TypeRange newResultTypes,
198                        ValueRange newOperands) {
199   IRMapping bvm;
200   OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
201                        op->getAttrs());
202   for (Region &r : op->getRegions()) {
203     Region *newRegion = state.addRegion();
204     b.cloneRegionBefore(r, *newRegion, newRegion->begin(), bvm);
205   }
206   return b.create(state);
207 }
208 
cloneWithoutRegions(OpBuilder & b,Operation * op,TypeRange newResultTypes,ValueRange newOperands)209 Operation *mlir::cloneWithoutRegions(OpBuilder &b, Operation *op,
210                                      TypeRange newResultTypes,
211                                      ValueRange newOperands) {
212   OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
213                        op->getAttrs());
214   for (size_t cnt = 0, e = op->getNumRegions(); cnt < e; ++cnt)
215     state.addRegion();
216   return b.create(state);
217 }
218 
219 SmallVector<NamedAttribute>
getPrunedAttributeList(Operation * op,ArrayRef<StringRef> elidedAttrs)220 mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
221   llvm::StringSet<> elidedAttrsSet;
222   elidedAttrsSet.insert(elidedAttrs.begin(), elidedAttrs.end());
223   SmallVector<NamedAttribute> attrs;
224   for (auto attr : op->getAttrs()) {
225     if (elidedAttrsSet.count(attr.getName()))
226       continue;
227     attrs.push_back(attr);
228   }
229   return attrs;
230 }
231