1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 25 #include "mlir/Interfaces/TilingInterface.h" 26 #include "llvm/Support/Debug.h" 27 #include <optional> 28 29 #define DEBUG_TYPE "tile-using-interface" 30 31 using namespace mlir; 32 33 scf::SCFTilingOptions & 34 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) { 35 assert(!tileSizeComputationFunction && "tile sizes already set"); 36 auto tileSizes = llvm::to_vector(ts); 37 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 38 return tileSizes; 39 }; 40 return *this; 41 } 42 43 /// Helper method to adjust the interchange vector to match the iteration 44 /// domain. 45 static SmallVector<int64_t> 46 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 47 size_t iterationDomainSize) { 48 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 49 if (filledVector.size() < iterationDomainSize) { 50 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 51 filledVector.append(range.begin(), range.end()); 52 } 53 if (filledVector.size() > iterationDomainSize) 54 filledVector.resize(iterationDomainSize); 55 return filledVector; 56 } 57 58 /// Convert a list of ops of type `SrcOpTy` to list of `Operation *`. 59 template <typename SrcOpTy> 60 static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) { 61 return llvm::to_vector( 62 llvm::map_range(ops, [](auto op) -> Operation * { return op; })); 63 } 64 template <typename SrcOpTy> 65 static SmallVector<Operation *> 66 getAsOperations(const SmallVector<SrcOpTy> &ops) { 67 return getAsOperations(ArrayRef<SrcOpTy>(ops)); 68 } 69 70 /// Convert a list of `Operation *` to a list of `DstOpTy. 71 template <typename DstOpTy> 72 static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) { 73 return llvm::to_vector( 74 llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); })); 75 } 76 template <typename DstOpTy> 77 static SmallVector<DstOpTy> 78 castToTypedOperations(const SmallVector<Operation *> &ops) { 79 return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops)); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // tileUsingSCFForOp implementation. 84 //===----------------------------------------------------------------------===// 85 86 // Check if `stride` evenly divides the trip count `size - offset`. 87 static bool tileDividesIterationDomain(Range loopRange) { 88 std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 89 if (!offsetAsInt) 90 return false; 91 std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 92 if (!sizeAsInt) 93 return false; 94 std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 95 if (!strideAsInt) 96 return false; 97 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 98 } 99 100 /// Returns the bounded tile size given the current `iv`, `loopRange` and 101 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. 102 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 103 Range loopRange, Value iv, 104 OpFoldResult tileSize) { 105 std::optional<int64_t> ts = getConstantIntValue(tileSize); 106 if (ts && ts.value() == 1) 107 return tileSize; 108 109 if (tileDividesIterationDomain( 110 Range{loopRange.offset, loopRange.size, tileSize})) 111 return tileSize; 112 113 // The tile size to use (to avoid out of bounds access) is minimum of 114 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 115 // loop. 116 AffineExpr s0, s1, d0; 117 bindDims(b.getContext(), d0); 118 bindSymbols(b.getContext(), s0, s1); 119 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); 120 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 121 return affine::makeComposedFoldedAffineMin( 122 b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size}); 123 } 124 125 /// Clones the operation and updates the destination if the operation 126 /// implements the `DestinationStyleOpInterface`. 127 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, 128 Operation *op, 129 ValueRange newDestArgs) { 130 Operation *clonedOp = rewriter.clone(*op); 131 if (newDestArgs.empty()) 132 return clonedOp; 133 if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp)) 134 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs); 135 return clonedOp; 136 } 137 138 /// Generate an empty loop nest that represents the tiled loop nest shell. 139 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 140 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 141 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 142 /// the tile processed within the inner most loop. 143 /// Note that this methods adds `scf.yield` operation for all but the innermost 144 /// loop. These yield the value returned by the immediately inner loop. The 145 /// caller is expected to add the scf.yield operation for the innermost loop. 146 static SmallVector<scf::ForOp> generateTileLoopNest( 147 OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges, 148 ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets, 149 SmallVector<OpFoldResult> &sizes, ValueRange destinationTensors = {}) { 150 if (loopRanges.empty()) 151 return {}; 152 assert(loopRanges.size() == tileSizes.size() && 153 "expected as many tile sizes as loop ranges"); 154 OpBuilder::InsertionGuard guard(builder); 155 SmallVector<scf::ForOp> loops; 156 offsets.resize(loopRanges.size()); 157 sizes.resize(loopRanges.size()); 158 159 for (auto loopRange : llvm::enumerate(loopRanges)) { 160 Value offset = 161 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); 162 Value size = 163 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); 164 Value tileSize = getValueOrCreateConstantIndexOp( 165 builder, loc, tileSizes[loopRange.index()]); 166 // No loops if tile size is zero. Set offset and size to the loop 167 // offset and size. 168 if (matchPattern(tileSize, m_Zero())) { 169 offsets[loopRange.index()] = offset; 170 sizes[loopRange.index()] = size; 171 continue; 172 } 173 174 auto loop = builder.create<scf::ForOp>( 175 loc, offset, size, tileSize, destinationTensors, 176 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 177 ValueRange /*iterArgs*/) { 178 sizes[loopRange.index()] = 179 getBoundedTileSize(bodyBuilder, bodyLoc, loopRange.value(), iv, 180 getAsOpFoldResult(tileSize)); 181 }); 182 offsets[loopRange.index()] = loop.getInductionVar(); 183 loops.push_back(loop); 184 builder.setInsertionPointToEnd(loop.getBody()); 185 destinationTensors = loop.getRegionIterArgs(); 186 } 187 188 // Add the scf.yield operations for all the outer loops. 189 if (!loops.empty()) { 190 for (auto [outerLoop, innerLoop] : 191 llvm::zip_equal(MutableArrayRef(loops).drop_back(), 192 MutableArrayRef(loops).drop_front())) { 193 builder.setInsertionPointToEnd(outerLoop.getBody()); 194 builder.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop.getResults()); 195 } 196 } 197 return loops; 198 } 199 200 /// Method to add new init values to a loop nest. Updates `loops` in-place with 201 /// new loops that use the `newInitValues`. 202 /// The outer-loops are updated to yield the new result values of the inner 203 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get 204 /// the additional values to yield form the innermost loop. 205 static void addInitOperandsToLoopNest( 206 RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loops, 207 ValueRange newInitValues, 208 llvm::function_ref<SmallVector<Value>(RewriterBase &rewriter, Value iv, 209 ValueRange newRegionIterArgs)> 210 getNewYieldValsFn) { 211 SmallVector<scf::ForOp> newLoops; 212 if (loops.empty()) 213 return; 214 OpBuilder::InsertionGuard g(rewriter); 215 rewriter.setInsertionPoint(loops.front()); 216 for (auto &loop : loops) { 217 rewriter.setInsertionPoint(loop); 218 219 // Create a new loop with the new init values for this loop. 220 SmallVector<Value> newInits = llvm::to_vector(loop.getInitArgs()); 221 newInits.append(newInitValues.begin(), newInitValues.end()); 222 auto newLoop = rewriter.create<scf::ForOp>( 223 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), 224 loop.getStep(), newInits, 225 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); 226 227 // Merge the body of the new loop with the body of the old loops. 228 SmallVector<Value> sourceBlockArgs; 229 sourceBlockArgs.push_back(newLoop.getInductionVar()); 230 auto newRegionIterArgs = newLoop.getRegionIterArgs(); 231 sourceBlockArgs.append( 232 newRegionIterArgs.begin(), 233 std::next(newRegionIterArgs.begin(), loop.getNumResults())); 234 rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(), sourceBlockArgs); 235 rewriter.replaceOp(loop, 236 newLoop.getResults().take_front(loop.getNumResults())); 237 loop = newLoop; 238 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size()); 239 } 240 241 // Update the loop body of the innermost loop to get new yield values. 242 scf::ForOp innerMostLoop = loops.back(); 243 auto innerMostYieldOp = 244 cast<scf::YieldOp>(innerMostLoop.getBody()->getTerminator()); 245 rewriter.setInsertionPoint(innerMostYieldOp); 246 SmallVector<Value> newYieldVals = 247 getNewYieldValsFn(rewriter, innerMostLoop.getInductionVar(), 248 innerMostLoop.getRegionIterArgs()); 249 SmallVector<Value> newYieldOperands = 250 llvm::to_vector(innerMostYieldOp->getOperands()); 251 newYieldOperands.append(newYieldVals); 252 rewriter.replaceOpWithNewOp<scf::YieldOp>(innerMostYieldOp, newYieldOperands); 253 254 // Make all other loops except the innermost loops yield the values returned 255 // by the inner loop. 256 for (auto [outerLoop, innerLoop] : 257 llvm::zip_equal(loops.drop_back(), loops.drop_front())) { 258 auto outerLoopYield = 259 cast<scf::YieldOp>(outerLoop.getBody()->getTerminator()); 260 SmallVector<Value> newYields = 261 llvm::to_vector(outerLoopYield.getOperands()); 262 ValueRange additionalYields = 263 innerLoop.getResults().take_back(newInitValues.size()); 264 newYields.append(additionalYields.begin(), additionalYields.end()); 265 rewriter.setInsertionPoint(outerLoopYield); 266 rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields); 267 } 268 } 269 270 /// Implementation of tiling transformation of `op` that implements the 271 /// `TilingInterface` using `scf.for` to iterate over the tiles. 272 FailureOr<scf::SCFTilingResult> 273 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, 274 const scf::SCFTilingOptions &options) { 275 OpBuilder::InsertionGuard guard(rewriter); 276 rewriter.setInsertionPointAfter(op); 277 278 if (!options.tileSizeComputationFunction) { 279 return rewriter.notifyMatchFailure( 280 op, "missing tile size computation function"); 281 } 282 283 // 1. Get the range of the loops that are represented by the operation. 284 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 285 size_t numLoops = iterationDomain.size(); 286 if (numLoops == 0) { 287 return rewriter.notifyMatchFailure( 288 op, "unable to tile op with no iteration domain"); 289 } 290 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 291 // skips tiling a particular dimension. This convention is significantly 292 // simpler to handle instead of adjusting affine maps to account for missing 293 // dimensions. 294 SmallVector<OpFoldResult> tileSizeVector = 295 options.tileSizeComputationFunction(rewriter, op); 296 if (tileSizeVector.size() < iterationDomain.size()) { 297 auto zero = rewriter.getIndexAttr(0); 298 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 299 } 300 301 // 3. Find the destination tensors to use for the operation. 302 SmallVector<Value> destinationTensors; 303 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 304 destinationTensors))) { 305 return rewriter.notifyMatchFailure(op, 306 "unable to create destination tensors"); 307 } 308 309 SmallVector<OpFoldResult> offsets, sizes; 310 SmallVector<scf::ForOp> forLoops; 311 { 312 // If there is an interchange specified, permute the iteration domain and 313 // the tile sizes. 314 SmallVector<int64_t> interchangeVector; 315 if (!options.interchangeVector.empty()) { 316 interchangeVector = fillInterchangeVector(options.interchangeVector, 317 iterationDomain.size()); 318 } 319 if (!interchangeVector.empty()) { 320 if (!isPermutationVector(interchangeVector)) { 321 return rewriter.notifyMatchFailure( 322 op, "invalid intechange vector, not a permutation of the entire " 323 "iteration space"); 324 } 325 326 applyPermutationToVector(iterationDomain, interchangeVector); 327 applyPermutationToVector(tileSizeVector, interchangeVector); 328 } 329 330 // 4. Materialize an empty loop nest that iterates over the tiles. These 331 // loops for now do not return any values even if the original operation has 332 // results. 333 forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain, 334 tileSizeVector, offsets, sizes, 335 destinationTensors); 336 337 if (!interchangeVector.empty()) { 338 auto inversePermutation = invertPermutationVector(interchangeVector); 339 applyPermutationToVector(offsets, inversePermutation); 340 applyPermutationToVector(sizes, inversePermutation); 341 } 342 } 343 344 LLVM_DEBUG({ 345 if (!forLoops.empty()) { 346 llvm::dbgs() << "LoopNest shell :\n"; 347 forLoops.front().dump(); 348 llvm::dbgs() << "\n"; 349 } 350 }); 351 352 // 5. Generate the tiled implementation within the inner most loop. 353 SmallVector<Value> clonedOpDestination = destinationTensors; 354 if (!forLoops.empty()) { 355 rewriter.setInsertionPointToEnd(forLoops.back().getBody()); 356 clonedOpDestination = 357 llvm::map_to_vector(forLoops.back().getRegionIterArgs(), 358 [](BlockArgument b) -> Value { return b; }); 359 } 360 361 // 5a. Clone the operation within the loop body. 362 auto clonedOp = cast<TilingInterface>( 363 cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination)); 364 365 // 5b. Early return cloned op if tiling is not happening. We can not return 366 // the original op because it could lead to 367 // `rewriter.replaceOp(op, op->getResults())` and user would get crash. 368 if (llvm::all_of(tileSizeVector, isZeroIndex)) { 369 return scf::SCFTilingResult{/*tiledOps=*/{clonedOp}, /*loops=*/{}, 370 clonedOp->getResults()}; 371 } 372 373 // 5c. Tile the cloned operation. 374 FailureOr<TilingResult> tiledImplementation = 375 clonedOp.getTiledImplementation(rewriter, offsets, sizes); 376 if (failed(tiledImplementation)) { 377 return rewriter.notifyMatchFailure(op, "failed to tile operation"); 378 } 379 380 // 5d. Delete the cloned operation. 381 rewriter.eraseOp(clonedOp); 382 383 // If loops are empty, the tiled op is used as the replacement for the untiled 384 // op. 385 if (forLoops.empty()) { 386 return scf::SCFTilingResult{tiledImplementation->tiledOps, 387 getAsOperations(forLoops), 388 tiledImplementation->tiledValues}; 389 } 390 391 if (op->getNumResults() == 0) { 392 // The innermost loop does not have a `scf.yield` yet. There is nothing to 393 // return, so generate an empty `scf.yield` operation. 394 rewriter.setInsertionPointToEnd(forLoops.back().getBody()); 395 rewriter.create<scf::YieldOp>(op->getLoc()); 396 return scf::SCFTilingResult{ 397 tiledImplementation->tiledOps, getAsOperations(forLoops), {}}; 398 } 399 400 // 6. Yield all the results of the tiled operation. 401 int64_t numResults = op->getNumResults(); 402 SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults), 403 resultSizesList(numResults); 404 SmallVector<Value> yieldedValues; 405 for (auto [index, tiledValue] : 406 llvm::enumerate(tiledImplementation->tiledValues)) { 407 SmallVector<OpFoldResult> resultOffsets, resultSizes; 408 if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes, 409 resultOffsets, resultSizes))) { 410 return rewriter.notifyMatchFailure( 411 op, "failed to get slice of result produced"); 412 } 413 SmallVector<OpFoldResult> resultStrides(resultOffsets.size(), 414 rewriter.getIndexAttr(1)); 415 auto insertSlice = rewriter.create<tensor::InsertSliceOp>( 416 op->getLoc(), tiledValue, clonedOpDestination[index], resultOffsets, 417 resultSizes, resultStrides); 418 yieldedValues.push_back(insertSlice); 419 } 420 rewriter.create<scf::YieldOp>(op->getLoc(), yieldedValues); 421 422 SmallVector<Value> replacements = llvm::map_to_vector( 423 forLoops.front().getResults(), [](OpResult r) -> Value { return r; }); 424 LLVM_DEBUG({ 425 if (!forLoops.empty()) { 426 llvm::dbgs() << "After tiled implementation :\n"; 427 forLoops.front().dump(); 428 llvm::dbgs() << "\n"; 429 } 430 }); 431 return scf::SCFTilingResult{tiledImplementation->tiledOps, 432 getAsOperations(forLoops), replacements}; 433 } 434 435 FailureOr<scf::SCFReductionTilingResult> 436 mlir::scf::tileReductionUsingScf(RewriterBase &b, 437 PartialReductionOpInterface op, 438 ArrayRef<OpFoldResult> tileSizes) { 439 Location loc = op.getLoc(); 440 // Ops implementing PartialReductionOpInterface are expected to implement 441 // TilingInterface. 442 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 443 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 444 auto tileSizesVector = llvm::to_vector(tileSizes); 445 if (tileSizesVector.size() < iterationDomain.size()) { 446 auto zero = b.getIndexAttr(0); 447 tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(), 448 zero); 449 } 450 if (op->getNumResults() != 1) 451 return b.notifyMatchFailure( 452 op, "don't support ops with multiple results for now"); 453 SmallVector<utils::IteratorType> iterators = 454 tilingInterfaceOp.getLoopIteratorTypes(); 455 456 SmallVector<int> reductionDims; 457 for (auto [idx, iteratorType] : 458 llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { 459 if (iteratorType == utils::IteratorType::reduction) 460 reductionDims.push_back(idx); 461 } 462 463 // 2. create the inital tensor value. 464 FailureOr<Operation *> identityTensor = 465 op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector, 466 reductionDims); 467 if (failed(identityTensor)) 468 return b.notifyMatchFailure(op, 469 "cannot create a tensor of identity value."); 470 // 3. Create the nested loops. 471 SmallVector<OpFoldResult> offsets, sizes; 472 SmallVector<scf::ForOp> loops = 473 generateTileLoopNest(b, loc, iterationDomain, tileSizesVector, offsets, 474 sizes, identityTensor.value()->getResults()); 475 476 // 4. Generate the tiled implementation within the inner most loop. 477 // 4a. Clone the operation within the loop body. 478 SmallVector<Value> clonedOpDestination = 479 llvm::map_to_vector(identityTensor.value()->getResults(), 480 [](OpResult res) -> Value { return res; }); 481 if (!loops.empty()) { 482 b.setInsertionPointToEnd(loops.back().getBody()); 483 clonedOpDestination = 484 llvm::map_to_vector(loops.back().getRegionIterArgs(), 485 [](BlockArgument b) -> Value { return b; }); 486 } 487 auto clonedOp = cast<PartialReductionOpInterface>( 488 cloneOpAndUpdateDestinationArgs(b, op, clonedOpDestination)); 489 490 // 4b. Tile the cloned operation. 491 Operation *parallelOp = clonedOp.tileToPartialReduction( 492 b, loc, clonedOpDestination, offsets, sizes, reductionDims); 493 // 4c. Delete the cloned operation. 494 b.eraseOp(clonedOp); 495 496 SmallVector<OpFoldResult> outSizes; 497 for (size_t i = 0; i < offsets.size(); i++) { 498 outSizes.push_back( 499 tensor::getMixedSize(b, loc, parallelOp->getResult(0), i)); 500 } 501 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); 502 SmallVector<OpFoldResult> outStrides(outOffsets.size(), b.getIndexAttr(1)); 503 SmallVector<Value> yieldedVals; 504 auto bbArgs = loops.back().getRegionIterArgs(); 505 for (auto [result, bbArg] : llvm::zip(parallelOp->getResults(), bbArgs)) { 506 Value insert = b.create<tensor::InsertSliceOp>( 507 loc, result, bbArg, outOffsets, outSizes, outStrides); 508 yieldedVals.push_back(insert); 509 } 510 b.create<scf::YieldOp>(loc, yieldedVals); 511 512 SmallVector<Value> replacements = llvm::map_to_vector( 513 loops.front().getResults(), [](OpResult r) -> Value { return r; }); 514 515 // 5. Apply the merge reduction to combine all the partial values. 516 b.setInsertionPointAfter(*loops.begin()); 517 Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims); 518 b.replaceOp(op, mergeOp->getResults()); 519 520 SCFReductionTilingResult results; 521 results.initialOp = *identityTensor; 522 results.loops = std::move(loops); 523 results.parallelTiledOp = parallelOp; 524 results.mergeOp = mergeOp; 525 return results; 526 } 527 528 //===----------------------------------------------------------------------===// 529 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. 530 //===----------------------------------------------------------------------===// 531 532 /// Return the untiled producer whose slice is used in a tiled consumer. The 533 /// method traverses the tile loop nest (`loops`) if needed, and returns the 534 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 535 /// indicates that this is a destination operand of the consumer. If there was 536 /// no loop traversal needed, the second value of the returned tuple is empty. 537 static std::tuple<OpResult, std::optional<OpOperand *>> 538 getUntiledProducerFromSliceSource(OpOperand *source, 539 ArrayRef<scf::ForOp> loops) { 540 std::optional<OpOperand *> destinationIterArg; 541 auto loopIt = loops.rbegin(); 542 while (auto iterArg = dyn_cast<BlockArgument>(source->get())) { 543 scf::ForOp loop = *loopIt; 544 if (iterArg.getOwner()->getParentOp() != loop) 545 break; 546 source = loop.getTiedLoopInit(iterArg); 547 loopIt++; 548 } 549 if (loopIt == loops.rend()) 550 destinationIterArg = source; 551 return {dyn_cast<OpResult>(source->get()), destinationIterArg}; 552 } 553 554 /// Implementation of fusing producer of a single slice by computing the 555 /// slice of the producer in-place. 556 std::optional<scf::SCFFuseProducerOfSliceResult> 557 mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, 558 tensor::ExtractSliceOp candidateSliceOp, 559 MutableArrayRef<scf::ForOp> loops) { 560 // 1. Get the producer of the source (potentially walking through 561 // `iter_args` of nested `scf.for`) 562 auto [fusableProducer, destinationInitArg] = 563 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), 564 loops); 565 if (!fusableProducer) 566 return std::nullopt; 567 unsigned resultNumber = fusableProducer.getResultNumber(); 568 569 OpBuilder::InsertionGuard g(rewriter); 570 rewriter.setInsertionPoint(candidateSliceOp); 571 572 // 2. Clone the fused producer 573 // 2a. Compute the destination operands to use for the cloned operation. 574 SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors; 575 Operation *fusableProducerOp = fusableProducer.getOwner(); 576 if (isa<DestinationStyleOpInterface>(fusableProducerOp) && 577 failed(tensor::getOrCreateDestinations( 578 rewriter, fusableProducerOp->getLoc(), fusableProducerOp, 579 origDestinationTensors))) 580 return std::nullopt; 581 582 clonedOpDestinationTensors = origDestinationTensors; 583 if (destinationInitArg && 584 isa<DestinationStyleOpInterface>(fusableProducerOp)) { 585 // 2b. If the producer is also destination style, then to maintain the 586 // destination passing style, update the destination of the producer to be 587 // the source of the slice. 588 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource(); 589 } 590 // 2c. Clone the fused producer. 591 Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs( 592 rewriter, fusableProducerOp, clonedOpDestinationTensors); 593 // 2d. Update the source of the candidateSlice to be the cloned producer. 594 // Easier to just clone the slice with different source since replacements 595 // and DCE of cloned ops becomes easier 596 SmallVector<Value> candidateSliceOpOperands = 597 llvm::to_vector(candidateSliceOp->getOperands()); 598 candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber); 599 tensor::ExtractSliceOp clonedCandidateSliceOp = 600 mlir::clone(rewriter, candidateSliceOp, 601 candidateSliceOp->getResultTypes(), candidateSliceOpOperands); 602 603 // 3. Generate the tiled implementation of the producer of the source 604 FailureOr<TilingResult> tileAndFuseResult = 605 tensor::replaceExtractSliceWithTiledProducer( 606 rewriter, clonedCandidateSliceOp, 607 clonedProducerOp->getResult(resultNumber)); 608 if (failed(tileAndFuseResult)) 609 return std::nullopt; 610 // Note: Do not delete the candidateSliceOp, since its passed in from the 611 // caller. 612 rewriter.replaceAllUsesWith(candidateSliceOp, 613 tileAndFuseResult->tiledValues[0]); 614 rewriter.eraseOp(clonedCandidateSliceOp); 615 rewriter.eraseOp(clonedProducerOp); 616 617 // 3. If the slice is for a destination operand, for example, 618 // 619 // ```mlir 620 // %0 = linalg.init 621 // %1 = linalg.fill .. outs(%0 : ) 622 // %2 = scf.for .. iter_args(%arg0 = %1) { 623 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 624 // %4 = tensor.extract_slice %arg1 [..] 625 // .. = linalg.matmul .. outs(%4 : ) 626 // } 627 // } 628 // ``` 629 // 630 // the IR is currently 631 // 632 // ``` 633 // %0 = linalg.init 634 // %1 = linalg.fill 635 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 636 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 637 // %4 = tensor.extract_slice %arg1[..] 638 // %5 = linalg.fill .. outs(%4 : ) 639 // .. = linalg.matmul .. outs(%5 : ) 640 // } 641 // } 642 // ``` 643 // 644 // The untiled `linalg.fill` is still used as the `init_value` since it 645 // was originally a destination operand of the untiled `linalg.matmul`. 646 // When fusing an operand that is a destination operand, the iter_arg of 647 // the outer most loop should be changed to use the destination of the 648 // fused operation. With this the IR will be. 649 // 650 // ``` 651 // %0 = linalg.init 652 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 653 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 654 // %3 = tensor.extract_slice %arg1[..] 655 // %4 = linalg.fill .. outs(%3 : ) 656 // .. = linalg.matmul .. outs(%4 : ) 657 // } 658 // } 659 // ``` 660 if (destinationInitArg && 661 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) { 662 loops.front() 663 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] 664 .set(origDestinationTensors[resultNumber]); 665 } 666 return scf::SCFFuseProducerOfSliceResult{fusableProducer, 667 tileAndFuseResult->tiledValues[0], 668 tileAndFuseResult->tiledOps}; 669 } 670 671 /// Reconstruct the fused producer from within the tiled-and-fused code. 672 void mlir::scf::yieldReplacementForFusedProducer( 673 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, 674 scf::SCFFuseProducerOfSliceResult fusedProducerInfo, 675 MutableArrayRef<scf::ForOp> loops) { 676 if (loops.empty()) 677 return; 678 679 OpResult fusableProducer = fusedProducerInfo.origProducer; 680 Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer; 681 FailureOr<Value> initValue = tensor::getOrCreateDestination( 682 rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); 683 if (succeeded(initValue)) { 684 685 auto newYieldValuesFn = 686 [&](RewriterBase &innerRewriter, Value iv, 687 ValueRange newRegionIterArgs) -> SmallVector<Value> { 688 OpBuilder::InsertionGuard g(innerRewriter); 689 if (auto tiledDestStyleOp = 690 tiledAndFusedProducer 691 .getDefiningOp<DestinationStyleOpInterface>()) { 692 rewriter.setInsertionPoint(tiledDestStyleOp); 693 BlockArgument newRegionArg = loops.back().getRegionIterArgs().back(); 694 auto destSlice = rewriter.create<tensor::ExtractSliceOp>( 695 sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(), 696 sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); 697 unsigned resultNumber = fusableProducer.getResultNumber(); 698 rewriter.updateRootInPlace(tiledDestStyleOp, [&]() { 699 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); 700 }); 701 } 702 Block *block = rewriter.getInsertionPoint()->getBlock(); 703 rewriter.setInsertionPoint(block->getTerminator()); 704 Value replacement = rewriter.create<tensor::InsertSliceOp>( 705 fusedProducerInfo.origProducer.getLoc(), 706 fusedProducerInfo.tiledAndFusedProducer, 707 loops.back().getRegionIterArgs().back(), sliceOp.getMixedOffsets(), 708 sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); 709 return {replacement}; 710 }; 711 712 addInitOperandsToLoopNest(rewriter, loops, 713 SmallVector<Value>{initValue.value()}, 714 newYieldValuesFn); 715 } 716 } 717 718 /// Implementation of tile consumer and fuse producer greedily. 719 FailureOr<scf::SCFTileAndFuseResult> 720 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( 721 RewriterBase &rewriter, TilingInterface consumer, 722 const scf::SCFTileAndFuseOptions &options) { 723 // This transformation is only valid for ops that return values (i.e. not 724 // valid to use with operations that have memref operands). 725 if (!consumer->getNumResults()) { 726 return rewriter.notifyMatchFailure( 727 consumer, "invalid pattern for op with no results"); 728 } 729 730 // 1. First tile the consumer. 731 SetVector<Operation *> fusedProducers, tiledAndFusedOps; 732 llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum; 733 FailureOr<scf::SCFTilingResult> tilingResult = 734 tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); 735 if (failed(tilingResult)) 736 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 737 for (auto *tiledOp : tilingResult->tiledOps) 738 tiledAndFusedOps.insert(tiledOp); 739 SmallVector<scf::ForOp> forLoops = 740 castToTypedOperations<scf::ForOp>(tilingResult->loops); 741 742 // If there are no loops generated, fusion is immaterial. 743 if (forLoops.empty()) { 744 DenseMap<Value, Value> replacements; 745 for (auto [origVal, replacement] : 746 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) { 747 replacements[origVal] = replacement; 748 } 749 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, 750 getAsOperations(forLoops), replacements}; 751 } 752 753 // To keep track of replacements for now just record the map from the original 754 // untiled value to the result number of the for loop. Since the loop gets 755 // potentially replaced during fusion, keeping the value directly wont work. 756 DenseMap<Value, size_t> origValToResultNumber; 757 for (auto [index, result] : llvm::enumerate(consumer->getResults())) { 758 origValToResultNumber[result] = index; 759 } 760 761 // 2. Typically, the operands of the tiled operation are slices of the 762 // operands of the untiled operation. These are expressed in IR using 763 // `tensor.extract_slice` operations with source being the operands of the 764 // untiled operation. Create a worklist of these `tensor.extract_slice` 765 // operations. If the producers of the source of the `tensor.extract_slice` 766 // can be tiled such that the tiled value is generated in-place, that 767 // effectively tiles + fuses the operations. 768 auto addCandidateSlices = [](Operation *fusedOp, 769 std::deque<tensor::ExtractSliceOp> &candidates) { 770 for (Value operand : fusedOp->getOperands()) 771 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 772 candidates.push_back(sliceOp); 773 }; 774 775 std::deque<tensor::ExtractSliceOp> candidates; 776 addCandidateSlices(tiledAndFusedOps.back(), candidates); 777 OpBuilder::InsertionGuard g(rewriter); 778 while (!candidates.empty()) { 779 // Traverse the slices in BFS fashion. 780 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 781 candidates.pop_front(); 782 783 // Find the original producer of the slice. 784 auto [fusableProducer, destinationInitArg] = 785 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), 786 forLoops); 787 if (!fusableProducer) 788 continue; 789 790 auto [fuseSlice, yieldReplacement] = options.fusionControlFn( 791 candidateSliceOp, fusableProducer, destinationInitArg.has_value()); 792 if (!fuseSlice) 793 continue; 794 795 // The operands of the fused producer might themselved be slices of 796 // values produced by operations that implement the `TilingInterface`. 797 // Add these operations to the worklist. 798 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult = 799 tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops); 800 if (!fusedResult) 801 continue; 802 803 if (yieldReplacement) { 804 yieldReplacementForFusedProducer(rewriter, candidateSliceOp, 805 fusedResult.value(), forLoops); 806 origValToResultNumber[fusableProducer] = 807 forLoops.front().getNumResults() - 1; 808 } 809 810 if (Operation *tiledAndFusedOp = 811 fusedResult->tiledAndFusedProducer.getDefiningOp()) { 812 fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); 813 tiledAndFusedOps.insert(tiledAndFusedOp); 814 addCandidateSlices(tiledAndFusedOp, candidates); 815 } 816 } 817 818 DenseMap<Value, Value> replacements; 819 for (auto [origVal, resultNumber] : origValToResultNumber) { 820 replacements[origVal] = forLoops.front()->getResult(resultNumber); 821 } 822 823 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, 824 getAsOperations(forLoops), replacements}; 825 } 826 827 //===----------------------------------------------------------------------===// 828 // tileUsingSCFForAllOp implementation. 829 //===----------------------------------------------------------------------===// 830 831 FailureOr<scf::SCFTilingResult> 832 mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op, 833 const scf::SCFTilingOptions &options) { 834 Location loc = op->getLoc(); 835 OpBuilder::InsertionGuard g(rewriter); 836 837 // 1. Get the range of loops that are represented by the operation. 838 SmallVector<Range> loopRanges = op.getIterationDomain(rewriter); 839 if (loopRanges.empty()) 840 return op->emitOpError("expected non-empty loop ranges"); 841 auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); }; 842 if (llvm::any_of(loopRanges, hasStrideOne)) 843 return op->emitOpError("only stride-1 supported atm"); 844 845 // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed. 846 // To make it easier, pad the tile sizes to loopRanges.size with value 0. 847 SmallVector<OpFoldResult> tileSizeVector = 848 options.tileSizeComputationFunction(rewriter, op); 849 tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0)); 850 851 // 3. Build the offsets, sizes and steps for the tile and distributed loops. 852 SmallVector<OpFoldResult> lbs, ubs, steps; 853 for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) { 854 if (isConstantIntValue(tileSize, 0)) 855 continue; 856 lbs.push_back(loopRange.offset); 857 ubs.push_back(loopRange.size); 858 steps.push_back(tileSize); 859 } 860 861 // 4. Gather destination tensors. 862 SmallVector<Value> dest; 863 if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest))) 864 return op->emitOpError("failed to get destination tensors"); 865 866 // 5. Build the device mapping attribute. 867 std::optional<ArrayAttr> mappingAttr; 868 if (!options.mappingVector.empty()) { 869 mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector)); 870 } 871 872 // 6. Create the ForallOp. We don't use the lambda body-builder 873 // version because we require the use of RewriterBase in the body, so we 874 // manually move the insertion point to the body below. 875 auto forallOp = 876 rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr); 877 878 // 7. Get the tile offset and sizes. 879 rewriter.setInsertionPoint(forallOp.getTerminator()); 880 SmallVector<OpFoldResult> tiledOffsets, tiledSizes; 881 ValueRange ivs = forallOp.getInductionVars(); 882 { 883 int materializedLoopNum = 0; 884 for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) { 885 if (isConstantIntValue(tileSize, 0)) { 886 tiledOffsets.push_back(loopRange.offset); 887 tiledSizes.push_back(loopRange.size); 888 continue; 889 } 890 Value iv = ivs[materializedLoopNum++]; 891 tiledOffsets.push_back(iv); 892 tiledSizes.push_back( 893 getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize)); 894 } 895 } 896 897 // 8. Tile the operation. Clone the operation to allow fix up of destination 898 // operands. 899 ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments(); 900 Operation *clonedOp = 901 cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs); 902 FailureOr<TilingResult> tilingResult = 903 cast<TilingInterface>(clonedOp).getTiledImplementation( 904 rewriter, tiledOffsets, tiledSizes); 905 if (failed(tilingResult)) 906 return clonedOp->emitError("failed to tile op: "); 907 rewriter.eraseOp(clonedOp); 908 909 // 9. Parallel insert back into the result tensor. 910 for (auto [index, tiledValue, destBBArg] : 911 llvm::enumerate(tilingResult->tiledValues, destBbArgs)) { 912 // 9.a. Partial subset information is inserted just before the terminator. 913 rewriter.setInsertionPoint(forallOp.getTerminator()); 914 915 SmallVector<OpFoldResult> resultOffsets, resultSizes; 916 if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets, 917 tiledSizes, resultOffsets, 918 resultSizes))) { 919 return op->emitOpError("output offsets couldn't be calculated"); 920 } 921 922 SmallVector<OpFoldResult> strides(resultSizes.size(), 923 rewriter.getIndexAttr(1)); 924 // 9.b. Parallel insertions are inserted at the end of the combining 925 // terminator. 926 rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); 927 rewriter.create<tensor::ParallelInsertSliceOp>( 928 loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides); 929 } 930 931 // 10. Return the tiling result. 932 return scf::SCFTilingResult{ 933 tilingResult->tiledOps, 934 {forallOp.getOperation()}, 935 llvm::map_to_vector(forallOp.getResults(), 936 [](auto val) -> Value { return val; })}; 937 } 938 939 //===----------------------------------------------------------------------===// 940 // lowerToLoopsUsingSCFForOp implementation. 941 //===----------------------------------------------------------------------===// 942 943 FailureOr<SmallVector<scf::ForOp>> 944 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 945 TilingInterface op) { 946 // TODO: Handle cases where the op has results if needed. 947 if (op->getNumResults() > 0) { 948 return rewriter.notifyMatchFailure( 949 op, "unable to lower to loops operations with return values"); 950 } 951 952 SmallVector<Range> domain = op.getIterationDomain(rewriter); 953 SmallVector<Value> ivs; 954 SmallVector<scf::ForOp> loops; 955 Location loc = op.getLoc(); 956 for (auto loopRange : domain) { 957 Value offsetVal = 958 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 959 Value sizeVal = 960 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 961 Value strideVal = 962 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 963 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 964 strideVal, ValueRange{}); 965 loops.push_back(loop); 966 ivs.push_back(loop.getInductionVar()); 967 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 968 } 969 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 970 return failure(); 971 } 972 return loops; 973 } 974