xref: /llvm-project/mlir/lib/Dialect/Utils/IndexingUtils.cpp (revision 9cc11b98a76c9b2f39b84f709566aac6f962f07a)
1 //===- IndexingUtils.cpp - Helpers related to index computations ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/IndexingUtils.h"
10 #include "mlir/Dialect/Utils/StaticValueUtils.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include <numeric>
17 #include <optional>
18 
19 using namespace mlir;
20 
21 template <typename ExprType>
22 SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes,
23                                                ExprType unit) {
24   if (sizes.empty())
25     return {};
26   SmallVector<ExprType> strides(sizes.size(), unit);
27   for (int64_t r = strides.size() - 2; r >= 0; --r)
28     strides[r] = strides[r + 1] * sizes[r + 1];
29   return strides;
30 }
31 
32 template <typename ExprType>
33 SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1,
34                                                 ArrayRef<ExprType> v2) {
35   // Early exit if both are empty, let zip_equal fail if only 1 is empty.
36   if (v1.empty() && v2.empty())
37     return {};
38   SmallVector<ExprType> result;
39   for (auto it : llvm::zip_equal(v1, v2))
40     result.push_back(std::get<0>(it) * std::get<1>(it));
41   return result;
42 }
43 
44 template <typename ExprType>
45 ExprType linearizeImpl(ArrayRef<ExprType> offsets, ArrayRef<ExprType> basis,
46                        ExprType zero) {
47   assert(offsets.size() == basis.size());
48   ExprType linearIndex = zero;
49   for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
50     linearIndex = linearIndex + offsets[idx] * basis[idx];
51   return linearIndex;
52 }
53 
54 template <typename ExprType, typename DivOpTy>
55 SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
56                                       ArrayRef<ExprType> strides,
57                                       DivOpTy divOp) {
58   int64_t rank = strides.size();
59   SmallVector<ExprType> offsets(rank);
60   for (int64_t r = 0; r < rank; ++r) {
61     offsets[r] = divOp(linearIndex, strides[r]);
62     linearIndex = linearIndex % strides[r];
63   }
64   return offsets;
65 }
66 
67 //===----------------------------------------------------------------------===//
68 // Utils that operate on static integer values.
69 //===----------------------------------------------------------------------===//
70 
71 SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
72   assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&
73          "sizes must be nonnegative");
74   int64_t unit = 1;
75   return ::computeSuffixProductImpl(sizes, unit);
76 }
77 
78 SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
79                                                  ArrayRef<int64_t> v2) {
80   return computeElementwiseMulImpl(v1, v2);
81 }
82 
83 int64_t mlir::computeSum(ArrayRef<int64_t> basis) {
84   assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
85          "basis must be nonnegative");
86   if (basis.empty())
87     return 0;
88   return std::accumulate(basis.begin(), basis.end(), 1, std::plus<int64_t>());
89 }
90 
91 int64_t mlir::computeProduct(ArrayRef<int64_t> basis) {
92   assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
93          "basis must be nonnegative");
94   if (basis.empty())
95     return 1;
96   return std::accumulate(basis.begin(), basis.end(), 1,
97                          std::multiplies<int64_t>());
98 }
99 
100 int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
101   assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
102          "basis must be nonnegative");
103   int64_t zero = 0;
104   return linearizeImpl(offsets, basis, zero);
105 }
106 
107 SmallVector<int64_t> mlir::delinearize(int64_t linearIndex,
108                                        ArrayRef<int64_t> strides) {
109   assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) &&
110          "strides must be nonnegative");
111   return delinearizeImpl(linearIndex, strides,
112                          [](int64_t e1, int64_t e2) { return e1 / e2; });
113 }
114 
115 std::optional<SmallVector<int64_t>>
116 mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) {
117   if (shape.size() < subShape.size())
118     return std::nullopt;
119   assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) &&
120          "shape must be nonnegative");
121   assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) &&
122          "subShape must be nonnegative");
123 
124   // Starting from the end, compute the integer divisors.
125   std::vector<int64_t> result;
126   result.reserve(shape.size());
127   for (auto [size, subSize] :
128        llvm::zip(llvm::reverse(shape), llvm::reverse(subShape))) {
129     // If integral division does not occur, return and let the caller decide.
130     if (size % subSize != 0)
131       return std::nullopt;
132     result.push_back(size / subSize);
133   }
134   // At this point we computed the ratio (in reverse) for the common size.
135   // Fill with the remaining entries from the shape (still in reverse).
136   int commonSize = subShape.size();
137   std::copy(shape.rbegin() + commonSize, shape.rend(),
138             std::back_inserter(result));
139   // Reverse again to get it back in the proper order and return.
140   return SmallVector<int64_t>{result.rbegin(), result.rend()};
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // Utils that operate on AffineExpr.
145 //===----------------------------------------------------------------------===//
146 
147 SmallVector<AffineExpr> mlir::computeSuffixProduct(ArrayRef<AffineExpr> sizes) {
148   if (sizes.empty())
149     return {};
150   AffineExpr unit = getAffineConstantExpr(1, sizes.front().getContext());
151   return ::computeSuffixProductImpl(sizes, unit);
152 }
153 
154 SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1,
155                                                     ArrayRef<AffineExpr> v2) {
156   return computeElementwiseMulImpl(v1, v2);
157 }
158 
159 AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
160   if (basis.empty())
161     return getAffineConstantExpr(0, ctx);
162   return std::accumulate(basis.begin(), basis.end(),
163                          getAffineConstantExpr(0, ctx),
164                          std::plus<AffineExpr>());
165 }
166 
167 AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
168   if (basis.empty())
169     return getAffineConstantExpr(1, ctx);
170   return std::accumulate(basis.begin(), basis.end(),
171                          getAffineConstantExpr(1, ctx),
172                          std::multiplies<AffineExpr>());
173 }
174 
175 AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
176                            ArrayRef<AffineExpr> basis) {
177   AffineExpr zero = getAffineConstantExpr(0, ctx);
178   return linearizeImpl(offsets, basis, zero);
179 }
180 
181 AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
182                            ArrayRef<int64_t> basis) {
183 
184   return linearize(ctx, offsets, getAffineConstantExprs(basis, ctx));
185 }
186 
187 SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
188                                           ArrayRef<AffineExpr> strides) {
189   return delinearizeImpl(
190       linearIndex, strides,
191       [](AffineExpr e1, AffineExpr e2) { return e1.floorDiv(e2); });
192 }
193 
194 SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
195                                           ArrayRef<int64_t> strides) {
196   MLIRContext *ctx = linearIndex.getContext();
197   return delinearize(linearIndex, getAffineConstantExprs(strides, ctx));
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // Permutation utils.
202 //===----------------------------------------------------------------------===//
203 
204 SmallVector<int64_t>
205 mlir::invertPermutationVector(ArrayRef<int64_t> permutation) {
206   assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) &&
207          "permutation must be non-negative");
208   SmallVector<int64_t> inversion(permutation.size());
209   for (const auto &pos : llvm::enumerate(permutation)) {
210     inversion[pos.value()] = pos.index();
211   }
212   return inversion;
213 }
214 
215 bool mlir::isIdentityPermutation(ArrayRef<int64_t> permutation) {
216   for (auto i : llvm::seq<int64_t>(0, permutation.size()))
217     if (permutation[i] != i)
218       return false;
219   return true;
220 }
221 
222 bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
223   assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) &&
224          "permutation must be non-negative");
225   llvm::SmallDenseSet<int64_t, 4> seenVals;
226   for (auto val : interchange) {
227     if (seenVals.count(val))
228       return false;
229     seenVals.insert(val);
230   }
231   return seenVals.size() == interchange.size();
232 }
233 
234 SmallVector<int64_t>
235 mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
236                                ArrayRef<int64_t> desiredPositions) {
237   SmallVector<int64_t> res(permSize, -1);
238   DenseSet<int64_t> seen;
239   for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) {
240     res[desiredPos] = pos;
241     seen.insert(pos);
242   }
243   int64_t nextPos = 0;
244   for (int64_t &entry : res) {
245     if (entry != -1)
246       continue;
247     while (seen.contains(nextPos))
248       ++nextPos;
249     entry = nextPos;
250     ++nextPos;
251   }
252   return res;
253 }
254 
255 SmallVector<int64_t> mlir::dropDims(ArrayRef<int64_t> inputPerm,
256                                     ArrayRef<int64_t> dropPositions) {
257   assert(inputPerm.size() >= dropPositions.size() &&
258          "expect inputPerm size large than position to drop");
259   SmallVector<int64_t> res;
260   unsigned permSize = inputPerm.size();
261   for (unsigned inputIndex = 0; inputIndex < permSize; ++inputIndex) {
262     int64_t targetIndex = inputPerm[inputIndex];
263     bool shouldDrop = false;
264     unsigned dropSize = dropPositions.size();
265     for (unsigned dropIndex = 0; dropIndex < dropSize; dropIndex++) {
266       if (dropPositions[dropIndex] == inputPerm[inputIndex]) {
267         shouldDrop = true;
268         break;
269       }
270       if (dropPositions[dropIndex] < inputPerm[inputIndex]) {
271         targetIndex--;
272       }
273     }
274     if (!shouldDrop) {
275       res.push_back(targetIndex);
276     }
277   }
278   return res;
279 }
280 
281 SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
282                                           unsigned dropFront,
283                                           unsigned dropBack) {
284   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
285   auto range = arrayAttr.getAsRange<IntegerAttr>();
286   SmallVector<int64_t> res;
287   res.reserve(arrayAttr.size() - dropFront - dropBack);
288   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
289        it != eit; ++it)
290     res.push_back((*it).getValue().getSExtValue());
291   return res;
292 }
293 
294 // TODO: do we have any common utily for this?
295 static MLIRContext *getContext(OpFoldResult val) {
296   assert(val && "Invalid value");
297   if (auto attr = dyn_cast<Attribute>(val)) {
298     return attr.getContext();
299   }
300   return cast<Value>(val).getContext();
301 }
302 
303 std::pair<AffineExpr, SmallVector<OpFoldResult>>
304 mlir::computeLinearIndex(OpFoldResult sourceOffset,
305                          ArrayRef<OpFoldResult> strides,
306                          ArrayRef<OpFoldResult> indices) {
307   assert(strides.size() == indices.size());
308   auto sourceRank = static_cast<unsigned>(strides.size());
309 
310   // Hold the affine symbols and values for the computation of the offset.
311   SmallVector<OpFoldResult> values(2 * sourceRank + 1);
312   SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
313 
314   bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols});
315   AffineExpr expr = symbols.front();
316   values[0] = sourceOffset;
317 
318   for (unsigned i = 0; i < sourceRank; ++i) {
319     // Compute the stride.
320     OpFoldResult origStride = strides[i];
321 
322     // Build up the computation of the offset.
323     unsigned baseIdxForDim = 1 + 2 * i;
324     unsigned subOffsetForDim = baseIdxForDim;
325     unsigned origStrideForDim = baseIdxForDim + 1;
326     expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
327     values[subOffsetForDim] = indices[i];
328     values[origStrideForDim] = origStride;
329   }
330 
331   return {expr, values};
332 }
333 
334 std::pair<AffineExpr, SmallVector<OpFoldResult>>
335 mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
336                          ArrayRef<Value> indices) {
337   return computeLinearIndex(
338       sourceOffset, getAsIndexOpFoldResult(sourceOffset.getContext(), strides),
339       getAsOpFoldResult(ValueRange(indices)));
340 }
341 
342 //===----------------------------------------------------------------------===//
343 // TileOffsetRange
344 //===----------------------------------------------------------------------===//
345 
346 /// Apply left-padding by 1 to the tile shape if required.
347 static SmallVector<int64_t> padTileShapeToSize(ArrayRef<int64_t> tileShape,
348                                                unsigned paddedSize) {
349   assert(tileShape.size() <= paddedSize &&
350          "expected tileShape to <= paddedSize");
351   if (tileShape.size() == paddedSize)
352     return to_vector(tileShape);
353   SmallVector<int64_t> result(paddedSize - tileShape.size(), 1);
354   llvm::append_range(result, tileShape);
355   return result;
356 }
357 
358 mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl(
359     ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
360     ArrayRef<int64_t> loopOrder)
361     : tileShape(padTileShapeToSize(tileShape, shape.size())),
362       inverseLoopOrder(invertPermutationVector(loopOrder)),
363       sliceStrides(shape.size()) {
364   // Divide the shape by the tile shape.
365   std::optional<SmallVector<int64_t>> shapeRatio =
366       mlir::computeShapeRatio(shape, tileShape);
367   assert(shapeRatio && shapeRatio->size() == shape.size() &&
368          "target shape does not evenly divide the original shape");
369   assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() &&
370          "expected loop order to be a permutation of rank equal to outer "
371          "shape");
372 
373   maxLinearIndex = mlir::computeMaxLinearIndex(*shapeRatio);
374   mlir::applyPermutationToVector(*shapeRatio, loopOrder);
375   sliceStrides = mlir::computeStrides(*shapeRatio);
376 }
377 
378 SmallVector<int64_t> mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets(
379     int64_t linearIndex) const {
380   SmallVector<int64_t> tileCoords = applyPermutation(
381       delinearize(linearIndex, sliceStrides), inverseLoopOrder);
382   return computeElementwiseMul(tileCoords, tileShape);
383 }
384 
385 SmallVector<AffineExpr>
386 mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets(
387     AffineExpr linearIndex) const {
388   MLIRContext *ctx = linearIndex.getContext();
389   SmallVector<AffineExpr> tileCoords = applyPermutation(
390       delinearize(linearIndex, sliceStrides), inverseLoopOrder);
391   return mlir::computeElementwiseMul(tileCoords,
392                                      getAffineConstantExprs(tileShape, ctx));
393 }
394