1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 25 #include "mlir/Interfaces/TilingInterface.h" 26 #include "llvm/Support/Debug.h" 27 #include <optional> 28 29 #define DEBUG_TYPE "tile-using-interface" 30 31 using namespace mlir; 32 33 scf::SCFTilingOptions & 34 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) { 35 assert(!tileSizeComputationFunction && "tile sizes already set"); 36 auto tileSizes = llvm::to_vector(ts); 37 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 38 return tileSizes; 39 }; 40 return *this; 41 } 42 43 /// Helper method to adjust the interchange vector to match the iteration 44 /// domain. 45 static SmallVector<int64_t> 46 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 47 size_t iterationDomainSize) { 48 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 49 if (filledVector.size() < iterationDomainSize) { 50 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 51 filledVector.append(range.begin(), range.end()); 52 } 53 if (filledVector.size() > iterationDomainSize) 54 filledVector.resize(iterationDomainSize); 55 return filledVector; 56 } 57 58 /// Convert a list of ops of type `SrcOpTy` to list of `Operation *`. 59 template <typename SrcOpTy> 60 static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) { 61 return llvm::to_vector( 62 llvm::map_range(ops, [](auto op) -> Operation * { return op; })); 63 } 64 template <typename SrcOpTy> 65 static SmallVector<Operation *> 66 getAsOperations(const SmallVector<SrcOpTy> &ops) { 67 return getAsOperations(ArrayRef<SrcOpTy>(ops)); 68 } 69 70 /// Convert a list of `Operation *` to a list of `DstOpTy. 71 template <typename DstOpTy> 72 static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) { 73 return llvm::to_vector( 74 llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); })); 75 } 76 template <typename DstOpTy> 77 static SmallVector<DstOpTy> 78 castToTypedOperations(const SmallVector<Operation *> &ops) { 79 return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops)); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // tileUsingSCFForOp implementation. 84 //===----------------------------------------------------------------------===// 85 86 // Check if `stride` evenly divides the trip count `size - offset`. 87 static bool tileDividesIterationDomain(Range loopRange) { 88 std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 89 if (!offsetAsInt) 90 return false; 91 std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 92 if (!sizeAsInt) 93 return false; 94 std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 95 if (!strideAsInt) 96 return false; 97 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 98 } 99 100 /// Returns the bounded tile size given the current `iv`, `loopRange` and 101 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. 102 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 103 Range loopRange, Value iv, 104 OpFoldResult tileSize) { 105 if (isConstantIntValue(tileSize, 1)) 106 return tileSize; 107 108 if (tileDividesIterationDomain( 109 Range{loopRange.offset, loopRange.size, tileSize})) 110 return tileSize; 111 112 // The tile size to use (to avoid out of bounds access) is minimum of 113 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 114 // loop. 115 AffineExpr s0, s1, d0; 116 bindDims(b.getContext(), d0); 117 bindSymbols(b.getContext(), s0, s1); 118 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); 119 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 120 return affine::makeComposedFoldedAffineMin( 121 b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size}); 122 } 123 124 /// Generate an empty loop nest that represents the tiled loop nest shell. 125 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 126 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 127 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 128 /// the 129 /// tile processed within the inner most loop. 130 static SmallVector<scf::ForOp> generateTileLoopNest( 131 OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges, 132 ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets, 133 SmallVector<OpFoldResult> &sizes) { 134 assert(!loopRanges.empty() && "expected at least one loop range"); 135 assert(loopRanges.size() == tileSizes.size() && 136 "expected as many tile sizes as loop ranges"); 137 OpBuilder::InsertionGuard guard(builder); 138 SmallVector<scf::ForOp> loops; 139 offsets.resize(loopRanges.size()); 140 sizes.resize(loopRanges.size()); 141 142 for (auto loopRange : llvm::enumerate(loopRanges)) { 143 Value offset = 144 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); 145 Value size = 146 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); 147 Value tileSize = getValueOrCreateConstantIndexOp( 148 builder, loc, tileSizes[loopRange.index()]); 149 // No loops if tile size is zero. Set offset and size to the loop 150 // offset and size. 151 if (matchPattern(tileSize, m_Zero())) { 152 offsets[loopRange.index()] = offset; 153 sizes[loopRange.index()] = size; 154 continue; 155 } 156 157 auto loop = builder.create<scf::ForOp>( 158 loc, offset, size, tileSize, ValueRange{}, 159 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 160 ValueRange /*iterArgs*/) { 161 sizes[loopRange.index()] = getBoundedTileSize( 162 bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize); 163 builder.create<scf::YieldOp>(loc); 164 }); 165 offsets[loopRange.index()] = loop.getInductionVar(); 166 loops.push_back(loop); 167 builder.setInsertionPoint(loop.getBody()->getTerminator()); 168 } 169 return loops; 170 } 171 172 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, 173 /// construct the destructive update pattern that inserts the yielded 174 /// value into a destination tensor provided by `initValue` at offset 175 /// `tileOffsets` and size `tileSizes`. For example, 176 /// 177 /// ```mlir 178 /// scf.for %iv0 = ... { 179 /// %0 = tiled_op 180 /// } 181 /// ``` 182 /// 183 /// is transformed to 184 /// 185 /// ```mlir 186 /// scf.for %iv0 = ... iter_args(%arg = %0) { 187 /// %1 = tensor.extract_slice %arg 188 /// %2 = tiled_op 189 /// %3 = tensor.insert_slice %2 into %arg 190 /// scf.yield %3 191 /// } 192 /// ``` 193 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. 194 static SmallVector<Value> 195 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, 196 ValueRange yieldedValues, 197 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 198 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 199 MutableArrayRef<scf::ForOp> loops) { 200 NewYieldValueFn yieldValueFn = 201 [&](OpBuilder &b, Location loc, 202 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 203 SmallVector<Value> inserts; 204 for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) { 205 ArrayRef<OpFoldResult> tileOffsets = 206 tileOffsetsList[yieldedValue.index()]; 207 ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()]; 208 SmallVector<OpFoldResult> tileStrides(tileOffsets.size(), 209 b.getIndexAttr(1)); 210 Value insert = b.create<tensor::InsertSliceOp>( 211 loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], 212 tileOffsets, tileSizes, tileStrides); 213 inserts.push_back(insert); 214 } 215 return inserts; 216 }; 217 218 SmallVector<scf::ForOp> newLoops = 219 replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, 220 /*replaceIterOperandsUsesInLoop =*/false); 221 for (const auto &loop : llvm::enumerate(loops)) { 222 rewriter.eraseOp(loop.value()); 223 loops[loop.index()] = newLoops[loop.index()]; 224 } 225 return llvm::to_vector(llvm::map_range( 226 loops.front().getResults().take_back(yieldedValues.size()), 227 [](OpResult r) -> Value { return r; })); 228 } 229 230 /// If the tiled operation is destination passing style, update the 231 /// slice of the destination used (which refers to the untiled destination) 232 /// to use the corresponding region argument of the innermost loop. 233 /// 234 /// ```mlir 235 /// %0 = 236 /// scf.for %iv0 = ... iter_args(%arg = %0) { 237 /// %1 = tensor.extract_slice %0 238 /// %2 = tiled_op 239 /// %3 = tensor.insert_slice %2 into %arg 240 /// scf.yield %3 241 /// } 242 /// ``` 243 /// 244 /// is transformed to 245 /// 246 /// ```mlir 247 /// scf.for %iv0 = ... iter_args(%arg = %0) { 248 /// %1 = tensor.extract_slice %arg 249 /// %2 = tiled_op 250 /// %3 = tensor.insert_slice %2 into %arg 251 /// scf.yield %3 252 /// } 253 /// ``` 254 static void 255 updateDestinationOperandsForTiledOp(OpBuilder &builder, 256 ValueRange tiledOpDestinationValues, 257 ValueRange bbArgsList) { 258 for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) { 259 auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>(); 260 if (!sliceOp) 261 continue; 262 sliceOp.setOperand(0, bbArgsList[destValue.index()]); 263 } 264 } 265 266 /// Helper method to yield the values of the tiled op, as well as 267 /// update the destination operands of the tiled op, if it is 268 /// a destination passing style op. 269 static SmallVector<Value> 270 yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues, 271 TilingResult tilingResult, 272 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 273 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 274 MutableArrayRef<scf::ForOp> loops) { 275 SmallVector<Value> replacements = 276 yieldTiledValues(rewriter, initValues, tilingResult.tiledValues, 277 tileOffsetsList, tileSizesList, loops); 278 for (auto tiledOp : tilingResult.tiledOps) { 279 if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) { 280 auto innerMostLoop = loops.back(); 281 SmallVector<Value> tiledOpDestinationTensors = 282 llvm::to_vector(dstOp.getDpsInits()); 283 updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, 284 innerMostLoop.getRegionIterArgs()); 285 } 286 } 287 return replacements; 288 } 289 290 /// Implementation of tiling transformation of `op` that implements the 291 /// `TilingInterface` using `scf.for` to iterate over the tiles. 292 FailureOr<scf::SCFTilingResult> 293 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, 294 const scf::SCFTilingOptions &options) { 295 OpBuilder::InsertionGuard guard(rewriter); 296 rewriter.setInsertionPointAfter(op); 297 298 if (!options.tileSizeComputationFunction) { 299 return rewriter.notifyMatchFailure( 300 op, "missing tile size computation function"); 301 } 302 303 // 1. Get the range of the loops that are represented by the operation. 304 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 305 size_t numLoops = iterationDomain.size(); 306 if (numLoops == 0) { 307 return rewriter.notifyMatchFailure( 308 op, "unable to tile op with no iteration domain"); 309 } 310 311 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 312 // skips tiling a particular dimension. This convention is significantly 313 // simpler to handle instead of adjusting affine maps to account for missing 314 // dimensions. 315 SmallVector<OpFoldResult> tileSizeVector = 316 options.tileSizeComputationFunction(rewriter, op); 317 if (tileSizeVector.size() < iterationDomain.size()) { 318 auto zero = rewriter.getIndexAttr(0); 319 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 320 } 321 322 SmallVector<OpFoldResult> offsets, sizes; 323 SmallVector<scf::ForOp> forLoops; 324 { 325 // If there is an interchange specified, permute the iteration domain and 326 // the tile sizes. 327 SmallVector<int64_t> interchangeVector; 328 if (!options.interchangeVector.empty()) { 329 interchangeVector = fillInterchangeVector(options.interchangeVector, 330 iterationDomain.size()); 331 } 332 if (!interchangeVector.empty()) { 333 if (!isPermutationVector(interchangeVector)) { 334 return rewriter.notifyMatchFailure( 335 op, "invalid intechange vector, not a permutation of the entire " 336 "iteration space"); 337 } 338 339 applyPermutationToVector(iterationDomain, interchangeVector); 340 applyPermutationToVector(tileSizeVector, interchangeVector); 341 } 342 343 // 3. Materialize an empty loop nest that iterates over the tiles. These 344 // loops for now do not return any values even if the original operation has 345 // results. 346 forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain, 347 tileSizeVector, offsets, sizes); 348 349 if (!interchangeVector.empty()) { 350 auto inversePermutation = invertPermutationVector(interchangeVector); 351 applyPermutationToVector(offsets, inversePermutation); 352 applyPermutationToVector(sizes, inversePermutation); 353 } 354 } 355 356 LLVM_DEBUG({ 357 if (!forLoops.empty()) { 358 llvm::dbgs() << "LoopNest shell :\n"; 359 forLoops.front().dump(); 360 llvm::dbgs() << "\n"; 361 } 362 }); 363 364 // 4. Generate the tiled implementation within the inner most loop. 365 if (!forLoops.empty()) 366 rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator()); 367 FailureOr<TilingResult> tiledImplementation = 368 op.getTiledImplementation(rewriter, offsets, sizes); 369 370 if (op->getNumResults() == 0) { 371 return scf::SCFTilingResult{ 372 tiledImplementation->tiledOps, getAsOperations(forLoops), {}}; 373 } 374 375 // If loops are empty, the tiled op is used as the replacement for the untiled 376 // op. 377 if (forLoops.empty()) { 378 return scf::SCFTilingResult{tiledImplementation->tiledOps, 379 getAsOperations(forLoops), 380 tiledImplementation->tiledValues}; 381 } 382 383 // 5. Yield all the results of the tiled operation. The surrounding loop 384 // nest is modified to insert a destructive update pattern to yield 385 // from the loop nest values to replace the untiled op with. 386 int64_t numResults = op->getNumResults(); 387 SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults), 388 resultSizesList(numResults); 389 for (const auto &result : llvm::enumerate(op->getResults())) { 390 if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, 391 sizes, 392 resultOffsetsList[result.index()], 393 resultSizesList[result.index()]))) { 394 return rewriter.notifyMatchFailure( 395 op, "failed to get slice of result produced"); 396 } 397 } 398 399 SmallVector<Value> destinationTensors; 400 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 401 destinationTensors))) 402 return rewriter.notifyMatchFailure(op, "failed to get destinations"); 403 404 SmallVector<Value> replacements = yieldTiledValues( 405 rewriter, destinationTensors, tiledImplementation.value(), 406 resultOffsetsList, resultSizesList, forLoops); 407 LLVM_DEBUG({ 408 if (!forLoops.empty()) { 409 llvm::dbgs() << "After tiled implementation :\n"; 410 forLoops.front().dump(); 411 llvm::dbgs() << "\n"; 412 } 413 }); 414 return scf::SCFTilingResult{tiledImplementation->tiledOps, 415 getAsOperations(forLoops), replacements}; 416 } 417 418 FailureOr<scf::SCFReductionTilingResult> 419 mlir::scf::tileReductionUsingScf(RewriterBase &b, 420 PartialReductionOpInterface op, 421 ArrayRef<OpFoldResult> tileSizes) { 422 Location loc = op.getLoc(); 423 // Ops implementing PartialReductionOpInterface are expected to implement 424 // TilingInterface. 425 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 426 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 427 auto tileSizesVector = llvm::to_vector(tileSizes); 428 if (tileSizesVector.size() < iterationDomain.size()) { 429 auto zero = b.getIndexAttr(0); 430 tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(), 431 zero); 432 } 433 if (op->getNumResults() != 1) 434 return b.notifyMatchFailure( 435 op, "don't support ops with multiple results for now"); 436 SmallVector<utils::IteratorType> iterators = 437 tilingInterfaceOp.getLoopIteratorTypes(); 438 439 SmallVector<int> reductionDims; 440 for (auto [idx, iteratorType] : 441 llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { 442 if (iteratorType == utils::IteratorType::reduction) 443 reductionDims.push_back(idx); 444 } 445 446 // 1. create the inital tensor value. 447 FailureOr<Operation *> identityTensor = 448 op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector, 449 reductionDims); 450 if (failed(identityTensor)) 451 return b.notifyMatchFailure(op, 452 "cannot create a tensor of identity value."); 453 // 2. Create the nested loops. 454 SmallVector<OpFoldResult> offsets, sizes; 455 SmallVector<scf::ForOp> loops = generateTileLoopNest( 456 b, loc, iterationDomain, tileSizesVector, offsets, sizes); 457 458 // 3. Generate the tiled implementation within the inner most loop. 459 b.setInsertionPoint(loops.back().getBody()->getTerminator()); 460 Operation *parallelOp = op.tileToPartialReduction( 461 b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDims); 462 463 SmallVector<OpFoldResult> resultSizesList; 464 for (size_t i = 0; i < offsets.size(); i++) 465 resultSizesList.push_back( 466 tensor::getMixedSize(b, loc, parallelOp->getResult(0), i)); 467 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); 468 SmallVector<Value> replacements = yieldTiledValues( 469 b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets, 470 resultSizesList, loops); 471 472 auto dstOp = cast<DestinationStyleOpInterface>(parallelOp); 473 auto innerMostLoop = loops.back(); 474 SmallVector<Value> destinationTensors = llvm::to_vector(dstOp.getDpsInits()); 475 assert(destinationTensors.size() == 476 innerMostLoop.getRegionIterArgs().size() && 477 "unexpected number of outputs"); 478 updateDestinationOperandsForTiledOp(b, destinationTensors, 479 innerMostLoop.getRegionIterArgs()); 480 481 // 4. Apply the merge reduction to combine all the partial values. 482 b.setInsertionPointAfter(*loops.begin()); 483 Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims); 484 b.replaceOp(op, mergeOp->getResults()); 485 486 SCFReductionTilingResult results; 487 results.initialOp = *identityTensor; 488 results.loops = std::move(loops); 489 results.parallelTiledOp = parallelOp; 490 results.mergeOp = mergeOp; 491 return results; 492 } 493 494 //===----------------------------------------------------------------------===// 495 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. 496 //===----------------------------------------------------------------------===// 497 498 /// Return the untiled producer whose slice is used in a tiled consumer. The 499 /// method traverses the tile loop nest (`loops`) if needed, and returns the 500 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 501 /// indicates that this is a destination operand of the consumer. If there was 502 /// no loop traversal needed, the second value of the returned tuple is empty. 503 static std::tuple<OpResult, std::optional<OpOperand *>> 504 getUntiledProducerFromSliceSource(OpOperand *source, 505 ArrayRef<scf::ForOp> loops) { 506 std::optional<OpOperand *> destinationIterArg; 507 auto loopIt = loops.rbegin(); 508 while (auto iterArg = dyn_cast<BlockArgument>(source->get())) { 509 scf::ForOp loop = *loopIt; 510 if (iterArg.getOwner()->getParentOp() != loop) 511 break; 512 source = &loop.getOpOperandForRegionIterArg(iterArg); 513 loopIt++; 514 } 515 if (loopIt == loops.rend()) 516 destinationIterArg = source; 517 return {dyn_cast<OpResult>(source->get()), destinationIterArg}; 518 } 519 520 /// Implementation of fusing producer of a single slice by computing the 521 /// slice of the producer in-place. 522 std::optional<scf::SCFFuseProducerOfSliceResult> 523 mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, 524 tensor::ExtractSliceOp candidateSliceOp, 525 MutableArrayRef<scf::ForOp> loops) { 526 // 1. Get the producer of the source (potentially walking through 527 // `iter_args` of nested `scf.for`) 528 auto [fusableProducer, destinationInitArg] = 529 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0], 530 loops); 531 if (!fusableProducer) 532 return std::nullopt; 533 534 // 2. Generate the tiled implementation of the producer of the source 535 OpBuilder::InsertionGuard g(rewriter); 536 rewriter.setInsertionPoint(candidateSliceOp); 537 FailureOr<TilingResult> tileAndFuseResult = 538 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 539 fusableProducer); 540 if (failed(tileAndFuseResult)) 541 return std::nullopt; 542 rewriter.replaceAllUsesWith(candidateSliceOp, 543 tileAndFuseResult->tiledValues[0]); 544 545 // 3. If the slice is for a destination operand, for example, 546 // 547 // ```mlir 548 // %0 = linalg.init 549 // %1 = linalg.fill .. outs(%0 : ) 550 // %2 = scf.for .. iter_args(%arg0 = %1) { 551 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 552 // %4 = tensor.extract_slice %arg1 [..] 553 // .. = linalg.matmul .. outs(%4 : ) 554 // } 555 // } 556 // ``` 557 // 558 // the IR is currently 559 // 560 // ``` 561 // %0 = linalg.init 562 // %1 = linalg.fill 563 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 564 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 565 // %4 = tensor.extract_slice %0 /*incorrect value */ [..] 566 // %5 = linalg.fill .. outs(%4 : ) 567 // .. = linalg.matmul .. outs(%5 : ) 568 // } 569 // } 570 // ``` 571 // 572 // The untiled `linalg.fill` is still used as the `init_value` since it 573 // was originally a destination operand of the untiled `linalg.matmul`. 574 // When fusing an operand that is a destination operand. 575 // - Update the iter_arg of the outer most loop to use the destination 576 // of the untiled producer. 577 // - Update the destination of the slice of the tiled producer generated 578 // to use the same basic block argument as the slice that was used to 579 // generate inplace the tiled implementation of the producer. 580 // With this the IR will be. 581 // 582 // ``` 583 // %0 = linalg.init 584 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 585 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 586 // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] 587 // %4 = linalg.fill .. outs(%3 : ) 588 // .. = linalg.matmul .. outs(%4 : ) 589 // } 590 // } 591 // ``` 592 // TODO: This can be modeled better if the `DestinationStyleOpInterface`. 593 // Update to use that when it does become available. 594 scf::ForOp outerMostLoop = loops.front(); 595 if (destinationInitArg && 596 (*destinationInitArg)->getOwner() == outerMostLoop) { 597 unsigned iterArgNumber = 598 outerMostLoop.getResultForOpOperand(**destinationInitArg) 599 .getResultNumber(); 600 int64_t resultNumber = fusableProducer.getResultNumber(); 601 if (auto dstOp = 602 dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) { 603 (*destinationInitArg) 604 ->set(dstOp.getTiedOpOperand(fusableProducer)->get()); 605 } 606 for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) { 607 auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp); 608 if (!dstOp) 609 continue; 610 scf::ForOp innerMostLoop = loops.back(); 611 updateDestinationOperandsForTiledOp( 612 rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), 613 innerMostLoop.getRegionIterArgs()[iterArgNumber]); 614 } 615 } 616 return scf::SCFFuseProducerOfSliceResult{fusableProducer, 617 tileAndFuseResult->tiledValues[0], 618 tileAndFuseResult->tiledOps}; 619 } 620 621 /// Reconstruct the fused producer from within the tiled-and-fused code. 622 void mlir::scf::yieldReplacementForFusedProducer( 623 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, 624 scf::SCFFuseProducerOfSliceResult fusedProducerInfo, 625 MutableArrayRef<scf::ForOp> loops) { 626 auto [fusableProducer, fusedProducerValue, tileAndFusedOps] = 627 fusedProducerInfo; 628 SmallVector<Value> initValues; 629 FailureOr<Value> initValue = tensor::getOrCreateDestination( 630 rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); 631 if (succeeded(initValue)) { 632 SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets(); 633 SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes(); 634 SmallVector<Value> yieldedVals = 635 yieldTiledValues(rewriter, initValue.value(), fusedProducerValue, 636 resultOffsets, resultSizes, loops); 637 } 638 for (auto tileAndFusedOp : tileAndFusedOps) { 639 auto dstStyleProducer = 640 dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp); 641 if (!dstStyleProducer) 642 continue; 643 Value dstValue = 644 dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber()) 645 ->get(); 646 updateDestinationOperandsForTiledOp( 647 rewriter, dstValue, loops.back().getRegionIterArgs().back()); 648 } 649 } 650 651 /// Implementation of tile consumer and fuse producer greedily. 652 FailureOr<scf::SCFTileAndFuseResult> 653 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( 654 RewriterBase &rewriter, TilingInterface consumer, 655 const scf::SCFTileAndFuseOptions &options) { 656 // This transformation is only valid for ops that return values (i.e. not 657 // valid to use with operations that have memref operands). 658 if (!consumer->getNumResults()) { 659 return rewriter.notifyMatchFailure( 660 consumer, "invalid pattern for op with no results"); 661 } 662 663 // 1. First tile the consumer. 664 SmallVector<scf::ForOp> forLoops; 665 SetVector<Operation *> fusedProducers, tiledAndFusedOps; 666 DenseMap<Value, Value> replacements; 667 llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber; 668 { 669 FailureOr<scf::SCFTilingResult> tilingResult = 670 tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); 671 if (failed(tilingResult)) 672 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 673 for (auto *tiledOp : tilingResult->tiledOps) 674 tiledAndFusedOps.insert(tiledOp); 675 forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops); 676 for (auto [index, origValue, replacement] : 677 llvm::enumerate(consumer->getResults(), tilingResult->replacements)) { 678 replacements[origValue] = replacement; 679 yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult( 680 index)] = index; 681 } 682 } 683 684 // If there are no loops generated, fusion is immaterial. 685 if (forLoops.empty()) { 686 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, 687 getAsOperations(forLoops), replacements}; 688 } 689 690 // 2. Typically, the operands of the tiled operation are slices of the 691 // operands of the untiled operation. These are expressed in IR using 692 // `tensor.extract_slice` operations with source being the operands of the 693 // untiled operation. Create a worklist of these `tensor.extract_slice` 694 // operations. If the producers of the source of the `tensor.extract_slice` 695 // can be tiled such that the tiled value is generated in-place, that 696 // effectively tiles + fuses the operations. 697 auto addCandidateSlices = [](Operation *fusedOp, 698 std::deque<tensor::ExtractSliceOp> &candidates) { 699 for (Value operand : fusedOp->getOperands()) 700 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 701 candidates.push_back(sliceOp); 702 }; 703 704 std::deque<tensor::ExtractSliceOp> candidates; 705 addCandidateSlices(tiledAndFusedOps.back(), candidates); 706 OpBuilder::InsertionGuard g(rewriter); 707 while (!candidates.empty()) { 708 // Traverse the slices in BFS fashion. 709 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 710 candidates.pop_front(); 711 712 // The operands of the fused producer might themselved be slices of 713 // values produced by operations that implement the `TilingInterface`. 714 // Add these operations to the worklist. 715 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult = 716 tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops); 717 if (!fusedResult) 718 continue; 719 720 if (Operation *tiledAndFusedOp = 721 fusedResult->tiledAndFusedProducer.getDefiningOp()) { 722 fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); 723 tiledAndFusedOps.insert(tiledAndFusedOp); 724 addCandidateSlices(tiledAndFusedOp, candidates); 725 } 726 } 727 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, 728 getAsOperations(forLoops), replacements}; 729 } 730 731 //===----------------------------------------------------------------------===// 732 // lowerToLoopsUsingSCFForOp implementation. 733 //===----------------------------------------------------------------------===// 734 735 FailureOr<SmallVector<scf::ForOp>> 736 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 737 TilingInterface op) { 738 // TODO: Handle cases where the op has results if needed. 739 if (op->getNumResults() > 0) { 740 return rewriter.notifyMatchFailure( 741 op, "unable to lower to loops operations with return values"); 742 } 743 744 SmallVector<Range> domain = op.getIterationDomain(rewriter); 745 SmallVector<Value> ivs; 746 SmallVector<scf::ForOp> loops; 747 Location loc = op.getLoc(); 748 for (auto loopRange : domain) { 749 Value offsetVal = 750 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 751 Value sizeVal = 752 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 753 Value strideVal = 754 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 755 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 756 strideVal, ValueRange{}); 757 loops.push_back(loop); 758 ivs.push_back(loop.getInductionVar()); 759 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 760 } 761 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 762 return failure(); 763 } 764 return loops; 765 } 766