1 //===- StructuredOpsUtils.cpp - Utilities used by structured ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10 #include "mlir/IR/AffineMap.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinAttributes.h"
13 #include "mlir/IR/IRMapping.h"
14 #include "llvm/ADT/StringSet.h"
15
16 #include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc"
17
18 using namespace mlir;
19
isRowMajorMatmul(ArrayAttr indexingMaps)20 bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) {
21 if (indexingMaps.size() != 3)
22 return false;
23
24 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
25 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
26 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
27
28 if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
29 map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
30 map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
31 return false;
32 }
33
34 // Extract dimensions for MxK * KxN -> MxN
35 AffineExpr m = map2.getResult(0);
36 AffineExpr n = map2.getResult(1);
37 AffineExpr k = map0.getResult(1);
38 auto *context = indexingMaps.getContext();
39 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
40 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
41 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
42 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
43 return indexingMaps == maps;
44 }
45
isColumnMajorMatmul(ArrayAttr indexingMaps)46 bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) {
47 if (indexingMaps.size() != 3)
48 return false;
49
50 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
51 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
52 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
53
54 if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
55 map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
56 map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
57 return false;
58 }
59
60 // Extract dimensions for KxM * NxK -> NxM
61 AffineExpr n = map2.getResult(0);
62 AffineExpr m = map2.getResult(1);
63 AffineExpr k = map0.getResult(0);
64 auto *context = indexingMaps.getContext();
65 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, context));
66 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, context));
67 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
68 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
69 return indexingMaps == maps;
70 }
71
isRowMajorBatchMatmul(ArrayAttr indexingMaps)72 bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
73 if (indexingMaps.size() != 3)
74 return false;
75
76 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
77 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
78 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
79
80 if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
81 map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
82 map1.getNumInputs() != 4 || map2.getNumInputs() != 4) {
83 return false;
84 }
85
86 // Extract dimensions for BxMxK * BxKxN -> BxMxN
87 AffineExpr b = map2.getResult(0);
88 AffineExpr m = map2.getResult(1);
89 AffineExpr n = map2.getResult(2);
90 AffineExpr k = map0.getResult(2);
91 auto *context = indexingMaps.getContext();
92 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, k}, context));
93 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {b, k, n}, context));
94 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, n}, context));
95 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
96 return indexingMaps == maps;
97 }
98
isVecmat(ArrayAttr indexingMaps)99 bool mlir::isVecmat(ArrayAttr indexingMaps) {
100 if (indexingMaps.size() != 3)
101 return false;
102 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
103 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
104 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
105
106 if (map0.getNumResults() != 1 || map1.getNumResults() != 2 ||
107 map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
108 map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
109 return false;
110 }
111
112 // Extract dimensions for K * KxN -> N
113 AffineExpr k = map0.getResult(0);
114 AffineExpr n = map2.getResult(0);
115 auto *context = indexingMaps.getContext();
116 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
117 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, context));
118 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
119 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
120 return indexingMaps == maps;
121 }
122
isBatchVecmat(ArrayAttr indexingMaps)123 bool mlir::isBatchVecmat(ArrayAttr indexingMaps) {
124 if (indexingMaps.size() != 3)
125 return false;
126 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
127 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
128 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
129
130 if (map0.getNumResults() != 2 || map1.getNumResults() != 3 ||
131 map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
132 map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
133 return false;
134 }
135
136 // Extract dimensions for B*K * B*K*N -> B*N
137 AffineExpr b = map0.getResult(0);
138 AffineExpr k = map0.getResult(1);
139 AffineExpr n = map2.getResult(1);
140 auto *context = indexingMaps.getContext();
141 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
142 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k, n}, context));
143 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
144 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
145 return indexingMaps == maps;
146 }
147
isMatvec(ArrayAttr indexingMaps)148 bool mlir::isMatvec(ArrayAttr indexingMaps) {
149 if (indexingMaps.size() != 3)
150 return false;
151 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
152 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
153 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
154
155 if (map0.getNumResults() != 2 || map1.getNumResults() != 1 ||
156 map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
157 map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
158 return false;
159 }
160
161 // Extract dimensions for N*K * K -> N
162 AffineExpr k = map1.getResult(0);
163 AffineExpr n = map2.getResult(0);
164 auto *context = indexingMaps.getContext();
165 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, context));
166 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
167 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
168 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
169 return indexingMaps == maps;
170 }
171
isBatchMatvec(ArrayAttr indexingMaps)172 bool mlir::isBatchMatvec(ArrayAttr indexingMaps) {
173 if (indexingMaps.size() != 3)
174 return false;
175 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
176 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
177 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
178
179 if (map0.getNumResults() != 3 || map1.getNumResults() != 2 ||
180 map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
181 map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
182 return false;
183 }
184
185 // Extract dimensions for B*N*K * B*K -> B*N
186 AffineExpr b = map0.getResult(0);
187 AffineExpr k = map1.getResult(1);
188 AffineExpr n = map2.getResult(1);
189 auto *context = indexingMaps.getContext();
190 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, n, k}, context));
191 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
192 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
193 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
194 return indexingMaps == maps;
195 }
196
clone(OpBuilder & b,Operation * op,TypeRange newResultTypes,ValueRange newOperands)197 Operation *mlir::clone(OpBuilder &b, Operation *op, TypeRange newResultTypes,
198 ValueRange newOperands) {
199 IRMapping bvm;
200 OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
201 op->getAttrs());
202 for (Region &r : op->getRegions()) {
203 Region *newRegion = state.addRegion();
204 b.cloneRegionBefore(r, *newRegion, newRegion->begin(), bvm);
205 }
206 return b.create(state);
207 }
208
cloneWithoutRegions(OpBuilder & b,Operation * op,TypeRange newResultTypes,ValueRange newOperands)209 Operation *mlir::cloneWithoutRegions(OpBuilder &b, Operation *op,
210 TypeRange newResultTypes,
211 ValueRange newOperands) {
212 OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
213 op->getAttrs());
214 for (size_t cnt = 0, e = op->getNumRegions(); cnt < e; ++cnt)
215 state.addRegion();
216 return b.create(state);
217 }
218
219 SmallVector<NamedAttribute>
getPrunedAttributeList(Operation * op,ArrayRef<StringRef> elidedAttrs)220 mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
221 llvm::StringSet<> elidedAttrsSet;
222 elidedAttrsSet.insert(elidedAttrs.begin(), elidedAttrs.end());
223 SmallVector<NamedAttribute> attrs;
224 for (auto attr : op->getAttrs()) {
225 if (elidedAttrsSet.count(attr.getName()))
226 continue;
227 attrs.push_back(attr);
228 }
229 return attrs;
230 }
231