1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Tiling pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Passes.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Affine/LoopUtils.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/Dialect/Linalg/IR/Linalg.h" 21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Dialect/Utils/IndexingUtils.h" 26 #include "mlir/IR/AffineExpr.h" 27 #include "mlir/IR/AffineMap.h" 28 #include "mlir/IR/BuiltinOps.h" 29 #include "mlir/IR/ValueRange.h" 30 #include "mlir/Transforms/FoldUtils.h" 31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 32 #include "llvm/ADT/STLExtras.h" 33 #include "llvm/Support/CommandLine.h" 34 #include <utility> 35 36 namespace mlir { 37 #define GEN_PASS_DEF_LINALGTILINGPASS 38 #include "mlir/Dialect/Linalg/Passes.h.inc" 39 } // namespace mlir 40 41 using namespace mlir; 42 using namespace mlir::affine; 43 using namespace mlir::linalg; 44 using namespace mlir::scf; 45 46 #define DEBUG_TYPE "linalg-tiling" 47 48 std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap> 49 mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, 50 ArrayRef<OpFoldResult> allShapeSizes, 51 ArrayRef<OpFoldResult> allTileSizes) { 52 assert(allTileSizes.size() == map.getNumResults()); 53 // Apply `map` to get shape sizes in loop order. 54 SmallVector<OpFoldResult> shapeSizes = 55 makeComposedFoldedMultiResultAffineApply(b, loc, map, allShapeSizes); 56 SmallVector<OpFoldResult> tileSizes(allTileSizes); 57 58 // Traverse the tile sizes, which are in loop order, erase zeros everywhere. 59 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 60 for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { 61 if (getConstantIntValue(tileSizes[idx - zerosCount]) == 62 static_cast<int64_t>(0)) { 63 shapeSizes.erase(shapeSizes.begin() + idx - zerosCount); 64 tileSizes.erase(tileSizes.begin() + idx - zerosCount); 65 ++zerosCount; 66 continue; 67 } 68 loopIndexToRangeIndex[idx] = idx - zerosCount; 69 } 70 71 // Create a new range with the applied tile sizes. 72 SmallVector<Range, 4> res; 73 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) 74 res.push_back(Range{b.getIndexAttr(0), shapeSizes[idx], tileSizes[idx]}); 75 return std::make_tuple(res, loopIndexToRangeIndex); 76 } 77 78 void mlir::linalg::transformIndexOps( 79 RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs, 80 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { 81 SmallVector<Value> allIvs(op.getNumLoops(), nullptr); 82 for (auto en : enumerate(allIvs)) { 83 auto rangeIndex = loopIndexToRangeIndex.find(en.index()); 84 if (rangeIndex == loopIndexToRangeIndex.end()) 85 continue; 86 en.value() = ivs[rangeIndex->second]; 87 } 88 offsetIndices(b, op, getAsOpFoldResult(allIvs)); 89 } 90 91 /// Asserts that the given index-typed value is strictly positive. If the value 92 /// is an attribute, asserts at compile time, otherwise emits an assertion 93 /// checked at runtime. 94 static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b, 95 OpFoldResult value) { 96 if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) { 97 assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() && 98 "expected strictly positive tile size and divisor"); 99 return; 100 } 101 102 Value zero = b.create<arith::ConstantIndexOp>(0); 103 Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, 104 cast<Value>(value), zero); 105 b.create<cf::AssertOp>( 106 condition, 107 b.getStringAttr("expected strictly positive tile size and divisor")); 108 } 109 110 FailureOr<StaticContinuousTileSizeSpecification> 111 mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, 112 unsigned dimension, 113 unsigned targetSize) { 114 115 assert(!op.hasDynamicShape() && 116 "cannot compute static multi-tile sizes for an op with dynamic shape"); 117 assert(targetSize > 0 && "target size must be non-negative"); 118 assert(dimension < op.getNumLoops() && "dimension overflow"); 119 120 StaticContinuousTileSizeSpecification spec; 121 int64_t loopRange = op.getStaticLoopRanges()[dimension]; 122 int64_t tripCount = loopRange / targetSize; 123 124 unsigned tileSize = targetSize; 125 126 spec.tileSizes.push_back(tileSize); 127 spec.tripCounts.push_back(tripCount); 128 129 int64_t remainderChunk = loopRange % targetSize; 130 131 while (tileSize > 1 && remainderChunk != 0) { 132 133 uint64_t maxPower = llvm::bit_floor(tileSize); 134 tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower; 135 136 tripCount = remainderChunk / tileSize; 137 138 if (tripCount > 0) { 139 spec.tileSizes.push_back(tileSize); 140 spec.tripCounts.push_back(tripCount); 141 } 142 143 remainderChunk = remainderChunk % tileSize; 144 } 145 146 auto tripCountCheck = [&](SmallVector<int64_t> tileSizes, 147 SmallVector<int64_t> tripCounts, 148 int64_t range) -> bool { 149 int64_t computedRange = 0; 150 for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts)) 151 computedRange += tileSize * tripCount; 152 return range == computedRange; 153 }; 154 155 if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange)) 156 return failure(); 157 158 return spec; 159 } 160 161 FailureOr<ContinuousTileSizeSpecification> 162 mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, 163 unsigned dimension, 164 OpFoldResult targetSize, 165 bool emitAssertions) { 166 167 SmallVector<Range> loopRanges = op.getIterationDomain(builder); 168 unsigned numLoops = loopRanges.size(); 169 170 // Bail out on dimension overflow. 171 if (dimension >= numLoops) 172 return failure(); 173 174 // The code below works only on values. 175 Location loc = op->getLoc(); 176 ImplicitLocOpBuilder b(loc, builder); 177 if (emitAssertions) { 178 emitIsPositiveIndexAssertion(b, targetSize); 179 } 180 Value targetSizeValue = 181 getValueOrCreateConstantIndexOp(builder, loc, targetSize); 182 183 // Find the trip count of the iteration space dimension for which the tile 184 // sizes are computed. 185 Value loopRange = getValueOrCreateConstantIndexOp(b, loc, 186 loopRanges[dimension].size); 187 ContinuousTileSizeSpecification spec; 188 189 // Compute the tile sizes and the respective numbers of tiles. 190 AffineExpr s0 = b.getAffineSymbolExpr(0); 191 AffineExpr s1 = b.getAffineSymbolExpr(1); 192 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value { 193 return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs); 194 }; 195 196 Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue}); 197 Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue}); 198 199 OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply( 200 b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue}); 201 202 // emitAssertions above already asserts that targetSize is 203 // a poistive integer. 204 uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue); 205 206 assert(tileSizeInt > 0 && "target size must be non-negative"); 207 208 spec.tileSizes.push_back(targetSizeValue); 209 spec.tripCounts.push_back(tripCountValue); 210 211 while (tileSizeInt > 1) { 212 uint64_t maxPower = llvm::bit_floor(tileSizeInt); 213 tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower; 214 auto constStepOp = 215 builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt); 216 tripCountValue = apply(s0.floorDiv(s1), {remainderChunkValue, constStepOp}); 217 218 tripCountSize = affine::makeComposedFoldedAffineApply( 219 b, b.getLoc(), s0.floorDiv(s1), {remainderChunkValue, constStepOp}); 220 221 // Optimization if tripCount can be determined to be zero. 222 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tripCountSize)) { 223 auto intAttr = cast<IntegerAttr>(attr); 224 bool isTripCountZero = intAttr.getValue().isZero(); 225 226 if (!isTripCountZero) { 227 spec.tileSizes.push_back(constStepOp); 228 spec.tripCounts.push_back(tripCountValue); 229 } 230 } else { 231 spec.tileSizes.push_back(constStepOp); 232 spec.tripCounts.push_back(tripCountValue); 233 } 234 235 remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp}); 236 } 237 238 return spec; 239 } 240 241 FailureOr<StaticMultiSizeSpecification> 242 mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, 243 int64_t targetSize, int64_t divisor) { 244 assert(!op.hasDynamicShape() && 245 "cannot compute static multi-tile sizes for an op with dynamic shape"); 246 assert(targetSize > 0 && "target size must be non-negative"); 247 assert(divisor > 0 && "divisor must be non-negative"); 248 assert(dimension < op.getNumLoops() && "dimension overflow"); 249 250 StaticMultiSizeSpecification spec; 251 int64_t tripCount = op.getStaticLoopRanges()[dimension]; 252 int64_t a = tripCount / divisor; 253 int64_t t = (targetSize + divisor - 1) / divisor; 254 int64_t totalTripCount = (a + t - 1) / t; 255 spec.lowTileSize = (a / totalTripCount) * divisor; 256 spec.highTileSize = spec.lowTileSize + divisor; 257 spec.highTripCount = a % totalTripCount; 258 spec.lowTripCount = totalTripCount - spec.highTripCount; 259 if (spec.lowTileSize * spec.lowTripCount + 260 spec.highTileSize * spec.highTripCount != 261 tripCount) { 262 return failure(); 263 } 264 return spec; 265 } 266 267 FailureOr<MultiSizeSpecification> 268 mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op, 269 unsigned dimension, OpFoldResult targetSize, 270 OpFoldResult divisor, bool emitAssertions) { 271 // Bail out on dimension overflow. 272 if (dimension >= op.getNumLoops()) 273 return failure(); 274 275 // The code below works only on values. 276 Location loc = op.getLoc(); 277 ImplicitLocOpBuilder b(loc, builder); 278 if (emitAssertions) { 279 emitIsPositiveIndexAssertion(b, targetSize); 280 emitIsPositiveIndexAssertion(b, divisor); 281 } 282 Value targetSizeValue = 283 getValueOrCreateConstantIndexOp(builder, loc, targetSize); 284 Value divisorValue = getValueOrCreateConstantIndexOp(builder, loc, divisor); 285 286 // Find the trip count of the iteration space dimension for which the tile 287 // sizes are computed. 288 SmallVector<OpFoldResult> allShapes = 289 op.createFlatListOfOperandDims(b, b.getLoc()); 290 AffineMap shapesToLoops = op.getShapesToLoopsMap(); 291 SmallVector<OpFoldResult> loopRanges = 292 makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops, 293 allShapes); 294 Value tripCount = 295 getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]); 296 297 // Compute the tile sizes and the respective numbers of tiles. 298 AffineExpr s0 = b.getAffineSymbolExpr(0); 299 AffineExpr s1 = b.getAffineSymbolExpr(1); 300 AffineExpr s2 = b.getAffineSymbolExpr(2); 301 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value { 302 return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs); 303 }; 304 Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue}); 305 Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue}); 306 Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t}); 307 Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue}); 308 Value v = apply(s0 % s1, {a, d}); 309 Value u = apply(s0 - s1, {d, v}); 310 311 MultiSizeSpecification spec; 312 spec.lowTileSize = s; 313 spec.highTileSize = apply(s0 + s1, {s, divisorValue}); 314 spec.lowTripCount = u; 315 spec.highTripCount = v; 316 317 // If requested, emit the check that the tile sizes are computed correctly. 318 // For example, for iteration dimension size of 15 and the target size 8 it is 319 // impossible to find two tile sizes both divisible by 8 that fully cover the 320 // original space dimension. 321 if (emitAssertions) { 322 AffineExpr s3 = builder.getAffineSymbolExpr(3); 323 Value coveredSize = 324 apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount, 325 spec.highTileSize, spec.highTripCount}); 326 Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, 327 coveredSize, tripCount); 328 b.create<cf::AssertOp>( 329 equals, builder.getStringAttr( 330 "could not compute dynamic multi-size tile shapes")); 331 } 332 333 return spec; 334 } 335 336 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less 337 /// than `iterationSize`. 338 static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, 339 OpFoldResult numThreads, 340 OpFoldResult iterationSize) { 341 std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize); 342 std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads); 343 std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize); 344 if (!tileSizeConst || !numThreadsConst || !iterSizeConst) 345 return false; 346 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; 347 } 348 349 /// Build an `affine_max` of all the `vals`. 350 static OpFoldResult buildMax(OpBuilder &b, Location loc, 351 ArrayRef<OpFoldResult> vals) { 352 return affine::makeComposedFoldedAffineMax( 353 b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), 354 vals); 355 } 356 357 /// Build an `affine_min` of all the `vals`. 358 static OpFoldResult buildMin(OpBuilder &b, Location loc, 359 ArrayRef<OpFoldResult> vals) { 360 return affine::makeComposedFoldedAffineMin( 361 b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), 362 vals); 363 } 364 365 /// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given 366 /// number of threads. 367 static void calculateTileOffsetsAndSizes( 368 RewriterBase &b, Location loc, scf::ForallOp forallOp, 369 ArrayRef<OpFoldResult> numThreads, SmallVector<Range> loopRanges, 370 bool omitTileOffsetBoundsCheck, 371 std::optional<ArrayRef<OpFoldResult>> nominalTileSizes, 372 SmallVector<OpFoldResult> &tiledOffsets, 373 SmallVector<OpFoldResult> &tiledSizes) { 374 OpBuilder::InsertionGuard g(b); 375 b.setInsertionPointToStart(forallOp.getBody(0)); 376 377 SmallVector<Value> threadIds = forallOp.getInductionVars(); 378 SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector( 379 numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); }); 380 int64_t nLoops = loopRanges.size(); 381 tiledOffsets.reserve(nLoops); 382 tiledSizes.reserve(nLoops); 383 for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) { 384 bool overflow = loopIdx >= numThreads.size(); 385 bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0); 386 // Degenerate case: take the whole domain. 387 if (overflow || isZero) { 388 tiledOffsets.push_back(loopRanges[loopIdx].offset); 389 tiledSizes.push_back(loopRanges[loopIdx].size); 390 continue; 391 } 392 393 // Tiled case: compute the offset and size. 394 AffineExpr i, j, m, n, o; 395 bindDims(b.getContext(), i, j); 396 bindSymbols(b.getContext(), m, n, o); 397 OpFoldResult size = loopRanges[loopIdx].size; 398 OpFoldResult offset = loopRanges[loopIdx].offset; 399 OpFoldResult threadId = threadIds[threadIdIdx]; 400 // Symbolic fixed max size per thread. 401 // TODO: floor + 0/1 depending on case for better load-balancing. 402 OpFoldResult tileSizePerThread = 403 nominalTileSizes.has_value() 404 ? (*nominalTileSizes)[loopIdx] 405 : makeComposedFoldedAffineApply( 406 b, loc, m.ceilDiv(n), 407 ArrayRef<OpFoldResult>{size, nonZeroNumThreads[threadIdIdx]}); 408 409 // Dynamic offset shifted by threadId * maxSizePerThread. 410 OpFoldResult offsetPerThread = makeComposedFoldedAffineApply( 411 b, loc, i + j * m, {offset, threadId, tileSizePerThread}); 412 // Dynamic upper-bound depending on the threadId. 413 OpFoldResult residualTileSize = makeComposedFoldedAffineApply( 414 b, loc, i + j * m - n, 415 {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size}); 416 if (!isConstantIntValue(residualTileSize, 0)) { 417 OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply( 418 b, loc, -i + m, {offsetPerThread, size}); 419 tileSizePerThread = 420 buildMin(b, loc, {sizeMinusOffsetPerThread, tileSizePerThread}); 421 } 422 423 tiledOffsets.push_back(offsetPerThread); 424 // TODO: if tileSizePerThread <= 0 early exit. 425 if (!omitTileOffsetBoundsCheck && 426 !canOmitTileOffsetInBoundsCheck(tileSizePerThread, 427 nonZeroNumThreads[threadIdIdx], size)) 428 tileSizePerThread = 429 buildMax(b, loc, {b.getIndexAttr(0), tileSizePerThread}); 430 431 tiledSizes.push_back(tileSizePerThread); 432 ++threadIdIdx; 433 } 434 } 435 436 template <typename LoopTy> 437 static FailureOr<TiledLinalgOp> 438 tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes, 439 const LinalgTilingOptions &options) { 440 OpBuilder::InsertionGuard g(b); 441 442 auto nLoops = op.getNumLoops(); 443 // Initial tile sizes may be too big, only take the first nLoops. 444 tileSizes = tileSizes.take_front(nLoops); 445 446 if (llvm::all_of(tileSizes, [](OpFoldResult ofr) { 447 return getConstantIntValue(ofr) == static_cast<int64_t>(0); 448 })) { 449 TiledLinalgOp tiledOp; 450 tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation())); 451 tiledOp.tensorResults.assign(tiledOp.op->result_begin(), 452 tiledOp.op->result_end()); 453 return tiledOp; 454 } 455 456 // 1. Build the tiled loop ranges. 457 SmallVector<OpFoldResult> allShapeSizes = 458 op.createFlatListOfOperandDims(b, op.getLoc()); 459 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); 460 if (!shapeSizesToLoopsMap) 461 return failure(); 462 463 auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges( 464 b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); 465 466 SmallVector<utils::IteratorType, 4> iteratorTypes; 467 for (const auto &attr : enumerate(op.getIteratorTypesArray())) { 468 if (loopIndexToRangeIndex.count(attr.index())) 469 iteratorTypes.push_back(attr.value()); 470 } 471 // If interchangeVector is empty, use the identity. Build the permutation map 472 // otherwise. 473 auto invPermutationMap = 474 AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); 475 if (!options.interchangeVector.empty()) { 476 // Based on the pruned iterations (due to zero tile size), recompute the 477 // interchange vector. 478 SmallVector<unsigned, 4> interchangeVector; 479 interchangeVector.reserve(options.interchangeVector.size()); 480 for (auto pos : options.interchangeVector) { 481 auto it = loopIndexToRangeIndex.find(pos); 482 if (it == loopIndexToRangeIndex.end()) 483 continue; 484 interchangeVector.push_back(it->second); 485 } 486 // Interchange vector is guaranteed to be a permutation, 487 // `inversePermutation` must succeed. 488 invPermutationMap = inversePermutation( 489 AffineMap::getPermutationMap(interchangeVector, b.getContext())); 490 assert(invPermutationMap); 491 SmallVector<int64_t> permutation(interchangeVector.begin(), 492 interchangeVector.end()); 493 applyPermutationToVector(loopRanges, permutation); 494 applyPermutationToVector(iteratorTypes, permutation); 495 } 496 497 // Handle distribution. Create a vector of the same size of loops that are to 498 // be tiled. 499 SmallVector<linalg::ProcInfo> procInfo; 500 if (options.distribution) { 501 procInfo.resize( 502 iteratorTypes.size(), 503 linalg::ProcInfo{nullptr, nullptr, linalg::DistributionMethod::None}); 504 // Collect loop ranges of tiled loops, loops that are parallel. 505 SmallVector<Range> parallelLoopRanges; 506 for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) { 507 if (!isParallelIterator(iteratorType.value())) 508 break; 509 parallelLoopRanges.push_back(loopRanges[iteratorType.index()]); 510 } 511 auto returnedProcInfo = 512 options.distribution->procInfo(b, op.getLoc(), parallelLoopRanges); 513 unsigned procIdIdx = 0; 514 // Update the distribution information for the loops. 515 for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) { 516 if (!isParallelIterator(iteratorType.value())) 517 break; 518 procInfo[iteratorType.index()] = returnedProcInfo[procIdIdx++]; 519 } 520 } 521 522 // 2. Create the tiled loops. 523 LinalgOp res = op; 524 SmallVector<Value, 4> ivs, tensorResults; 525 auto tiledLoopBodyBuilder = 526 [&](OpBuilder &builder, Location loc, ValueRange localIvs, 527 ValueRange operandValuesToUse) -> scf::ValueVector { 528 ivs.assign(localIvs.begin(), localIvs.end()); 529 530 // When an `interchangeVector` is present, it has been applied to the 531 // loop ranges and the iterator types. Apply its inverse to the 532 // resulting loop `ivs` to match the op definition. 533 SmallVector<Value, 4> interchangedIvs; 534 if (!options.interchangeVector.empty()) { 535 for (AffineExpr result : invPermutationMap.getResults()) 536 interchangedIvs.push_back( 537 ivs[cast<AffineDimExpr>(result).getPosition()]); 538 } else { 539 interchangedIvs.assign(ivs.begin(), ivs.end()); 540 } 541 542 // Tile the `operandValuesToUse` that either match the `op` operands 543 // themselves or the tile loop arguments forwarding them. 544 assert(operandValuesToUse.size() == 545 static_cast<size_t>(op->getNumOperands()) && 546 "expect the number of operands and inputs and outputs to match"); 547 SmallVector<Value> valuesToTile = operandValuesToUse; 548 SmallVector<OpFoldResult> sizeBounds = 549 makeComposedFoldedMultiResultAffineApply(b, loc, shapeSizesToLoopsMap, 550 allShapeSizes); 551 SmallVector<Value> tiledOperands = makeTiledShapes( 552 b, loc, op, valuesToTile, getAsOpFoldResult(interchangedIvs), tileSizes, 553 sizeBounds, 554 /*omitPartialTileCheck=*/false); 555 556 SmallVector<Type> resultTensorTypes = 557 getTensorOutputTypes(op, tiledOperands); 558 res = clone(b, op, resultTensorTypes, tiledOperands); 559 tensorResults = 560 insertSlicesBack(builder, loc, op, tiledOperands, res->getResults()); 561 return scf::ValueVector(tensorResults.begin(), tensorResults.end()); 562 }; 563 GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, 564 tiledLoopBodyBuilder, procInfo); 565 566 // 3. Transform IndexOp results w.r.t. the tiling. 567 transformIndexOps(b, res, ivs, loopIndexToRangeIndex); 568 569 // 4. Gather the newly created loops and return them with the new op. 570 SmallVector<Operation *, 8> loops; 571 loops.reserve(ivs.size()); 572 for (auto iv : ivs) { 573 if (isa<BlockArgument>(iv)) { 574 loops.push_back(cast<BlockArgument>(iv).getOwner()->getParentOp()); 575 assert(loops.back() && "no owner found for induction variable!"); 576 } else { 577 // TODO: Instead of doing this, try to recover the ops used instead of the 578 // loop. 579 loops.push_back(nullptr); 580 } 581 } 582 583 // 5. Get the tensor results from the outermost loop if available. Otherwise 584 // use the previously captured `tensorResults`. 585 Operation *outermostLoop = nullptr; 586 for (Operation *loop : loops) 587 if ((outermostLoop = loop)) 588 break; 589 590 return TiledLinalgOp{ 591 res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; 592 } 593 594 FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall( 595 RewriterBase &b, PartialReductionOpInterface op, 596 ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> tileSizes, 597 std::optional<ArrayAttr> mapping) { 598 Location loc = op.getLoc(); 599 OpBuilder::InsertionGuard g(b); 600 601 // Ops implementing PartialReductionOpInterface are expected to implement 602 // TilingInterface. 603 // TODO: proper core mechanism to tie interfaces together. 604 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 605 606 // Ops implementing PartialReductionOpInterface are not necessarily expected 607 // to implement TilingInterface.. This cast is unsafe atm. 608 // TODO: proper core mechanism to tie interfaces together. 609 // TODO: this function requires a pair of interfaces .. 610 auto destinationStyleOp = 611 dyn_cast<DestinationStyleOpInterface>(op.getOperation()); 612 if (!destinationStyleOp) 613 return b.notifyMatchFailure(op, "not a destination style op"); 614 615 // Actually this only work for Linalg ops atm. 616 auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation()); 617 if (!linalgOp) 618 return b.notifyMatchFailure(op, "not a linalg op"); 619 620 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 621 if (op->getNumResults() != 1) 622 return b.notifyMatchFailure( 623 op, "don't support ops with multiple results for now"); 624 625 SmallVector<utils::IteratorType> iterators = 626 tilingInterfaceOp.getLoopIteratorTypes(); 627 SmallVector<unsigned> redDims; 628 linalgOp.getReductionDims(redDims); 629 if (redDims.size() != 1) 630 return b.notifyMatchFailure( 631 op, "only support ops with one reduction dimension."); 632 if (!tileSizes.empty() && tileSizes.size() != numThreads.size()) 633 return b.notifyMatchFailure(op, "if tile sizes are present it must have as " 634 "many elements as number of threads"); 635 int reductionDim = static_cast<int>(redDims.front()); 636 637 if (redDims.front() >= numThreads.size()) 638 return b.notifyMatchFailure( 639 op, "reduction dimension must be mapped to threads"); 640 641 // 1. Create the inital tensor value. 642 FailureOr<SmallVector<Value>> maybeInitTensors = 643 op.generateInitialTensorForPartialReduction(b, loc, numThreads, 644 reductionDim); 645 if (failed(maybeInitTensors)) 646 return b.notifyMatchFailure( 647 op, "Failed to create inital tensors for partial reduction"); 648 SmallVector<Value> &initTensors = maybeInitTensors.value(); 649 650 // Gather destination tensors. 651 SmallVector<Value> dest; 652 if (failed(tensor::getOrCreateDestinations(b, loc, op, dest))) 653 return b.notifyMatchFailure(op, "failed to get destination tensors"); 654 655 Operation *tiledOp = nullptr; 656 657 SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector( 658 numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); }); 659 SmallVector<Value> materializedNonZeroNumThreads = 660 getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads); 661 662 // 2. Create the ForallOp with an empty region. 663 scf::ForallOp forallOp = b.create<scf::ForallOp>( 664 loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors, 665 mapping); 666 667 // 3. Calculate the tile offsets and sizes for the subsequent loop that will 668 // be nested under `forallOp`. 669 SmallVector<OpFoldResult> tiledOffsets, tiledSizes; 670 calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain, 671 /*omitTileOffsetBoundsCheck =*/false, 672 /*nominalTileSizes=*/std::nullopt, tiledOffsets, 673 tiledSizes); 674 675 // 4b. Clone the tileable op and update its destination operands to use the 676 // output bbArgs of the ForallOp. 677 SmallVector<Value> tilingResults; 678 ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs(); 679 { 680 // 4.a. RAII guard, inserting within forallOp, before terminator. 681 OpBuilder::InsertionGuard g(b); 682 b.setInsertionPoint(forallOp.getTerminator()); 683 684 SmallVector<Value> tiledDpsInitOperands; 685 for (Value initOperand : destinationStyleOp.getDpsInits()) { 686 auto *it = llvm::find(dest, initOperand); 687 assert(it != dest.end() && "dest operand not found in dest"); 688 unsigned destNum = std::distance(dest.begin(), it); 689 SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1)); 690 SmallVector<OpFoldResult> outOffsets(numThreads.size(), 691 b.getIndexAttr(0)); 692 SmallVector<OpFoldResult> sizes = tiledSizes; 693 sizes[reductionDim] = b.getIndexAttr(1); 694 outOffsets[reductionDim] = forallOp.getInductionVars()[0]; 695 // TODO: use SubsetExtractOpInterface once it is available. 696 tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>( 697 loc, cast<RankedTensorType>(initOperand.getType()), 698 destBbArgs[destNum], outOffsets, sizes, strides)); 699 } 700 701 // 4.b. Clone the op and update init operands. 702 // We cannot use a IRMapping here because it can replace 703 // different OpOperands with the same value. 704 Operation *clonedOp = b.clone(*op.getOperation()); 705 b.modifyOpInPlace(clonedOp, [&]() { 706 for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal( 707 cast<DestinationStyleOpInterface>(clonedOp).getDpsInitsMutable(), 708 tiledDpsInitOperands)) { 709 initOperandPtr.set(tiledInitValue); 710 } 711 }); 712 713 // 5. Tile the cloned op and delete the clone. 714 if (tileSizes.empty()) { 715 FailureOr<TilingResult> tilingResult = 716 cast<TilingInterface>(clonedOp).getTiledImplementation( 717 b, tiledOffsets, tiledSizes); 718 if (failed(tilingResult)) 719 return clonedOp->emitError("Failed to tile op: "); 720 if (tilingResult->tiledOps.size() != 1) { 721 return clonedOp->emitError("expected a single produced tiled op, got ") 722 << tilingResult->tiledOps.size(); 723 } 724 tiledOp = tilingResult->tiledOps.front(); 725 tilingResults = tilingResult->tiledValues; 726 } else { 727 LinalgTilingOptions options; 728 FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>( 729 b, cast<LinalgOp>(clonedOp), tileSizes, options); 730 if (failed(maybeTiled)) 731 return b.notifyMatchFailure(op, "failed tileLinalgOpImpl"); 732 733 SmallVector<Value> ids = forallOp.getInductionVars(); 734 mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids, 735 materializedNonZeroNumThreads); 736 if (maybeTiled->loops.size() != 1) { 737 return clonedOp->emitError("expected a single produced loop"); 738 } 739 tiledOp = maybeTiled->op; 740 tilingResults = maybeTiled->loops.front()->getResults(); 741 } 742 743 b.eraseOp(clonedOp); 744 } 745 746 // 6. Insert the partial reductions back into a new tensor. 747 for (auto [index, result, bbArg] : llvm::zip( 748 llvm::seq<unsigned>(0, dest.size()), tilingResults, destBbArgs)) { 749 // 6.a. Partial subset information is inserted just before the terminator. 750 OpBuilder::InsertionGuard g(b); 751 b.setInsertionPoint(forallOp.getTerminator()); 752 753 SmallVector<OpFoldResult> resultOffsets, resultSizes; 754 if (failed(tilingInterfaceOp.getResultTilePosition( 755 b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes))) 756 return op->emitOpError("output offsets couldn't be calculated"); 757 SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank; 758 int64_t offIdx = 0; 759 int64_t sizeIdx = 0; 760 for (int64_t i = 0, e = numThreads.size(); i < e; ++i) { 761 if (i == reductionDim) { 762 resultOffsetsRank.push_back(forallOp.getInductionVars()[0]); 763 resultSizesRank.push_back(b.getIndexAttr(1)); 764 continue; 765 } 766 resultOffsetsRank.push_back(resultOffsets[offIdx++]); 767 resultSizesRank.push_back(resultSizes[sizeIdx++]); 768 } 769 SmallVector<OpFoldResult> strides(resultSizesRank.size(), 770 b.getIndexAttr(1)); 771 772 // 6.b. Parallel insertions are inserted at the end of the combining 773 // terminator. 774 b.setInsertionPointToEnd(forallOp.getTerminator().getBody()); 775 b.create<tensor::ParallelInsertSliceOp>( 776 loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides); 777 } 778 779 // 7. Merge the partial reductions. 780 b.setInsertionPointAfter(forallOp); 781 FailureOr<MergeResult> mergeResult = 782 op.mergeReductions(b, loc, forallOp->getResults(), reductionDim); 783 if (failed(mergeResult)) { 784 return failure(); 785 } 786 b.replaceOp(op, mergeResult->replacements); 787 788 // 8. Return. 789 ForallReductionTilingResult results; 790 results.initialValues = initTensors; 791 results.loops = forallOp; 792 results.parallelTiledOps.push_back(tiledOp); 793 results.mergeOps.append(mergeResult->mergeOps); 794 return results; 795 } 796 797 template <typename LoopTy> 798 FailureOr<TiledLinalgOp> static tileLinalgOpImpl( 799 RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) { 800 OpBuilder::InsertionGuard g(b); 801 b.setInsertionPoint(op); 802 803 if (!options.tileSizeComputationFunction) 804 return failure(); 805 806 // Enforce the convention that "tiling by zero" skips tiling a particular 807 // dimension. This convention is significantly simpler to handle instead of 808 // adjusting affine maps to account for missing dimensions. 809 auto nLoops = op.getNumLoops(); 810 SmallVector<OpFoldResult> tileSizeVector = 811 getAsOpFoldResult(options.tileSizeComputationFunction(b, op)); 812 if (tileSizeVector.size() < nLoops) { 813 tileSizeVector.append(nLoops - tileSizeVector.size(), b.getIndexAttr(0)); 814 } 815 816 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options); 817 } 818 819 FailureOr<TiledLinalgOp> 820 mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op, 821 const LinalgTilingOptions &options) { 822 switch (options.loopType) { 823 case LinalgTilingLoopType::Loops: 824 return tileLinalgOpImpl<scf::ForOp>(b, op, options); 825 case LinalgTilingLoopType::ParallelLoops: 826 return tileLinalgOpImpl<scf::ParallelOp>(b, op, options); 827 default:; 828 } 829 return failure(); 830 } 831 832 namespace { 833 /// Helper classes for type list expansion. 834 template <typename... OpTypes> 835 class CanonicalizationPatternList; 836 837 template <> 838 class CanonicalizationPatternList<> { 839 public: 840 static void insert(RewritePatternSet &patterns) {} 841 }; 842 843 template <typename OpTy, typename... OpTypes> 844 class CanonicalizationPatternList<OpTy, OpTypes...> { 845 public: 846 static void insert(RewritePatternSet &patterns) { 847 OpTy::getCanonicalizationPatterns(patterns, patterns.getContext()); 848 CanonicalizationPatternList<OpTypes...>::insert(patterns); 849 } 850 }; 851 } // namespace 852 853 RewritePatternSet 854 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) { 855 RewritePatternSet patterns(ctx); 856 populateLinalgTilingCanonicalizationPatterns(patterns); 857 return patterns; 858 } 859 860 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns( 861 RewritePatternSet &patterns) { 862 auto *ctx = patterns.getContext(); 863 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, ctx); 864 affine::AffineForOp::getCanonicalizationPatterns(patterns, ctx); 865 affine::AffineMinOp::getCanonicalizationPatterns(patterns, ctx); 866 affine::AffineMaxOp::getCanonicalizationPatterns(patterns, ctx); 867 arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); 868 869 memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); 870 memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); 871 872 scf::ForOp::getCanonicalizationPatterns(patterns, ctx); 873 scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); 874 875 tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); 876 tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx); 877 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); 878 tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); 879 tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); 880 ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns); 881 882 CanonicalizationPatternList< 883 #define GET_OP_LIST 884 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 885 >::insert(patterns); 886 } 887