1 //===- IndexingUtils.cpp - Helpers related to index computations ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Utils/IndexingUtils.h" 10 #include "mlir/Dialect/Utils/StaticValueUtils.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/BuiltinAttributes.h" 14 #include "mlir/IR/MLIRContext.h" 15 #include "llvm/ADT/STLExtras.h" 16 #include <numeric> 17 #include <optional> 18 19 using namespace mlir; 20 21 template <typename ExprType> 22 SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes, 23 ExprType unit) { 24 if (sizes.empty()) 25 return {}; 26 SmallVector<ExprType> strides(sizes.size(), unit); 27 for (int64_t r = strides.size() - 2; r >= 0; --r) 28 strides[r] = strides[r + 1] * sizes[r + 1]; 29 return strides; 30 } 31 32 template <typename ExprType> 33 SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1, 34 ArrayRef<ExprType> v2) { 35 // Early exit if both are empty, let zip_equal fail if only 1 is empty. 36 if (v1.empty() && v2.empty()) 37 return {}; 38 SmallVector<ExprType> result; 39 for (auto it : llvm::zip_equal(v1, v2)) 40 result.push_back(std::get<0>(it) * std::get<1>(it)); 41 return result; 42 } 43 44 template <typename ExprType> 45 ExprType linearizeImpl(ArrayRef<ExprType> offsets, ArrayRef<ExprType> basis, 46 ExprType zero) { 47 assert(offsets.size() == basis.size()); 48 ExprType linearIndex = zero; 49 for (unsigned idx = 0, e = basis.size(); idx < e; ++idx) 50 linearIndex = linearIndex + offsets[idx] * basis[idx]; 51 return linearIndex; 52 } 53 54 template <typename ExprType, typename DivOpTy> 55 SmallVector<ExprType> delinearizeImpl(ExprType linearIndex, 56 ArrayRef<ExprType> strides, 57 DivOpTy divOp) { 58 int64_t rank = strides.size(); 59 SmallVector<ExprType> offsets(rank); 60 for (int64_t r = 0; r < rank; ++r) { 61 offsets[r] = divOp(linearIndex, strides[r]); 62 linearIndex = linearIndex % strides[r]; 63 } 64 return offsets; 65 } 66 67 //===----------------------------------------------------------------------===// 68 // Utils that operate on static integer values. 69 //===----------------------------------------------------------------------===// 70 71 SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) { 72 assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) && 73 "sizes must be nonnegative"); 74 int64_t unit = 1; 75 return ::computeSuffixProductImpl(sizes, unit); 76 } 77 78 SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1, 79 ArrayRef<int64_t> v2) { 80 return computeElementwiseMulImpl(v1, v2); 81 } 82 83 int64_t mlir::computeSum(ArrayRef<int64_t> basis) { 84 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && 85 "basis must be nonnegative"); 86 if (basis.empty()) 87 return 0; 88 return std::accumulate(basis.begin(), basis.end(), 1, std::plus<int64_t>()); 89 } 90 91 int64_t mlir::computeProduct(ArrayRef<int64_t> basis) { 92 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && 93 "basis must be nonnegative"); 94 if (basis.empty()) 95 return 1; 96 return std::accumulate(basis.begin(), basis.end(), 1, 97 std::multiplies<int64_t>()); 98 } 99 100 int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) { 101 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && 102 "basis must be nonnegative"); 103 int64_t zero = 0; 104 return linearizeImpl(offsets, basis, zero); 105 } 106 107 SmallVector<int64_t> mlir::delinearize(int64_t linearIndex, 108 ArrayRef<int64_t> strides) { 109 assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) && 110 "strides must be nonnegative"); 111 return delinearizeImpl(linearIndex, strides, 112 [](int64_t e1, int64_t e2) { return e1 / e2; }); 113 } 114 115 std::optional<SmallVector<int64_t>> 116 mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) { 117 if (shape.size() < subShape.size()) 118 return std::nullopt; 119 assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) && 120 "shape must be nonnegative"); 121 assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) && 122 "subShape must be nonnegative"); 123 124 // Starting from the end, compute the integer divisors. 125 std::vector<int64_t> result; 126 result.reserve(shape.size()); 127 for (auto [size, subSize] : 128 llvm::zip(llvm::reverse(shape), llvm::reverse(subShape))) { 129 // If integral division does not occur, return and let the caller decide. 130 if (size % subSize != 0) 131 return std::nullopt; 132 result.push_back(size / subSize); 133 } 134 // At this point we computed the ratio (in reverse) for the common size. 135 // Fill with the remaining entries from the shape (still in reverse). 136 int commonSize = subShape.size(); 137 std::copy(shape.rbegin() + commonSize, shape.rend(), 138 std::back_inserter(result)); 139 // Reverse again to get it back in the proper order and return. 140 return SmallVector<int64_t>{result.rbegin(), result.rend()}; 141 } 142 143 //===----------------------------------------------------------------------===// 144 // Utils that operate on AffineExpr. 145 //===----------------------------------------------------------------------===// 146 147 SmallVector<AffineExpr> mlir::computeSuffixProduct(ArrayRef<AffineExpr> sizes) { 148 if (sizes.empty()) 149 return {}; 150 AffineExpr unit = getAffineConstantExpr(1, sizes.front().getContext()); 151 return ::computeSuffixProductImpl(sizes, unit); 152 } 153 154 SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1, 155 ArrayRef<AffineExpr> v2) { 156 return computeElementwiseMulImpl(v1, v2); 157 } 158 159 AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) { 160 if (basis.empty()) 161 return getAffineConstantExpr(0, ctx); 162 return std::accumulate(basis.begin(), basis.end(), 163 getAffineConstantExpr(0, ctx), 164 std::plus<AffineExpr>()); 165 } 166 167 AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) { 168 if (basis.empty()) 169 return getAffineConstantExpr(1, ctx); 170 return std::accumulate(basis.begin(), basis.end(), 171 getAffineConstantExpr(1, ctx), 172 std::multiplies<AffineExpr>()); 173 } 174 175 AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets, 176 ArrayRef<AffineExpr> basis) { 177 AffineExpr zero = getAffineConstantExpr(0, ctx); 178 return linearizeImpl(offsets, basis, zero); 179 } 180 181 AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets, 182 ArrayRef<int64_t> basis) { 183 184 return linearize(ctx, offsets, getAffineConstantExprs(basis, ctx)); 185 } 186 187 SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex, 188 ArrayRef<AffineExpr> strides) { 189 return delinearizeImpl( 190 linearIndex, strides, 191 [](AffineExpr e1, AffineExpr e2) { return e1.floorDiv(e2); }); 192 } 193 194 SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex, 195 ArrayRef<int64_t> strides) { 196 MLIRContext *ctx = linearIndex.getContext(); 197 return delinearize(linearIndex, getAffineConstantExprs(strides, ctx)); 198 } 199 200 //===----------------------------------------------------------------------===// 201 // Permutation utils. 202 //===----------------------------------------------------------------------===// 203 204 SmallVector<int64_t> 205 mlir::invertPermutationVector(ArrayRef<int64_t> permutation) { 206 assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) && 207 "permutation must be non-negative"); 208 SmallVector<int64_t> inversion(permutation.size()); 209 for (const auto &pos : llvm::enumerate(permutation)) { 210 inversion[pos.value()] = pos.index(); 211 } 212 return inversion; 213 } 214 215 bool mlir::isIdentityPermutation(ArrayRef<int64_t> permutation) { 216 for (auto i : llvm::seq<int64_t>(0, permutation.size())) 217 if (permutation[i] != i) 218 return false; 219 return true; 220 } 221 222 bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) { 223 assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) && 224 "permutation must be non-negative"); 225 llvm::SmallDenseSet<int64_t, 4> seenVals; 226 for (auto val : interchange) { 227 if (seenVals.count(val)) 228 return false; 229 seenVals.insert(val); 230 } 231 return seenVals.size() == interchange.size(); 232 } 233 234 SmallVector<int64_t> 235 mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions, 236 ArrayRef<int64_t> desiredPositions) { 237 SmallVector<int64_t> res(permSize, -1); 238 DenseSet<int64_t> seen; 239 for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) { 240 res[desiredPos] = pos; 241 seen.insert(pos); 242 } 243 int64_t nextPos = 0; 244 for (int64_t &entry : res) { 245 if (entry != -1) 246 continue; 247 while (seen.contains(nextPos)) 248 ++nextPos; 249 entry = nextPos; 250 ++nextPos; 251 } 252 return res; 253 } 254 255 SmallVector<int64_t> mlir::dropDims(ArrayRef<int64_t> inputPerm, 256 ArrayRef<int64_t> dropPositions) { 257 assert(inputPerm.size() >= dropPositions.size() && 258 "expect inputPerm size large than position to drop"); 259 SmallVector<int64_t> res; 260 unsigned permSize = inputPerm.size(); 261 for (unsigned inputIndex = 0; inputIndex < permSize; ++inputIndex) { 262 int64_t targetIndex = inputPerm[inputIndex]; 263 bool shouldDrop = false; 264 unsigned dropSize = dropPositions.size(); 265 for (unsigned dropIndex = 0; dropIndex < dropSize; dropIndex++) { 266 if (dropPositions[dropIndex] == inputPerm[inputIndex]) { 267 shouldDrop = true; 268 break; 269 } 270 if (dropPositions[dropIndex] < inputPerm[inputIndex]) { 271 targetIndex--; 272 } 273 } 274 if (!shouldDrop) { 275 res.push_back(targetIndex); 276 } 277 } 278 return res; 279 } 280 281 SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr, 282 unsigned dropFront, 283 unsigned dropBack) { 284 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 285 auto range = arrayAttr.getAsRange<IntegerAttr>(); 286 SmallVector<int64_t> res; 287 res.reserve(arrayAttr.size() - dropFront - dropBack); 288 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 289 it != eit; ++it) 290 res.push_back((*it).getValue().getSExtValue()); 291 return res; 292 } 293 294 // TODO: do we have any common utily for this? 295 static MLIRContext *getContext(OpFoldResult val) { 296 assert(val && "Invalid value"); 297 if (auto attr = dyn_cast<Attribute>(val)) { 298 return attr.getContext(); 299 } 300 return cast<Value>(val).getContext(); 301 } 302 303 std::pair<AffineExpr, SmallVector<OpFoldResult>> 304 mlir::computeLinearIndex(OpFoldResult sourceOffset, 305 ArrayRef<OpFoldResult> strides, 306 ArrayRef<OpFoldResult> indices) { 307 assert(strides.size() == indices.size()); 308 auto sourceRank = static_cast<unsigned>(strides.size()); 309 310 // Hold the affine symbols and values for the computation of the offset. 311 SmallVector<OpFoldResult> values(2 * sourceRank + 1); 312 SmallVector<AffineExpr> symbols(2 * sourceRank + 1); 313 314 bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols}); 315 AffineExpr expr = symbols.front(); 316 values[0] = sourceOffset; 317 318 for (unsigned i = 0; i < sourceRank; ++i) { 319 // Compute the stride. 320 OpFoldResult origStride = strides[i]; 321 322 // Build up the computation of the offset. 323 unsigned baseIdxForDim = 1 + 2 * i; 324 unsigned subOffsetForDim = baseIdxForDim; 325 unsigned origStrideForDim = baseIdxForDim + 1; 326 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim]; 327 values[subOffsetForDim] = indices[i]; 328 values[origStrideForDim] = origStride; 329 } 330 331 return {expr, values}; 332 } 333 334 std::pair<AffineExpr, SmallVector<OpFoldResult>> 335 mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides, 336 ArrayRef<Value> indices) { 337 return computeLinearIndex( 338 sourceOffset, getAsIndexOpFoldResult(sourceOffset.getContext(), strides), 339 getAsOpFoldResult(ValueRange(indices))); 340 } 341 342 //===----------------------------------------------------------------------===// 343 // TileOffsetRange 344 //===----------------------------------------------------------------------===// 345 346 /// Apply left-padding by 1 to the tile shape if required. 347 static SmallVector<int64_t> padTileShapeToSize(ArrayRef<int64_t> tileShape, 348 unsigned paddedSize) { 349 assert(tileShape.size() <= paddedSize && 350 "expected tileShape to <= paddedSize"); 351 if (tileShape.size() == paddedSize) 352 return to_vector(tileShape); 353 SmallVector<int64_t> result(paddedSize - tileShape.size(), 1); 354 llvm::append_range(result, tileShape); 355 return result; 356 } 357 358 mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl( 359 ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape, 360 ArrayRef<int64_t> loopOrder) 361 : tileShape(padTileShapeToSize(tileShape, shape.size())), 362 inverseLoopOrder(invertPermutationVector(loopOrder)), 363 sliceStrides(shape.size()) { 364 // Divide the shape by the tile shape. 365 std::optional<SmallVector<int64_t>> shapeRatio = 366 mlir::computeShapeRatio(shape, tileShape); 367 assert(shapeRatio && shapeRatio->size() == shape.size() && 368 "target shape does not evenly divide the original shape"); 369 assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() && 370 "expected loop order to be a permutation of rank equal to outer " 371 "shape"); 372 373 maxLinearIndex = mlir::computeMaxLinearIndex(*shapeRatio); 374 mlir::applyPermutationToVector(*shapeRatio, loopOrder); 375 sliceStrides = mlir::computeStrides(*shapeRatio); 376 } 377 378 SmallVector<int64_t> mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets( 379 int64_t linearIndex) const { 380 SmallVector<int64_t> tileCoords = applyPermutation( 381 delinearize(linearIndex, sliceStrides), inverseLoopOrder); 382 return computeElementwiseMul(tileCoords, tileShape); 383 } 384 385 SmallVector<AffineExpr> 386 mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets( 387 AffineExpr linearIndex) const { 388 MLIRContext *ctx = linearIndex.getContext(); 389 SmallVector<AffineExpr> tileCoords = applyPermutation( 390 delinearize(linearIndex, sliceStrides), inverseLoopOrder); 391 return mlir::computeElementwiseMul(tileCoords, 392 getAffineConstantExprs(tileShape, ctx)); 393 } 394