199ef9eebSMatthias Springer //===- IndexingUtils.cpp - Helpers related to index computations ----------===// 299ef9eebSMatthias Springer // 399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 699ef9eebSMatthias Springer // 799ef9eebSMatthias Springer //===----------------------------------------------------------------------===// 899ef9eebSMatthias Springer 999ef9eebSMatthias Springer #include "mlir/Dialect/Utils/IndexingUtils.h" 10847048f4SDiego Caballero #include "mlir/Dialect/Utils/StaticValueUtils.h" 111b002d27SArnab Dutta #include "mlir/IR/AffineExpr.h" 121b002d27SArnab Dutta #include "mlir/IR/Builders.h" 1399ef9eebSMatthias Springer #include "mlir/IR/BuiltinAttributes.h" 14203fad47SNicolas Vasilache #include "mlir/IR/MLIRContext.h" 15203fad47SNicolas Vasilache #include "llvm/ADT/STLExtras.h" 167a69a9d7SNicolas Vasilache #include <numeric> 17a1fe1f5fSKazu Hirata #include <optional> 187a69a9d7SNicolas Vasilache 197a69a9d7SNicolas Vasilache using namespace mlir; 207a69a9d7SNicolas Vasilache 21203fad47SNicolas Vasilache template <typename ExprType> 22203fad47SNicolas Vasilache SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes, 23203fad47SNicolas Vasilache ExprType unit) { 24203fad47SNicolas Vasilache if (sizes.empty()) 25203fad47SNicolas Vasilache return {}; 26203fad47SNicolas Vasilache SmallVector<ExprType> strides(sizes.size(), unit); 277a69a9d7SNicolas Vasilache for (int64_t r = strides.size() - 2; r >= 0; --r) 287a69a9d7SNicolas Vasilache strides[r] = strides[r + 1] * sizes[r + 1]; 297a69a9d7SNicolas Vasilache return strides; 307a69a9d7SNicolas Vasilache } 317a69a9d7SNicolas Vasilache 32203fad47SNicolas Vasilache template <typename ExprType> 33203fad47SNicolas Vasilache SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1, 34203fad47SNicolas Vasilache ArrayRef<ExprType> v2) { 35203fad47SNicolas Vasilache // Early exit if both are empty, let zip_equal fail if only 1 is empty. 36203fad47SNicolas Vasilache if (v1.empty() && v2.empty()) 37203fad47SNicolas Vasilache return {}; 38203fad47SNicolas Vasilache SmallVector<ExprType> result; 39203fad47SNicolas Vasilache for (auto it : llvm::zip_equal(v1, v2)) 407a69a9d7SNicolas Vasilache result.push_back(std::get<0>(it) * std::get<1>(it)); 417a69a9d7SNicolas Vasilache return result; 427a69a9d7SNicolas Vasilache } 437a69a9d7SNicolas Vasilache 44203fad47SNicolas Vasilache template <typename ExprType> 45203fad47SNicolas Vasilache ExprType linearizeImpl(ArrayRef<ExprType> offsets, ArrayRef<ExprType> basis, 46203fad47SNicolas Vasilache ExprType zero) { 47203fad47SNicolas Vasilache assert(offsets.size() == basis.size()); 48203fad47SNicolas Vasilache ExprType linearIndex = zero; 49203fad47SNicolas Vasilache for (unsigned idx = 0, e = basis.size(); idx < e; ++idx) 50203fad47SNicolas Vasilache linearIndex = linearIndex + offsets[idx] * basis[idx]; 51203fad47SNicolas Vasilache return linearIndex; 52203fad47SNicolas Vasilache } 53203fad47SNicolas Vasilache 54203fad47SNicolas Vasilache template <typename ExprType, typename DivOpTy> 55203fad47SNicolas Vasilache SmallVector<ExprType> delinearizeImpl(ExprType linearIndex, 56203fad47SNicolas Vasilache ArrayRef<ExprType> strides, 57203fad47SNicolas Vasilache DivOpTy divOp) { 58203fad47SNicolas Vasilache int64_t rank = strides.size(); 59203fad47SNicolas Vasilache SmallVector<ExprType> offsets(rank); 60203fad47SNicolas Vasilache for (int64_t r = 0; r < rank; ++r) { 61203fad47SNicolas Vasilache offsets[r] = divOp(linearIndex, strides[r]); 62203fad47SNicolas Vasilache linearIndex = linearIndex % strides[r]; 63203fad47SNicolas Vasilache } 64203fad47SNicolas Vasilache return offsets; 65203fad47SNicolas Vasilache } 66203fad47SNicolas Vasilache 67203fad47SNicolas Vasilache //===----------------------------------------------------------------------===// 68203fad47SNicolas Vasilache // Utils that operate on static integer values. 69203fad47SNicolas Vasilache //===----------------------------------------------------------------------===// 70203fad47SNicolas Vasilache 71203fad47SNicolas Vasilache SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) { 72c65d8c71SGuray Ozen assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) && 73203fad47SNicolas Vasilache "sizes must be nonnegative"); 74203fad47SNicolas Vasilache int64_t unit = 1; 75203fad47SNicolas Vasilache return ::computeSuffixProductImpl(sizes, unit); 76203fad47SNicolas Vasilache } 77203fad47SNicolas Vasilache 78203fad47SNicolas Vasilache SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1, 79203fad47SNicolas Vasilache ArrayRef<int64_t> v2) { 80203fad47SNicolas Vasilache return computeElementwiseMulImpl(v1, v2); 81203fad47SNicolas Vasilache } 82203fad47SNicolas Vasilache 839e54d5e7SNicolas Vasilache int64_t mlir::computeSum(ArrayRef<int64_t> basis) { 849e54d5e7SNicolas Vasilache assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && 859e54d5e7SNicolas Vasilache "basis must be nonnegative"); 869e54d5e7SNicolas Vasilache if (basis.empty()) 879e54d5e7SNicolas Vasilache return 0; 889e54d5e7SNicolas Vasilache return std::accumulate(basis.begin(), basis.end(), 1, std::plus<int64_t>()); 899e54d5e7SNicolas Vasilache } 909e54d5e7SNicolas Vasilache 919e54d5e7SNicolas Vasilache int64_t mlir::computeProduct(ArrayRef<int64_t> basis) { 92203fad47SNicolas Vasilache assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && 93203fad47SNicolas Vasilache "basis must be nonnegative"); 94203fad47SNicolas Vasilache if (basis.empty()) 95a9205c5cSSpenser Bauman return 1; 96203fad47SNicolas Vasilache return std::accumulate(basis.begin(), basis.end(), 1, 97203fad47SNicolas Vasilache std::multiplies<int64_t>()); 98203fad47SNicolas Vasilache } 99203fad47SNicolas Vasilache 100203fad47SNicolas Vasilache int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) { 101203fad47SNicolas Vasilache assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && 102203fad47SNicolas Vasilache "basis must be nonnegative"); 103203fad47SNicolas Vasilache int64_t zero = 0; 104203fad47SNicolas Vasilache return linearizeImpl(offsets, basis, zero); 105203fad47SNicolas Vasilache } 106203fad47SNicolas Vasilache 107203fad47SNicolas Vasilache SmallVector<int64_t> mlir::delinearize(int64_t linearIndex, 108203fad47SNicolas Vasilache ArrayRef<int64_t> strides) { 109203fad47SNicolas Vasilache assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) && 110203fad47SNicolas Vasilache "strides must be nonnegative"); 111203fad47SNicolas Vasilache return delinearizeImpl(linearIndex, strides, 112203fad47SNicolas Vasilache [](int64_t e1, int64_t e2) { return e1 / e2; }); 113203fad47SNicolas Vasilache } 114203fad47SNicolas Vasilache 1150a81ace0SKazu Hirata std::optional<SmallVector<int64_t>> 1167a69a9d7SNicolas Vasilache mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) { 1177a69a9d7SNicolas Vasilache if (shape.size() < subShape.size()) 1181a36588eSKazu Hirata return std::nullopt; 1197a69a9d7SNicolas Vasilache assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) && 1207a69a9d7SNicolas Vasilache "shape must be nonnegative"); 1217a69a9d7SNicolas Vasilache assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) && 1227a69a9d7SNicolas Vasilache "subShape must be nonnegative"); 1237a69a9d7SNicolas Vasilache 1247a69a9d7SNicolas Vasilache // Starting from the end, compute the integer divisors. 1257a69a9d7SNicolas Vasilache std::vector<int64_t> result; 1267a69a9d7SNicolas Vasilache result.reserve(shape.size()); 1277a69a9d7SNicolas Vasilache for (auto [size, subSize] : 1287a69a9d7SNicolas Vasilache llvm::zip(llvm::reverse(shape), llvm::reverse(subShape))) { 1297a69a9d7SNicolas Vasilache // If integral division does not occur, return and let the caller decide. 1307a69a9d7SNicolas Vasilache if (size % subSize != 0) 1311a36588eSKazu Hirata return std::nullopt; 1327a69a9d7SNicolas Vasilache result.push_back(size / subSize); 1337a69a9d7SNicolas Vasilache } 1347a69a9d7SNicolas Vasilache // At this point we computed the ratio (in reverse) for the common size. 1357a69a9d7SNicolas Vasilache // Fill with the remaining entries from the shape (still in reverse). 1367a69a9d7SNicolas Vasilache int commonSize = subShape.size(); 1377a69a9d7SNicolas Vasilache std::copy(shape.rbegin() + commonSize, shape.rend(), 1387a69a9d7SNicolas Vasilache std::back_inserter(result)); 1397a69a9d7SNicolas Vasilache // Reverse again to get it back in the proper order and return. 1407a69a9d7SNicolas Vasilache return SmallVector<int64_t>{result.rbegin(), result.rend()}; 1417a69a9d7SNicolas Vasilache } 1427a69a9d7SNicolas Vasilache 143203fad47SNicolas Vasilache //===----------------------------------------------------------------------===// 144203fad47SNicolas Vasilache // Utils that operate on AffineExpr. 145203fad47SNicolas Vasilache //===----------------------------------------------------------------------===// 146203fad47SNicolas Vasilache 147203fad47SNicolas Vasilache SmallVector<AffineExpr> mlir::computeSuffixProduct(ArrayRef<AffineExpr> sizes) { 148203fad47SNicolas Vasilache if (sizes.empty()) 149203fad47SNicolas Vasilache return {}; 150203fad47SNicolas Vasilache AffineExpr unit = getAffineConstantExpr(1, sizes.front().getContext()); 151203fad47SNicolas Vasilache return ::computeSuffixProductImpl(sizes, unit); 15299ef9eebSMatthias Springer } 15399ef9eebSMatthias Springer 154203fad47SNicolas Vasilache SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1, 155203fad47SNicolas Vasilache ArrayRef<AffineExpr> v2) { 156203fad47SNicolas Vasilache return computeElementwiseMulImpl(v1, v2); 15799ef9eebSMatthias Springer } 15899ef9eebSMatthias Springer 1599e54d5e7SNicolas Vasilache AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) { 1609e54d5e7SNicolas Vasilache if (basis.empty()) 1619e54d5e7SNicolas Vasilache return getAffineConstantExpr(0, ctx); 1629e54d5e7SNicolas Vasilache return std::accumulate(basis.begin(), basis.end(), 163a3cd2eebSNicolas Vasilache getAffineConstantExpr(0, ctx), 1649e54d5e7SNicolas Vasilache std::plus<AffineExpr>()); 1659e54d5e7SNicolas Vasilache } 1669e54d5e7SNicolas Vasilache 1679e54d5e7SNicolas Vasilache AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) { 1687a69a9d7SNicolas Vasilache if (basis.empty()) 169a3cd2eebSNicolas Vasilache return getAffineConstantExpr(1, ctx); 170203fad47SNicolas Vasilache return std::accumulate(basis.begin(), basis.end(), 171203fad47SNicolas Vasilache getAffineConstantExpr(1, ctx), 172203fad47SNicolas Vasilache std::multiplies<AffineExpr>()); 1737a69a9d7SNicolas Vasilache } 1747a69a9d7SNicolas Vasilache 175203fad47SNicolas Vasilache AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets, 176203fad47SNicolas Vasilache ArrayRef<AffineExpr> basis) { 177203fad47SNicolas Vasilache AffineExpr zero = getAffineConstantExpr(0, ctx); 178203fad47SNicolas Vasilache return linearizeImpl(offsets, basis, zero); 179203fad47SNicolas Vasilache } 180203fad47SNicolas Vasilache 181203fad47SNicolas Vasilache AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets, 182203fad47SNicolas Vasilache ArrayRef<int64_t> basis) { 183831041beSChristopher Bate 184831041beSChristopher Bate return linearize(ctx, offsets, getAffineConstantExprs(basis, ctx)); 185203fad47SNicolas Vasilache } 186203fad47SNicolas Vasilache 187203fad47SNicolas Vasilache SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex, 188203fad47SNicolas Vasilache ArrayRef<AffineExpr> strides) { 189203fad47SNicolas Vasilache return delinearizeImpl( 190203fad47SNicolas Vasilache linearIndex, strides, 191203fad47SNicolas Vasilache [](AffineExpr e1, AffineExpr e2) { return e1.floorDiv(e2); }); 192203fad47SNicolas Vasilache } 193203fad47SNicolas Vasilache 194203fad47SNicolas Vasilache SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex, 195203fad47SNicolas Vasilache ArrayRef<int64_t> strides) { 196203fad47SNicolas Vasilache MLIRContext *ctx = linearIndex.getContext(); 197831041beSChristopher Bate return delinearize(linearIndex, getAffineConstantExprs(strides, ctx)); 198203fad47SNicolas Vasilache } 199203fad47SNicolas Vasilache 200203fad47SNicolas Vasilache //===----------------------------------------------------------------------===// 201203fad47SNicolas Vasilache // Permutation utils. 202203fad47SNicolas Vasilache //===----------------------------------------------------------------------===// 203203fad47SNicolas Vasilache 204203fad47SNicolas Vasilache SmallVector<int64_t> 205b1d3afc9SHanhan Wang mlir::invertPermutationVector(ArrayRef<int64_t> permutation) { 206203fad47SNicolas Vasilache assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) && 207203fad47SNicolas Vasilache "permutation must be non-negative"); 208b1d3afc9SHanhan Wang SmallVector<int64_t> inversion(permutation.size()); 209b1d3afc9SHanhan Wang for (const auto &pos : llvm::enumerate(permutation)) { 210b1d3afc9SHanhan Wang inversion[pos.value()] = pos.index(); 211b1d3afc9SHanhan Wang } 212b1d3afc9SHanhan Wang return inversion; 213b1d3afc9SHanhan Wang } 214b1d3afc9SHanhan Wang 2152472c45bSHan-Chung Wang bool mlir::isIdentityPermutation(ArrayRef<int64_t> permutation) { 2162472c45bSHan-Chung Wang for (auto i : llvm::seq<int64_t>(0, permutation.size())) 2172472c45bSHan-Chung Wang if (permutation[i] != i) 2182472c45bSHan-Chung Wang return false; 2192472c45bSHan-Chung Wang return true; 2202472c45bSHan-Chung Wang } 2212472c45bSHan-Chung Wang 222b1d3afc9SHanhan Wang bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) { 223203fad47SNicolas Vasilache assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) && 224203fad47SNicolas Vasilache "permutation must be non-negative"); 225b1d3afc9SHanhan Wang llvm::SmallDenseSet<int64_t, 4> seenVals; 226b1d3afc9SHanhan Wang for (auto val : interchange) { 227b1d3afc9SHanhan Wang if (seenVals.count(val)) 228b1d3afc9SHanhan Wang return false; 229b1d3afc9SHanhan Wang seenVals.insert(val); 230b1d3afc9SHanhan Wang } 231b1d3afc9SHanhan Wang return seenVals.size() == interchange.size(); 232b1d3afc9SHanhan Wang } 233b1d3afc9SHanhan Wang 2340bfbecf5SQuentin Colombet SmallVector<int64_t> 2350bfbecf5SQuentin Colombet mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions, 2360bfbecf5SQuentin Colombet ArrayRef<int64_t> desiredPositions) { 2370bfbecf5SQuentin Colombet SmallVector<int64_t> res(permSize, -1); 2380bfbecf5SQuentin Colombet DenseSet<int64_t> seen; 2390bfbecf5SQuentin Colombet for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) { 2400bfbecf5SQuentin Colombet res[desiredPos] = pos; 2410bfbecf5SQuentin Colombet seen.insert(pos); 2420bfbecf5SQuentin Colombet } 2430bfbecf5SQuentin Colombet int64_t nextPos = 0; 2440bfbecf5SQuentin Colombet for (int64_t &entry : res) { 2450bfbecf5SQuentin Colombet if (entry != -1) 2460bfbecf5SQuentin Colombet continue; 2470bfbecf5SQuentin Colombet while (seen.contains(nextPos)) 2480bfbecf5SQuentin Colombet ++nextPos; 2490bfbecf5SQuentin Colombet entry = nextPos; 2500bfbecf5SQuentin Colombet ++nextPos; 2510bfbecf5SQuentin Colombet } 2520bfbecf5SQuentin Colombet return res; 2530bfbecf5SQuentin Colombet } 2540bfbecf5SQuentin Colombet 255*9cc11b98Sdonald chen SmallVector<int64_t> mlir::dropDims(ArrayRef<int64_t> inputPerm, 256*9cc11b98Sdonald chen ArrayRef<int64_t> dropPositions) { 257*9cc11b98Sdonald chen assert(inputPerm.size() >= dropPositions.size() && 258*9cc11b98Sdonald chen "expect inputPerm size large than position to drop"); 259*9cc11b98Sdonald chen SmallVector<int64_t> res; 260*9cc11b98Sdonald chen unsigned permSize = inputPerm.size(); 261*9cc11b98Sdonald chen for (unsigned inputIndex = 0; inputIndex < permSize; ++inputIndex) { 262*9cc11b98Sdonald chen int64_t targetIndex = inputPerm[inputIndex]; 263*9cc11b98Sdonald chen bool shouldDrop = false; 264*9cc11b98Sdonald chen unsigned dropSize = dropPositions.size(); 265*9cc11b98Sdonald chen for (unsigned dropIndex = 0; dropIndex < dropSize; dropIndex++) { 266*9cc11b98Sdonald chen if (dropPositions[dropIndex] == inputPerm[inputIndex]) { 267*9cc11b98Sdonald chen shouldDrop = true; 268*9cc11b98Sdonald chen break; 269*9cc11b98Sdonald chen } 270*9cc11b98Sdonald chen if (dropPositions[dropIndex] < inputPerm[inputIndex]) { 271*9cc11b98Sdonald chen targetIndex--; 272*9cc11b98Sdonald chen } 273*9cc11b98Sdonald chen } 274*9cc11b98Sdonald chen if (!shouldDrop) { 275*9cc11b98Sdonald chen res.push_back(targetIndex); 276*9cc11b98Sdonald chen } 277*9cc11b98Sdonald chen } 278*9cc11b98Sdonald chen return res; 279*9cc11b98Sdonald chen } 280*9cc11b98Sdonald chen 281203fad47SNicolas Vasilache SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr, 28299ef9eebSMatthias Springer unsigned dropFront, 28399ef9eebSMatthias Springer unsigned dropBack) { 28499ef9eebSMatthias Springer assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 28599ef9eebSMatthias Springer auto range = arrayAttr.getAsRange<IntegerAttr>(); 2867a69a9d7SNicolas Vasilache SmallVector<int64_t> res; 28799ef9eebSMatthias Springer res.reserve(arrayAttr.size() - dropFront - dropBack); 28899ef9eebSMatthias Springer for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 28999ef9eebSMatthias Springer it != eit; ++it) 29099ef9eebSMatthias Springer res.push_back((*it).getValue().getSExtValue()); 29199ef9eebSMatthias Springer return res; 29299ef9eebSMatthias Springer } 293793ee2bfSIvan Butygin 294793ee2bfSIvan Butygin // TODO: do we have any common utily for this? 295793ee2bfSIvan Butygin static MLIRContext *getContext(OpFoldResult val) { 296793ee2bfSIvan Butygin assert(val && "Invalid value"); 297793ee2bfSIvan Butygin if (auto attr = dyn_cast<Attribute>(val)) { 298793ee2bfSIvan Butygin return attr.getContext(); 299793ee2bfSIvan Butygin } 3008383bf23SMehdi Amini return cast<Value>(val).getContext(); 301793ee2bfSIvan Butygin } 302793ee2bfSIvan Butygin 303793ee2bfSIvan Butygin std::pair<AffineExpr, SmallVector<OpFoldResult>> 304793ee2bfSIvan Butygin mlir::computeLinearIndex(OpFoldResult sourceOffset, 305793ee2bfSIvan Butygin ArrayRef<OpFoldResult> strides, 306793ee2bfSIvan Butygin ArrayRef<OpFoldResult> indices) { 307793ee2bfSIvan Butygin assert(strides.size() == indices.size()); 308793ee2bfSIvan Butygin auto sourceRank = static_cast<unsigned>(strides.size()); 309793ee2bfSIvan Butygin 310793ee2bfSIvan Butygin // Hold the affine symbols and values for the computation of the offset. 311793ee2bfSIvan Butygin SmallVector<OpFoldResult> values(2 * sourceRank + 1); 312793ee2bfSIvan Butygin SmallVector<AffineExpr> symbols(2 * sourceRank + 1); 313793ee2bfSIvan Butygin 314793ee2bfSIvan Butygin bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols}); 315793ee2bfSIvan Butygin AffineExpr expr = symbols.front(); 316793ee2bfSIvan Butygin values[0] = sourceOffset; 317793ee2bfSIvan Butygin 318793ee2bfSIvan Butygin for (unsigned i = 0; i < sourceRank; ++i) { 319793ee2bfSIvan Butygin // Compute the stride. 320793ee2bfSIvan Butygin OpFoldResult origStride = strides[i]; 321793ee2bfSIvan Butygin 322793ee2bfSIvan Butygin // Build up the computation of the offset. 323793ee2bfSIvan Butygin unsigned baseIdxForDim = 1 + 2 * i; 324793ee2bfSIvan Butygin unsigned subOffsetForDim = baseIdxForDim; 325793ee2bfSIvan Butygin unsigned origStrideForDim = baseIdxForDim + 1; 326793ee2bfSIvan Butygin expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim]; 327793ee2bfSIvan Butygin values[subOffsetForDim] = indices[i]; 328793ee2bfSIvan Butygin values[origStrideForDim] = origStride; 329793ee2bfSIvan Butygin } 330793ee2bfSIvan Butygin 331793ee2bfSIvan Butygin return {expr, values}; 332793ee2bfSIvan Butygin } 333831041beSChristopher Bate 334847048f4SDiego Caballero std::pair<AffineExpr, SmallVector<OpFoldResult>> 335847048f4SDiego Caballero mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides, 336847048f4SDiego Caballero ArrayRef<Value> indices) { 337847048f4SDiego Caballero return computeLinearIndex( 338847048f4SDiego Caballero sourceOffset, getAsIndexOpFoldResult(sourceOffset.getContext(), strides), 339847048f4SDiego Caballero getAsOpFoldResult(ValueRange(indices))); 340847048f4SDiego Caballero } 341847048f4SDiego Caballero 342831041beSChristopher Bate //===----------------------------------------------------------------------===// 343831041beSChristopher Bate // TileOffsetRange 344831041beSChristopher Bate //===----------------------------------------------------------------------===// 345831041beSChristopher Bate 346831041beSChristopher Bate /// Apply left-padding by 1 to the tile shape if required. 347831041beSChristopher Bate static SmallVector<int64_t> padTileShapeToSize(ArrayRef<int64_t> tileShape, 348831041beSChristopher Bate unsigned paddedSize) { 349831041beSChristopher Bate assert(tileShape.size() <= paddedSize && 350831041beSChristopher Bate "expected tileShape to <= paddedSize"); 351831041beSChristopher Bate if (tileShape.size() == paddedSize) 352831041beSChristopher Bate return to_vector(tileShape); 353831041beSChristopher Bate SmallVector<int64_t> result(paddedSize - tileShape.size(), 1); 354831041beSChristopher Bate llvm::append_range(result, tileShape); 355831041beSChristopher Bate return result; 356831041beSChristopher Bate } 357831041beSChristopher Bate 358831041beSChristopher Bate mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl( 359831041beSChristopher Bate ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape, 360831041beSChristopher Bate ArrayRef<int64_t> loopOrder) 361831041beSChristopher Bate : tileShape(padTileShapeToSize(tileShape, shape.size())), 362831041beSChristopher Bate inverseLoopOrder(invertPermutationVector(loopOrder)), 363831041beSChristopher Bate sliceStrides(shape.size()) { 364831041beSChristopher Bate // Divide the shape by the tile shape. 365831041beSChristopher Bate std::optional<SmallVector<int64_t>> shapeRatio = 366831041beSChristopher Bate mlir::computeShapeRatio(shape, tileShape); 367831041beSChristopher Bate assert(shapeRatio && shapeRatio->size() == shape.size() && 368831041beSChristopher Bate "target shape does not evenly divide the original shape"); 369831041beSChristopher Bate assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() && 370831041beSChristopher Bate "expected loop order to be a permutation of rank equal to outer " 371831041beSChristopher Bate "shape"); 372831041beSChristopher Bate 373831041beSChristopher Bate maxLinearIndex = mlir::computeMaxLinearIndex(*shapeRatio); 374831041beSChristopher Bate mlir::applyPermutationToVector(*shapeRatio, loopOrder); 375831041beSChristopher Bate sliceStrides = mlir::computeStrides(*shapeRatio); 376831041beSChristopher Bate } 377831041beSChristopher Bate 378831041beSChristopher Bate SmallVector<int64_t> mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets( 379831041beSChristopher Bate int64_t linearIndex) const { 380831041beSChristopher Bate SmallVector<int64_t> tileCoords = applyPermutation( 381831041beSChristopher Bate delinearize(linearIndex, sliceStrides), inverseLoopOrder); 382831041beSChristopher Bate return computeElementwiseMul(tileCoords, tileShape); 383831041beSChristopher Bate } 384831041beSChristopher Bate 385831041beSChristopher Bate SmallVector<AffineExpr> 386831041beSChristopher Bate mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets( 387831041beSChristopher Bate AffineExpr linearIndex) const { 388831041beSChristopher Bate MLIRContext *ctx = linearIndex.getContext(); 389831041beSChristopher Bate SmallVector<AffineExpr> tileCoords = applyPermutation( 390831041beSChristopher Bate delinearize(linearIndex, sliceStrides), inverseLoopOrder); 391831041beSChristopher Bate return mlir::computeElementwiseMul(tileCoords, 392831041beSChristopher Bate getAffineConstantExprs(tileShape, ctx)); 393831041beSChristopher Bate } 394