xref: /llvm-project/mlir/lib/Dialect/Utils/IndexingUtils.cpp (revision 9cc11b98a76c9b2f39b84f709566aac6f962f07a)
199ef9eebSMatthias Springer //===- IndexingUtils.cpp - Helpers related to index computations ----------===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer 
999ef9eebSMatthias Springer #include "mlir/Dialect/Utils/IndexingUtils.h"
10847048f4SDiego Caballero #include "mlir/Dialect/Utils/StaticValueUtils.h"
111b002d27SArnab Dutta #include "mlir/IR/AffineExpr.h"
121b002d27SArnab Dutta #include "mlir/IR/Builders.h"
1399ef9eebSMatthias Springer #include "mlir/IR/BuiltinAttributes.h"
14203fad47SNicolas Vasilache #include "mlir/IR/MLIRContext.h"
15203fad47SNicolas Vasilache #include "llvm/ADT/STLExtras.h"
167a69a9d7SNicolas Vasilache #include <numeric>
17a1fe1f5fSKazu Hirata #include <optional>
187a69a9d7SNicolas Vasilache 
197a69a9d7SNicolas Vasilache using namespace mlir;
207a69a9d7SNicolas Vasilache 
21203fad47SNicolas Vasilache template <typename ExprType>
22203fad47SNicolas Vasilache SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes,
23203fad47SNicolas Vasilache                                                ExprType unit) {
24203fad47SNicolas Vasilache   if (sizes.empty())
25203fad47SNicolas Vasilache     return {};
26203fad47SNicolas Vasilache   SmallVector<ExprType> strides(sizes.size(), unit);
277a69a9d7SNicolas Vasilache   for (int64_t r = strides.size() - 2; r >= 0; --r)
287a69a9d7SNicolas Vasilache     strides[r] = strides[r + 1] * sizes[r + 1];
297a69a9d7SNicolas Vasilache   return strides;
307a69a9d7SNicolas Vasilache }
317a69a9d7SNicolas Vasilache 
32203fad47SNicolas Vasilache template <typename ExprType>
33203fad47SNicolas Vasilache SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1,
34203fad47SNicolas Vasilache                                                 ArrayRef<ExprType> v2) {
35203fad47SNicolas Vasilache   // Early exit if both are empty, let zip_equal fail if only 1 is empty.
36203fad47SNicolas Vasilache   if (v1.empty() && v2.empty())
37203fad47SNicolas Vasilache     return {};
38203fad47SNicolas Vasilache   SmallVector<ExprType> result;
39203fad47SNicolas Vasilache   for (auto it : llvm::zip_equal(v1, v2))
407a69a9d7SNicolas Vasilache     result.push_back(std::get<0>(it) * std::get<1>(it));
417a69a9d7SNicolas Vasilache   return result;
427a69a9d7SNicolas Vasilache }
437a69a9d7SNicolas Vasilache 
44203fad47SNicolas Vasilache template <typename ExprType>
45203fad47SNicolas Vasilache ExprType linearizeImpl(ArrayRef<ExprType> offsets, ArrayRef<ExprType> basis,
46203fad47SNicolas Vasilache                        ExprType zero) {
47203fad47SNicolas Vasilache   assert(offsets.size() == basis.size());
48203fad47SNicolas Vasilache   ExprType linearIndex = zero;
49203fad47SNicolas Vasilache   for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
50203fad47SNicolas Vasilache     linearIndex = linearIndex + offsets[idx] * basis[idx];
51203fad47SNicolas Vasilache   return linearIndex;
52203fad47SNicolas Vasilache }
53203fad47SNicolas Vasilache 
54203fad47SNicolas Vasilache template <typename ExprType, typename DivOpTy>
55203fad47SNicolas Vasilache SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
56203fad47SNicolas Vasilache                                       ArrayRef<ExprType> strides,
57203fad47SNicolas Vasilache                                       DivOpTy divOp) {
58203fad47SNicolas Vasilache   int64_t rank = strides.size();
59203fad47SNicolas Vasilache   SmallVector<ExprType> offsets(rank);
60203fad47SNicolas Vasilache   for (int64_t r = 0; r < rank; ++r) {
61203fad47SNicolas Vasilache     offsets[r] = divOp(linearIndex, strides[r]);
62203fad47SNicolas Vasilache     linearIndex = linearIndex % strides[r];
63203fad47SNicolas Vasilache   }
64203fad47SNicolas Vasilache   return offsets;
65203fad47SNicolas Vasilache }
66203fad47SNicolas Vasilache 
67203fad47SNicolas Vasilache //===----------------------------------------------------------------------===//
68203fad47SNicolas Vasilache // Utils that operate on static integer values.
69203fad47SNicolas Vasilache //===----------------------------------------------------------------------===//
70203fad47SNicolas Vasilache 
71203fad47SNicolas Vasilache SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
72c65d8c71SGuray Ozen   assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&
73203fad47SNicolas Vasilache          "sizes must be nonnegative");
74203fad47SNicolas Vasilache   int64_t unit = 1;
75203fad47SNicolas Vasilache   return ::computeSuffixProductImpl(sizes, unit);
76203fad47SNicolas Vasilache }
77203fad47SNicolas Vasilache 
78203fad47SNicolas Vasilache SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
79203fad47SNicolas Vasilache                                                  ArrayRef<int64_t> v2) {
80203fad47SNicolas Vasilache   return computeElementwiseMulImpl(v1, v2);
81203fad47SNicolas Vasilache }
82203fad47SNicolas Vasilache 
839e54d5e7SNicolas Vasilache int64_t mlir::computeSum(ArrayRef<int64_t> basis) {
849e54d5e7SNicolas Vasilache   assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
859e54d5e7SNicolas Vasilache          "basis must be nonnegative");
869e54d5e7SNicolas Vasilache   if (basis.empty())
879e54d5e7SNicolas Vasilache     return 0;
889e54d5e7SNicolas Vasilache   return std::accumulate(basis.begin(), basis.end(), 1, std::plus<int64_t>());
899e54d5e7SNicolas Vasilache }
909e54d5e7SNicolas Vasilache 
919e54d5e7SNicolas Vasilache int64_t mlir::computeProduct(ArrayRef<int64_t> basis) {
92203fad47SNicolas Vasilache   assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
93203fad47SNicolas Vasilache          "basis must be nonnegative");
94203fad47SNicolas Vasilache   if (basis.empty())
95a9205c5cSSpenser Bauman     return 1;
96203fad47SNicolas Vasilache   return std::accumulate(basis.begin(), basis.end(), 1,
97203fad47SNicolas Vasilache                          std::multiplies<int64_t>());
98203fad47SNicolas Vasilache }
99203fad47SNicolas Vasilache 
100203fad47SNicolas Vasilache int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
101203fad47SNicolas Vasilache   assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
102203fad47SNicolas Vasilache          "basis must be nonnegative");
103203fad47SNicolas Vasilache   int64_t zero = 0;
104203fad47SNicolas Vasilache   return linearizeImpl(offsets, basis, zero);
105203fad47SNicolas Vasilache }
106203fad47SNicolas Vasilache 
107203fad47SNicolas Vasilache SmallVector<int64_t> mlir::delinearize(int64_t linearIndex,
108203fad47SNicolas Vasilache                                        ArrayRef<int64_t> strides) {
109203fad47SNicolas Vasilache   assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) &&
110203fad47SNicolas Vasilache          "strides must be nonnegative");
111203fad47SNicolas Vasilache   return delinearizeImpl(linearIndex, strides,
112203fad47SNicolas Vasilache                          [](int64_t e1, int64_t e2) { return e1 / e2; });
113203fad47SNicolas Vasilache }
114203fad47SNicolas Vasilache 
1150a81ace0SKazu Hirata std::optional<SmallVector<int64_t>>
1167a69a9d7SNicolas Vasilache mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) {
1177a69a9d7SNicolas Vasilache   if (shape.size() < subShape.size())
1181a36588eSKazu Hirata     return std::nullopt;
1197a69a9d7SNicolas Vasilache   assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) &&
1207a69a9d7SNicolas Vasilache          "shape must be nonnegative");
1217a69a9d7SNicolas Vasilache   assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) &&
1227a69a9d7SNicolas Vasilache          "subShape must be nonnegative");
1237a69a9d7SNicolas Vasilache 
1247a69a9d7SNicolas Vasilache   // Starting from the end, compute the integer divisors.
1257a69a9d7SNicolas Vasilache   std::vector<int64_t> result;
1267a69a9d7SNicolas Vasilache   result.reserve(shape.size());
1277a69a9d7SNicolas Vasilache   for (auto [size, subSize] :
1287a69a9d7SNicolas Vasilache        llvm::zip(llvm::reverse(shape), llvm::reverse(subShape))) {
1297a69a9d7SNicolas Vasilache     // If integral division does not occur, return and let the caller decide.
1307a69a9d7SNicolas Vasilache     if (size % subSize != 0)
1311a36588eSKazu Hirata       return std::nullopt;
1327a69a9d7SNicolas Vasilache     result.push_back(size / subSize);
1337a69a9d7SNicolas Vasilache   }
1347a69a9d7SNicolas Vasilache   // At this point we computed the ratio (in reverse) for the common size.
1357a69a9d7SNicolas Vasilache   // Fill with the remaining entries from the shape (still in reverse).
1367a69a9d7SNicolas Vasilache   int commonSize = subShape.size();
1377a69a9d7SNicolas Vasilache   std::copy(shape.rbegin() + commonSize, shape.rend(),
1387a69a9d7SNicolas Vasilache             std::back_inserter(result));
1397a69a9d7SNicolas Vasilache   // Reverse again to get it back in the proper order and return.
1407a69a9d7SNicolas Vasilache   return SmallVector<int64_t>{result.rbegin(), result.rend()};
1417a69a9d7SNicolas Vasilache }
1427a69a9d7SNicolas Vasilache 
143203fad47SNicolas Vasilache //===----------------------------------------------------------------------===//
144203fad47SNicolas Vasilache // Utils that operate on AffineExpr.
145203fad47SNicolas Vasilache //===----------------------------------------------------------------------===//
146203fad47SNicolas Vasilache 
147203fad47SNicolas Vasilache SmallVector<AffineExpr> mlir::computeSuffixProduct(ArrayRef<AffineExpr> sizes) {
148203fad47SNicolas Vasilache   if (sizes.empty())
149203fad47SNicolas Vasilache     return {};
150203fad47SNicolas Vasilache   AffineExpr unit = getAffineConstantExpr(1, sizes.front().getContext());
151203fad47SNicolas Vasilache   return ::computeSuffixProductImpl(sizes, unit);
15299ef9eebSMatthias Springer }
15399ef9eebSMatthias Springer 
154203fad47SNicolas Vasilache SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1,
155203fad47SNicolas Vasilache                                                     ArrayRef<AffineExpr> v2) {
156203fad47SNicolas Vasilache   return computeElementwiseMulImpl(v1, v2);
15799ef9eebSMatthias Springer }
15899ef9eebSMatthias Springer 
1599e54d5e7SNicolas Vasilache AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
1609e54d5e7SNicolas Vasilache   if (basis.empty())
1619e54d5e7SNicolas Vasilache     return getAffineConstantExpr(0, ctx);
1629e54d5e7SNicolas Vasilache   return std::accumulate(basis.begin(), basis.end(),
163a3cd2eebSNicolas Vasilache                          getAffineConstantExpr(0, ctx),
1649e54d5e7SNicolas Vasilache                          std::plus<AffineExpr>());
1659e54d5e7SNicolas Vasilache }
1669e54d5e7SNicolas Vasilache 
1679e54d5e7SNicolas Vasilache AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
1687a69a9d7SNicolas Vasilache   if (basis.empty())
169a3cd2eebSNicolas Vasilache     return getAffineConstantExpr(1, ctx);
170203fad47SNicolas Vasilache   return std::accumulate(basis.begin(), basis.end(),
171203fad47SNicolas Vasilache                          getAffineConstantExpr(1, ctx),
172203fad47SNicolas Vasilache                          std::multiplies<AffineExpr>());
1737a69a9d7SNicolas Vasilache }
1747a69a9d7SNicolas Vasilache 
175203fad47SNicolas Vasilache AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
176203fad47SNicolas Vasilache                            ArrayRef<AffineExpr> basis) {
177203fad47SNicolas Vasilache   AffineExpr zero = getAffineConstantExpr(0, ctx);
178203fad47SNicolas Vasilache   return linearizeImpl(offsets, basis, zero);
179203fad47SNicolas Vasilache }
180203fad47SNicolas Vasilache 
181203fad47SNicolas Vasilache AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
182203fad47SNicolas Vasilache                            ArrayRef<int64_t> basis) {
183831041beSChristopher Bate 
184831041beSChristopher Bate   return linearize(ctx, offsets, getAffineConstantExprs(basis, ctx));
185203fad47SNicolas Vasilache }
186203fad47SNicolas Vasilache 
187203fad47SNicolas Vasilache SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
188203fad47SNicolas Vasilache                                           ArrayRef<AffineExpr> strides) {
189203fad47SNicolas Vasilache   return delinearizeImpl(
190203fad47SNicolas Vasilache       linearIndex, strides,
191203fad47SNicolas Vasilache       [](AffineExpr e1, AffineExpr e2) { return e1.floorDiv(e2); });
192203fad47SNicolas Vasilache }
193203fad47SNicolas Vasilache 
194203fad47SNicolas Vasilache SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
195203fad47SNicolas Vasilache                                           ArrayRef<int64_t> strides) {
196203fad47SNicolas Vasilache   MLIRContext *ctx = linearIndex.getContext();
197831041beSChristopher Bate   return delinearize(linearIndex, getAffineConstantExprs(strides, ctx));
198203fad47SNicolas Vasilache }
199203fad47SNicolas Vasilache 
200203fad47SNicolas Vasilache //===----------------------------------------------------------------------===//
201203fad47SNicolas Vasilache // Permutation utils.
202203fad47SNicolas Vasilache //===----------------------------------------------------------------------===//
203203fad47SNicolas Vasilache 
204203fad47SNicolas Vasilache SmallVector<int64_t>
205b1d3afc9SHanhan Wang mlir::invertPermutationVector(ArrayRef<int64_t> permutation) {
206203fad47SNicolas Vasilache   assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) &&
207203fad47SNicolas Vasilache          "permutation must be non-negative");
208b1d3afc9SHanhan Wang   SmallVector<int64_t> inversion(permutation.size());
209b1d3afc9SHanhan Wang   for (const auto &pos : llvm::enumerate(permutation)) {
210b1d3afc9SHanhan Wang     inversion[pos.value()] = pos.index();
211b1d3afc9SHanhan Wang   }
212b1d3afc9SHanhan Wang   return inversion;
213b1d3afc9SHanhan Wang }
214b1d3afc9SHanhan Wang 
2152472c45bSHan-Chung Wang bool mlir::isIdentityPermutation(ArrayRef<int64_t> permutation) {
2162472c45bSHan-Chung Wang   for (auto i : llvm::seq<int64_t>(0, permutation.size()))
2172472c45bSHan-Chung Wang     if (permutation[i] != i)
2182472c45bSHan-Chung Wang       return false;
2192472c45bSHan-Chung Wang   return true;
2202472c45bSHan-Chung Wang }
2212472c45bSHan-Chung Wang 
222b1d3afc9SHanhan Wang bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
223203fad47SNicolas Vasilache   assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) &&
224203fad47SNicolas Vasilache          "permutation must be non-negative");
225b1d3afc9SHanhan Wang   llvm::SmallDenseSet<int64_t, 4> seenVals;
226b1d3afc9SHanhan Wang   for (auto val : interchange) {
227b1d3afc9SHanhan Wang     if (seenVals.count(val))
228b1d3afc9SHanhan Wang       return false;
229b1d3afc9SHanhan Wang     seenVals.insert(val);
230b1d3afc9SHanhan Wang   }
231b1d3afc9SHanhan Wang   return seenVals.size() == interchange.size();
232b1d3afc9SHanhan Wang }
233b1d3afc9SHanhan Wang 
2340bfbecf5SQuentin Colombet SmallVector<int64_t>
2350bfbecf5SQuentin Colombet mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
2360bfbecf5SQuentin Colombet                                ArrayRef<int64_t> desiredPositions) {
2370bfbecf5SQuentin Colombet   SmallVector<int64_t> res(permSize, -1);
2380bfbecf5SQuentin Colombet   DenseSet<int64_t> seen;
2390bfbecf5SQuentin Colombet   for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) {
2400bfbecf5SQuentin Colombet     res[desiredPos] = pos;
2410bfbecf5SQuentin Colombet     seen.insert(pos);
2420bfbecf5SQuentin Colombet   }
2430bfbecf5SQuentin Colombet   int64_t nextPos = 0;
2440bfbecf5SQuentin Colombet   for (int64_t &entry : res) {
2450bfbecf5SQuentin Colombet     if (entry != -1)
2460bfbecf5SQuentin Colombet       continue;
2470bfbecf5SQuentin Colombet     while (seen.contains(nextPos))
2480bfbecf5SQuentin Colombet       ++nextPos;
2490bfbecf5SQuentin Colombet     entry = nextPos;
2500bfbecf5SQuentin Colombet     ++nextPos;
2510bfbecf5SQuentin Colombet   }
2520bfbecf5SQuentin Colombet   return res;
2530bfbecf5SQuentin Colombet }
2540bfbecf5SQuentin Colombet 
255*9cc11b98Sdonald chen SmallVector<int64_t> mlir::dropDims(ArrayRef<int64_t> inputPerm,
256*9cc11b98Sdonald chen                                     ArrayRef<int64_t> dropPositions) {
257*9cc11b98Sdonald chen   assert(inputPerm.size() >= dropPositions.size() &&
258*9cc11b98Sdonald chen          "expect inputPerm size large than position to drop");
259*9cc11b98Sdonald chen   SmallVector<int64_t> res;
260*9cc11b98Sdonald chen   unsigned permSize = inputPerm.size();
261*9cc11b98Sdonald chen   for (unsigned inputIndex = 0; inputIndex < permSize; ++inputIndex) {
262*9cc11b98Sdonald chen     int64_t targetIndex = inputPerm[inputIndex];
263*9cc11b98Sdonald chen     bool shouldDrop = false;
264*9cc11b98Sdonald chen     unsigned dropSize = dropPositions.size();
265*9cc11b98Sdonald chen     for (unsigned dropIndex = 0; dropIndex < dropSize; dropIndex++) {
266*9cc11b98Sdonald chen       if (dropPositions[dropIndex] == inputPerm[inputIndex]) {
267*9cc11b98Sdonald chen         shouldDrop = true;
268*9cc11b98Sdonald chen         break;
269*9cc11b98Sdonald chen       }
270*9cc11b98Sdonald chen       if (dropPositions[dropIndex] < inputPerm[inputIndex]) {
271*9cc11b98Sdonald chen         targetIndex--;
272*9cc11b98Sdonald chen       }
273*9cc11b98Sdonald chen     }
274*9cc11b98Sdonald chen     if (!shouldDrop) {
275*9cc11b98Sdonald chen       res.push_back(targetIndex);
276*9cc11b98Sdonald chen     }
277*9cc11b98Sdonald chen   }
278*9cc11b98Sdonald chen   return res;
279*9cc11b98Sdonald chen }
280*9cc11b98Sdonald chen 
281203fad47SNicolas Vasilache SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
28299ef9eebSMatthias Springer                                           unsigned dropFront,
28399ef9eebSMatthias Springer                                           unsigned dropBack) {
28499ef9eebSMatthias Springer   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
28599ef9eebSMatthias Springer   auto range = arrayAttr.getAsRange<IntegerAttr>();
2867a69a9d7SNicolas Vasilache   SmallVector<int64_t> res;
28799ef9eebSMatthias Springer   res.reserve(arrayAttr.size() - dropFront - dropBack);
28899ef9eebSMatthias Springer   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
28999ef9eebSMatthias Springer        it != eit; ++it)
29099ef9eebSMatthias Springer     res.push_back((*it).getValue().getSExtValue());
29199ef9eebSMatthias Springer   return res;
29299ef9eebSMatthias Springer }
293793ee2bfSIvan Butygin 
294793ee2bfSIvan Butygin // TODO: do we have any common utily for this?
295793ee2bfSIvan Butygin static MLIRContext *getContext(OpFoldResult val) {
296793ee2bfSIvan Butygin   assert(val && "Invalid value");
297793ee2bfSIvan Butygin   if (auto attr = dyn_cast<Attribute>(val)) {
298793ee2bfSIvan Butygin     return attr.getContext();
299793ee2bfSIvan Butygin   }
3008383bf23SMehdi Amini   return cast<Value>(val).getContext();
301793ee2bfSIvan Butygin }
302793ee2bfSIvan Butygin 
303793ee2bfSIvan Butygin std::pair<AffineExpr, SmallVector<OpFoldResult>>
304793ee2bfSIvan Butygin mlir::computeLinearIndex(OpFoldResult sourceOffset,
305793ee2bfSIvan Butygin                          ArrayRef<OpFoldResult> strides,
306793ee2bfSIvan Butygin                          ArrayRef<OpFoldResult> indices) {
307793ee2bfSIvan Butygin   assert(strides.size() == indices.size());
308793ee2bfSIvan Butygin   auto sourceRank = static_cast<unsigned>(strides.size());
309793ee2bfSIvan Butygin 
310793ee2bfSIvan Butygin   // Hold the affine symbols and values for the computation of the offset.
311793ee2bfSIvan Butygin   SmallVector<OpFoldResult> values(2 * sourceRank + 1);
312793ee2bfSIvan Butygin   SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
313793ee2bfSIvan Butygin 
314793ee2bfSIvan Butygin   bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols});
315793ee2bfSIvan Butygin   AffineExpr expr = symbols.front();
316793ee2bfSIvan Butygin   values[0] = sourceOffset;
317793ee2bfSIvan Butygin 
318793ee2bfSIvan Butygin   for (unsigned i = 0; i < sourceRank; ++i) {
319793ee2bfSIvan Butygin     // Compute the stride.
320793ee2bfSIvan Butygin     OpFoldResult origStride = strides[i];
321793ee2bfSIvan Butygin 
322793ee2bfSIvan Butygin     // Build up the computation of the offset.
323793ee2bfSIvan Butygin     unsigned baseIdxForDim = 1 + 2 * i;
324793ee2bfSIvan Butygin     unsigned subOffsetForDim = baseIdxForDim;
325793ee2bfSIvan Butygin     unsigned origStrideForDim = baseIdxForDim + 1;
326793ee2bfSIvan Butygin     expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
327793ee2bfSIvan Butygin     values[subOffsetForDim] = indices[i];
328793ee2bfSIvan Butygin     values[origStrideForDim] = origStride;
329793ee2bfSIvan Butygin   }
330793ee2bfSIvan Butygin 
331793ee2bfSIvan Butygin   return {expr, values};
332793ee2bfSIvan Butygin }
333831041beSChristopher Bate 
334847048f4SDiego Caballero std::pair<AffineExpr, SmallVector<OpFoldResult>>
335847048f4SDiego Caballero mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
336847048f4SDiego Caballero                          ArrayRef<Value> indices) {
337847048f4SDiego Caballero   return computeLinearIndex(
338847048f4SDiego Caballero       sourceOffset, getAsIndexOpFoldResult(sourceOffset.getContext(), strides),
339847048f4SDiego Caballero       getAsOpFoldResult(ValueRange(indices)));
340847048f4SDiego Caballero }
341847048f4SDiego Caballero 
342831041beSChristopher Bate //===----------------------------------------------------------------------===//
343831041beSChristopher Bate // TileOffsetRange
344831041beSChristopher Bate //===----------------------------------------------------------------------===//
345831041beSChristopher Bate 
346831041beSChristopher Bate /// Apply left-padding by 1 to the tile shape if required.
347831041beSChristopher Bate static SmallVector<int64_t> padTileShapeToSize(ArrayRef<int64_t> tileShape,
348831041beSChristopher Bate                                                unsigned paddedSize) {
349831041beSChristopher Bate   assert(tileShape.size() <= paddedSize &&
350831041beSChristopher Bate          "expected tileShape to <= paddedSize");
351831041beSChristopher Bate   if (tileShape.size() == paddedSize)
352831041beSChristopher Bate     return to_vector(tileShape);
353831041beSChristopher Bate   SmallVector<int64_t> result(paddedSize - tileShape.size(), 1);
354831041beSChristopher Bate   llvm::append_range(result, tileShape);
355831041beSChristopher Bate   return result;
356831041beSChristopher Bate }
357831041beSChristopher Bate 
358831041beSChristopher Bate mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl(
359831041beSChristopher Bate     ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
360831041beSChristopher Bate     ArrayRef<int64_t> loopOrder)
361831041beSChristopher Bate     : tileShape(padTileShapeToSize(tileShape, shape.size())),
362831041beSChristopher Bate       inverseLoopOrder(invertPermutationVector(loopOrder)),
363831041beSChristopher Bate       sliceStrides(shape.size()) {
364831041beSChristopher Bate   // Divide the shape by the tile shape.
365831041beSChristopher Bate   std::optional<SmallVector<int64_t>> shapeRatio =
366831041beSChristopher Bate       mlir::computeShapeRatio(shape, tileShape);
367831041beSChristopher Bate   assert(shapeRatio && shapeRatio->size() == shape.size() &&
368831041beSChristopher Bate          "target shape does not evenly divide the original shape");
369831041beSChristopher Bate   assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() &&
370831041beSChristopher Bate          "expected loop order to be a permutation of rank equal to outer "
371831041beSChristopher Bate          "shape");
372831041beSChristopher Bate 
373831041beSChristopher Bate   maxLinearIndex = mlir::computeMaxLinearIndex(*shapeRatio);
374831041beSChristopher Bate   mlir::applyPermutationToVector(*shapeRatio, loopOrder);
375831041beSChristopher Bate   sliceStrides = mlir::computeStrides(*shapeRatio);
376831041beSChristopher Bate }
377831041beSChristopher Bate 
378831041beSChristopher Bate SmallVector<int64_t> mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets(
379831041beSChristopher Bate     int64_t linearIndex) const {
380831041beSChristopher Bate   SmallVector<int64_t> tileCoords = applyPermutation(
381831041beSChristopher Bate       delinearize(linearIndex, sliceStrides), inverseLoopOrder);
382831041beSChristopher Bate   return computeElementwiseMul(tileCoords, tileShape);
383831041beSChristopher Bate }
384831041beSChristopher Bate 
385831041beSChristopher Bate SmallVector<AffineExpr>
386831041beSChristopher Bate mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets(
387831041beSChristopher Bate     AffineExpr linearIndex) const {
388831041beSChristopher Bate   MLIRContext *ctx = linearIndex.getContext();
389831041beSChristopher Bate   SmallVector<AffineExpr> tileCoords = applyPermutation(
390831041beSChristopher Bate       delinearize(linearIndex, sliceStrides), inverseLoopOrder);
391831041beSChristopher Bate   return mlir::computeElementwiseMul(tileCoords,
392831041beSChristopher Bate                                      getAffineConstantExprs(tileShape, ctx));
393831041beSChristopher Bate }
394