1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 25 #include "mlir/Interfaces/TilingInterface.h" 26 #include "llvm/Support/Debug.h" 27 #include <optional> 28 29 #define DEBUG_TYPE "tile-using-interface" 30 31 using namespace mlir; 32 33 scf::SCFTilingOptions & 34 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) { 35 assert(!tileSizeComputationFunction && "tile sizes already set"); 36 auto tileSizes = llvm::to_vector(ts); 37 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 38 return tileSizes; 39 }; 40 return *this; 41 } 42 43 /// Helper method to adjust the interchange vector to match the iteration 44 /// domain. 45 static SmallVector<int64_t> 46 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 47 size_t iterationDomainSize) { 48 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 49 if (filledVector.size() < iterationDomainSize) { 50 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 51 filledVector.append(range.begin(), range.end()); 52 } 53 if (filledVector.size() > iterationDomainSize) 54 filledVector.resize(iterationDomainSize); 55 return filledVector; 56 } 57 58 //===----------------------------------------------------------------------===// 59 // tileUsingSCFForOp implementation. 60 //===----------------------------------------------------------------------===// 61 62 // Check if `stride` evenly divides the trip count `size - offset`. 63 static bool tileDividesIterationDomain(Range loopRange) { 64 std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 65 if (!offsetAsInt) 66 return false; 67 std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 68 if (!sizeAsInt) 69 return false; 70 std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 71 if (!strideAsInt) 72 return false; 73 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 74 } 75 76 /// Returns the bounded tile size given the current `iv`, `loopRange` and 77 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. 78 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 79 Range loopRange, Value iv, 80 Value tileSize) { 81 std::optional<int64_t> ts = getConstantIntValue(tileSize); 82 if (ts && ts.value() == 1) 83 return getAsOpFoldResult(tileSize); 84 85 if (tileDividesIterationDomain( 86 Range{loopRange.offset, loopRange.size, tileSize})) 87 return tileSize; 88 89 // The tile size to use (to avoid out of bounds access) is minimum of 90 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 91 // loop. 92 AffineExpr s0, s1, d0; 93 bindDims(b.getContext(), d0); 94 bindSymbols(b.getContext(), s0, s1); 95 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); 96 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 97 return affine::makeComposedFoldedAffineMin( 98 b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size}); 99 } 100 101 /// Generate an empty loop nest that represents the tiled loop nest shell. 102 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 103 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 104 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 105 /// the 106 /// tile processed within the inner most loop. 107 static SmallVector<scf::ForOp> generateTileLoopNest( 108 OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges, 109 ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets, 110 SmallVector<OpFoldResult> &sizes) { 111 assert(!loopRanges.empty() && "expected at least one loop range"); 112 assert(loopRanges.size() == tileSizes.size() && 113 "expected as many tile sizes as loop ranges"); 114 OpBuilder::InsertionGuard guard(builder); 115 SmallVector<scf::ForOp> loops; 116 offsets.resize(loopRanges.size()); 117 sizes.resize(loopRanges.size()); 118 119 for (auto loopRange : llvm::enumerate(loopRanges)) { 120 Value offset = 121 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); 122 Value size = 123 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); 124 Value tileSize = getValueOrCreateConstantIndexOp( 125 builder, loc, tileSizes[loopRange.index()]); 126 // No loops if tile size is zero. Set offset and size to the loop 127 // offset and size. 128 if (matchPattern(tileSize, m_Zero())) { 129 offsets[loopRange.index()] = offset; 130 sizes[loopRange.index()] = size; 131 continue; 132 } 133 134 auto loop = builder.create<scf::ForOp>( 135 loc, offset, size, tileSize, ValueRange{}, 136 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 137 ValueRange /*iterArgs*/) { 138 sizes[loopRange.index()] = getBoundedTileSize( 139 bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize); 140 builder.create<scf::YieldOp>(loc); 141 }); 142 offsets[loopRange.index()] = loop.getInductionVar(); 143 loops.push_back(loop); 144 builder.setInsertionPoint(loop.getBody()->getTerminator()); 145 } 146 return loops; 147 } 148 149 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, 150 /// construct the destructive update pattern that inserts the yielded 151 /// value into a destination tensor provided by `initValue` at offset 152 /// `tileOffsets` and size `tileSizes`. For example, 153 /// 154 /// ```mlir 155 /// scf.for %iv0 = ... { 156 /// %0 = tiled_op 157 /// } 158 /// ``` 159 /// 160 /// is transformed to 161 /// 162 /// ```mlir 163 /// scf.for %iv0 = ... iter_args(%arg = %0) { 164 /// %1 = tensor.extract_slice %arg 165 /// %2 = tiled_op 166 /// %3 = tensor.insert_slice %2 into %arg 167 /// scf.yield %3 168 /// } 169 /// ``` 170 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. 171 static SmallVector<Value> 172 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, 173 ValueRange yieldedValues, 174 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 175 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 176 MutableArrayRef<scf::ForOp> loops) { 177 NewYieldValueFn yieldValueFn = 178 [&](OpBuilder &b, Location loc, 179 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 180 SmallVector<Value> inserts; 181 for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) { 182 ArrayRef<OpFoldResult> tileOffsets = 183 tileOffsetsList[yieldedValue.index()]; 184 ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()]; 185 SmallVector<OpFoldResult> tileStrides(tileOffsets.size(), 186 b.getIndexAttr(1)); 187 Value insert = b.create<tensor::InsertSliceOp>( 188 loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], 189 tileOffsets, tileSizes, tileStrides); 190 inserts.push_back(insert); 191 } 192 return inserts; 193 }; 194 195 SmallVector<scf::ForOp> newLoops = 196 replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, 197 /*replaceIterOperandsUsesInLoop =*/false); 198 for (const auto &loop : llvm::enumerate(loops)) { 199 rewriter.eraseOp(loop.value()); 200 loops[loop.index()] = newLoops[loop.index()]; 201 } 202 return llvm::to_vector(llvm::map_range( 203 loops.front().getResults().take_back(yieldedValues.size()), 204 [](OpResult r) -> Value { return r; })); 205 } 206 207 /// If the tiled operation is destination passing style, update the 208 /// slice of the destination used (which refers to the untiled destination) 209 /// to use the corresponding region argument of the innermost loop. 210 /// 211 /// ```mlir 212 /// %0 = 213 /// scf.for %iv0 = ... iter_args(%arg = %0) { 214 /// %1 = tensor.extract_slice %0 215 /// %2 = tiled_op 216 /// %3 = tensor.insert_slice %2 into %arg 217 /// scf.yield %3 218 /// } 219 /// ``` 220 /// 221 /// is transformed to 222 /// 223 /// ```mlir 224 /// scf.for %iv0 = ... iter_args(%arg = %0) { 225 /// %1 = tensor.extract_slice %arg 226 /// %2 = tiled_op 227 /// %3 = tensor.insert_slice %2 into %arg 228 /// scf.yield %3 229 /// } 230 /// ``` 231 static void 232 updateDestinationOperandsForTiledOp(OpBuilder &builder, 233 ValueRange tiledOpDestinationValues, 234 ValueRange bbArgsList) { 235 for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) { 236 auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>(); 237 if (!sliceOp) 238 continue; 239 sliceOp.setOperand(0, bbArgsList[destValue.index()]); 240 } 241 } 242 243 /// Helper method to yield the values of the tiled op, as well as 244 /// update the destination operands of the tiled op, if it is 245 /// a destination passing style op. 246 static SmallVector<Value> 247 yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues, 248 TilingResult tilingResult, 249 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 250 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 251 MutableArrayRef<scf::ForOp> loops) { 252 SmallVector<Value> replacements = 253 yieldTiledValues(rewriter, initValues, tilingResult.tiledValues, 254 tileOffsetsList, tileSizesList, loops); 255 for (auto tiledOp : tilingResult.tiledOps) { 256 if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) { 257 auto innerMostLoop = loops.back(); 258 SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands(); 259 updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, 260 innerMostLoop.getRegionIterArgs()); 261 } 262 } 263 return replacements; 264 } 265 266 /// Implementation of tiling transformation of `op` that implements the 267 /// `TilingInterface` using `scf.for` to iterate over the tiles. 268 FailureOr<scf::SCFTilingResult> 269 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, 270 const scf::SCFTilingOptions &options) { 271 OpBuilder::InsertionGuard guard(rewriter); 272 rewriter.setInsertionPointAfter(op); 273 274 if (!options.tileSizeComputationFunction) { 275 return rewriter.notifyMatchFailure( 276 op, "missing tile size computation function"); 277 } 278 279 // 1. Get the range of the loops that are represented by the operation. 280 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 281 size_t numLoops = iterationDomain.size(); 282 if (numLoops == 0) { 283 return rewriter.notifyMatchFailure( 284 op, "unable to tile op with no iteration domain"); 285 } 286 287 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 288 // skips tiling a particular dimension. This convention is significantly 289 // simpler to handle instead of adjusting affine maps to account for missing 290 // dimensions. 291 SmallVector<OpFoldResult> tileSizeVector = 292 options.tileSizeComputationFunction(rewriter, op); 293 if (tileSizeVector.size() < iterationDomain.size()) { 294 auto zero = rewriter.getIndexAttr(0); 295 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 296 } 297 298 scf::SCFTilingResult tilingResult; 299 SmallVector<OpFoldResult> offsets, sizes; 300 { 301 // If there is an interchange specified, permute the iteration domain and 302 // the tile sizes. 303 SmallVector<int64_t> interchangeVector; 304 if (!options.interchangeVector.empty()) { 305 interchangeVector = fillInterchangeVector(options.interchangeVector, 306 iterationDomain.size()); 307 } 308 if (!interchangeVector.empty()) { 309 if (!isPermutationVector(interchangeVector)) { 310 return rewriter.notifyMatchFailure( 311 op, "invalid intechange vector, not a permutation of the entire " 312 "iteration space"); 313 } 314 315 applyPermutationToVector(iterationDomain, interchangeVector); 316 applyPermutationToVector(tileSizeVector, interchangeVector); 317 } 318 319 // 3. Materialize an empty loop nest that iterates over the tiles. These 320 // loops for now do not return any values even if the original operation has 321 // results. 322 tilingResult.loops = generateTileLoopNest( 323 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 324 325 if (!interchangeVector.empty()) { 326 auto inversePermutation = invertPermutationVector(interchangeVector); 327 applyPermutationToVector(offsets, inversePermutation); 328 applyPermutationToVector(sizes, inversePermutation); 329 } 330 } 331 332 LLVM_DEBUG({ 333 if (!tilingResult.loops.empty()) { 334 llvm::dbgs() << "LoopNest shell :\n"; 335 tilingResult.loops.front().dump(); 336 llvm::dbgs() << "\n"; 337 } 338 }); 339 340 // 4. Generate the tiled implementation within the inner most loop. 341 if (!tilingResult.loops.empty()) 342 rewriter.setInsertionPoint( 343 tilingResult.loops.back().getBody()->getTerminator()); 344 FailureOr<TilingResult> tiledImplementation = 345 op.getTiledImplementation(rewriter, offsets, sizes); 346 tilingResult.tiledOps.append(tiledImplementation->tiledOps); 347 if (op->getNumResults() == 0) { 348 // nothing more to do. 349 return tilingResult; 350 } 351 352 // If loops are empty, the tiled op is used as the replacement for the untiled 353 // op. 354 if (tilingResult.loops.empty()) { 355 tilingResult.replacements = tiledImplementation->tiledValues; 356 return tilingResult; 357 } 358 359 // 5. Yield all the results of the tiled operation. The surrounding loop 360 // nest is modified to insert a destructive update pattern to yield 361 // from the loop nest values to replace the untiled op with. 362 int64_t numResults = op->getNumResults(); 363 SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults), 364 resultSizesList(numResults); 365 for (const auto &result : llvm::enumerate(op->getResults())) { 366 if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, 367 sizes, 368 resultOffsetsList[result.index()], 369 resultSizesList[result.index()]))) { 370 return rewriter.notifyMatchFailure( 371 op, "failed to get slice of result produced"); 372 } 373 } 374 375 SmallVector<Value> destinationTensors; 376 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 377 destinationTensors))) 378 return rewriter.notifyMatchFailure(op, "failed to get destinations"); 379 380 tilingResult.replacements = yieldTiledValues( 381 rewriter, destinationTensors, tiledImplementation.value(), 382 resultOffsetsList, resultSizesList, tilingResult.loops); 383 384 LLVM_DEBUG({ 385 if (!tilingResult.loops.empty()) { 386 llvm::dbgs() << "After tiled implementation :\n"; 387 tilingResult.loops.front().dump(); 388 llvm::dbgs() << "\n"; 389 } 390 }); 391 return tilingResult; 392 } 393 394 FailureOr<scf::SCFReductionTilingResult> 395 mlir::scf::tileReductionUsingScf(RewriterBase &b, 396 PartialReductionOpInterface op, 397 ArrayRef<OpFoldResult> tileSizes) { 398 Location loc = op.getLoc(); 399 // Ops implementing PartialReductionOpInterface are expected to implement 400 // TilingInterface. 401 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 402 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 403 auto tileSizesVector = llvm::to_vector(tileSizes); 404 if (tileSizesVector.size() < iterationDomain.size()) { 405 auto zero = b.getIndexAttr(0); 406 tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(), 407 zero); 408 } 409 if (op->getNumResults() != 1) 410 return b.notifyMatchFailure( 411 op, "don't support ops with multiple results for now"); 412 SmallVector<utils::IteratorType> iterators = 413 tilingInterfaceOp.getLoopIteratorTypes(); 414 415 SmallVector<int> reductionDims; 416 for (auto [idx, iteratorType] : 417 llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { 418 if (iteratorType == utils::IteratorType::reduction) 419 reductionDims.push_back(idx); 420 } 421 422 // 1. create the inital tensor value. 423 FailureOr<Operation *> identityTensor = 424 op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector, 425 reductionDims); 426 if (failed(identityTensor)) 427 return b.notifyMatchFailure(op, 428 "cannot create a tensor of identity value."); 429 // 2. Create the nested loops. 430 SmallVector<OpFoldResult> offsets, sizes; 431 SmallVector<scf::ForOp> loops = generateTileLoopNest( 432 b, loc, iterationDomain, tileSizesVector, offsets, sizes); 433 434 // 3. Generate the tiled implementation within the inner most loop. 435 b.setInsertionPoint(loops.back().getBody()->getTerminator()); 436 Operation *parallelOp = op.tileToPartialReduction( 437 b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDims); 438 439 SmallVector<OpFoldResult> resultSizesList; 440 for (size_t i = 0; i < offsets.size(); i++) 441 resultSizesList.push_back( 442 tensor::getMixedSize(b, loc, parallelOp->getResult(0), i)); 443 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); 444 SmallVector<Value> replacements = yieldTiledValues( 445 b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets, 446 resultSizesList, loops); 447 448 auto dstOp = cast<DestinationStyleOpInterface>(parallelOp); 449 auto innerMostLoop = loops.back(); 450 SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands(); 451 assert(destinationTensors.size() == 452 innerMostLoop.getRegionIterArgs().size() && 453 "unexpected number of outputs"); 454 updateDestinationOperandsForTiledOp(b, destinationTensors, 455 innerMostLoop.getRegionIterArgs()); 456 457 // 4. Apply the merge reduction to combine all the partial values. 458 b.setInsertionPointAfter(*loops.begin()); 459 Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims); 460 b.replaceOp(op, mergeOp->getResults()); 461 462 SCFReductionTilingResult results; 463 results.initialOp = *identityTensor; 464 results.loops = std::move(loops); 465 results.parallelTiledOp = parallelOp; 466 results.mergeOp = mergeOp; 467 return results; 468 } 469 //===----------------------------------------------------------------------===// 470 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. 471 //===----------------------------------------------------------------------===// 472 473 /// Return the untiled producer whose slice is used in a tiled consumer. The 474 /// method traverses the tile loop nest (`loops`) if needed, and returns the 475 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 476 /// indicates that this is a destination operand of the consumer. If there was 477 /// no loop traversal needed, the second value of the returned tuple is empty. 478 static std::tuple<OpResult, std::optional<OpOperand *>> 479 getUntiledProducerFromSliceSource(OpOperand *source, 480 ArrayRef<scf::ForOp> loops) { 481 std::optional<OpOperand *> destinationIterArg; 482 auto loopIt = loops.rbegin(); 483 while (auto iterArg = dyn_cast<BlockArgument>(source->get())) { 484 scf::ForOp loop = *loopIt; 485 if (iterArg.getOwner()->getParentOp() != loop) 486 break; 487 source = &loop.getOpOperandForRegionIterArg(iterArg); 488 loopIt++; 489 } 490 if (loopIt == loops.rend()) 491 destinationIterArg = source; 492 return {dyn_cast<OpResult>(source->get()), destinationIterArg}; 493 } 494 495 /// Implementation of fusing producer of a single slice by computing the 496 /// slice of the producer in-place. 497 std::optional<scf::SCFFuseProducerOfSliceResult> 498 mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, 499 tensor::ExtractSliceOp candidateSliceOp, 500 MutableArrayRef<scf::ForOp> loops) { 501 // 1. Get the producer of the source (potentially walking through 502 // `iter_args` of nested `scf.for`) 503 auto [fusableProducer, destinationInitArg] = 504 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0], 505 loops); 506 if (!fusableProducer) 507 return std::nullopt; 508 509 // 2. Generate the tiled implementation of the producer of the source 510 OpBuilder::InsertionGuard g(rewriter); 511 rewriter.setInsertionPoint(candidateSliceOp); 512 FailureOr<TilingResult> tileAndFuseResult = 513 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 514 fusableProducer); 515 if (failed(tileAndFuseResult)) 516 return std::nullopt; 517 rewriter.replaceAllUsesWith(candidateSliceOp, 518 tileAndFuseResult->tiledValues[0]); 519 520 // 3. If the slice is for a destination operand, for example, 521 // 522 // ```mlir 523 // %0 = linalg.init 524 // %1 = linalg.fill .. outs(%0 : ) 525 // %2 = scf.for .. iter_args(%arg0 = %1) { 526 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 527 // %4 = tensor.extract_slice %arg1 [..] 528 // .. = linalg.matmul .. outs(%4 : ) 529 // } 530 // } 531 // ``` 532 // 533 // the IR is currently 534 // 535 // ``` 536 // %0 = linalg.init 537 // %1 = linalg.fill 538 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 539 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 540 // %4 = tensor.extract_slice %0 /*incorrect value */ [..] 541 // %5 = linalg.fill .. outs(%4 : ) 542 // .. = linalg.matmul .. outs(%5 : ) 543 // } 544 // } 545 // ``` 546 // 547 // The untiled `linalg.fill` is still used as the `init_value` since it 548 // was originally a destination operand of the untiled `linalg.matmul`. 549 // When fusing an operand that is a destination operand. 550 // - Update the iter_arg of the outer most loop to use the destination 551 // of the untiled producer. 552 // - Update the destination of the slice of the tiled producer generated 553 // to use the same basic block argument as the slice that was used to 554 // generate inplace the tiled implementation of the producer. 555 // With this the IR will be. 556 // 557 // ``` 558 // %0 = linalg.init 559 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 560 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 561 // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] 562 // %4 = linalg.fill .. outs(%3 : ) 563 // .. = linalg.matmul .. outs(%4 : ) 564 // } 565 // } 566 // ``` 567 // TODO: This can be modeled better if the `DestinationStyleOpInterface`. 568 // Update to use that when it does become available. 569 scf::ForOp outerMostLoop = loops.front(); 570 if (destinationInitArg && 571 (*destinationInitArg)->getOwner() == outerMostLoop) { 572 unsigned iterArgNumber = 573 outerMostLoop.getResultForOpOperand(**destinationInitArg) 574 .getResultNumber(); 575 int64_t resultNumber = fusableProducer.getResultNumber(); 576 if (auto dstOp = 577 dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) { 578 (*destinationInitArg) 579 ->set(dstOp.getTiedOpOperand(fusableProducer)->get()); 580 } 581 for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) { 582 auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp); 583 if (!dstOp) 584 continue; 585 scf::ForOp innerMostLoop = loops.back(); 586 updateDestinationOperandsForTiledOp( 587 rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), 588 innerMostLoop.getRegionIterArgs()[iterArgNumber]); 589 } 590 } 591 return scf::SCFFuseProducerOfSliceResult{fusableProducer, 592 tileAndFuseResult->tiledValues[0], 593 tileAndFuseResult->tiledOps}; 594 } 595 596 /// Reconstruct the fused producer from within the tiled-and-fused code. 597 void mlir::scf::yieldReplacementForFusedProducer( 598 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, 599 scf::SCFFuseProducerOfSliceResult fusedProducerInfo, 600 MutableArrayRef<scf::ForOp> loops) { 601 auto [fusableProducer, fusedProducerValue, tileAndFusedOps] = 602 fusedProducerInfo; 603 SmallVector<Value> initValues; 604 FailureOr<Value> initValue = tensor::getOrCreateDestination( 605 rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); 606 if (succeeded(initValue)) { 607 SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets(); 608 SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes(); 609 SmallVector<Value> yieldedVals = 610 yieldTiledValues(rewriter, initValue.value(), fusedProducerValue, 611 resultOffsets, resultSizes, loops); 612 } 613 for (auto tileAndFusedOp : tileAndFusedOps) { 614 auto dstStyleProducer = 615 dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp); 616 if (!dstStyleProducer) 617 continue; 618 Value dstValue = 619 dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber()) 620 ->get(); 621 updateDestinationOperandsForTiledOp( 622 rewriter, dstValue, loops.back().getRegionIterArgs().back()); 623 } 624 } 625 626 /// Implementation of tile consumer and fuse producer greedily. 627 FailureOr<scf::SCFTileAndFuseResult> 628 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( 629 RewriterBase &rewriter, TilingInterface consumer, 630 const scf::SCFTileAndFuseOptions &options) { 631 // This transformation is only valid for ops that return values (i.e. not 632 // valid to use with operations that have memref operands). 633 if (!consumer->getNumResults()) { 634 return rewriter.notifyMatchFailure( 635 consumer, "invalid pattern for op with no results"); 636 } 637 638 // 1. First tile the consumer. 639 scf::SCFTileAndFuseResult tileAndFuseResult; 640 llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber; 641 { 642 FailureOr<scf::SCFTilingResult> tilingResult = 643 tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); 644 if (failed(tilingResult)) 645 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 646 for (auto *tiledOp : tilingResult->tiledOps) 647 tileAndFuseResult.tiledAndFusedOps.insert(tiledOp); 648 tileAndFuseResult.loops = std::move(tilingResult->loops); 649 for (const auto &result : llvm::enumerate( 650 llvm::zip(consumer->getResults(), tilingResult->replacements))) { 651 tileAndFuseResult.replacements[std::get<0>(result.value())] = 652 std::get<1>(result.value()); 653 yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult( 654 result.index())] = result.index(); 655 } 656 } 657 658 // If there are no loops generated, fusion is immaterial. 659 if (tileAndFuseResult.loops.empty()) 660 return tileAndFuseResult; 661 662 // 2. Typically, the operands of the tiled operation are slices of the 663 // operands of the untiled operation. These are expressed in IR using 664 // `tensor.extract_slice` operations with source being the operands of the 665 // untiled operation. Create a worklist of these `tensor.extract_slice` 666 // operations. If the producers of the source of the `tensor.extract_slice` 667 // can be tiled such that the tiled value is generated in-place, that 668 // effectively tiles + fuses the operations. 669 auto addCandidateSlices = [](Operation *fusedOp, 670 std::deque<tensor::ExtractSliceOp> &candidates) { 671 for (Value operand : fusedOp->getOperands()) 672 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 673 candidates.push_back(sliceOp); 674 }; 675 676 std::deque<tensor::ExtractSliceOp> candidates; 677 addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 678 OpBuilder::InsertionGuard g(rewriter); 679 while (!candidates.empty()) { 680 // Traverse the slices in BFS fashion. 681 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 682 candidates.pop_front(); 683 684 // The operands of the fused producer might themselved be slices of 685 // values produced by operations that implement the `TilingInterface`. 686 // Add these operations to the worklist. 687 std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer = 688 tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, 689 tileAndFuseResult.loops); 690 if (!fusedProducer) 691 continue; 692 693 if (Operation *tiledAndFusedOp = 694 fusedProducer->tiledAndFusedProducer.getDefiningOp()) { 695 tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp); 696 addCandidateSlices(tiledAndFusedOp, candidates); 697 } 698 } 699 return tileAndFuseResult; 700 } 701 702 //===----------------------------------------------------------------------===// 703 // lowerToLoopsUsingSCFForOp implementation. 704 //===----------------------------------------------------------------------===// 705 706 FailureOr<SmallVector<scf::ForOp>> 707 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 708 TilingInterface op) { 709 // TODO: Handle cases where the op has results if needed. 710 if (op->getNumResults() > 0) { 711 return rewriter.notifyMatchFailure( 712 op, "unable to lower to loops operations with return values"); 713 } 714 715 SmallVector<Range> domain = op.getIterationDomain(rewriter); 716 SmallVector<Value> ivs; 717 SmallVector<scf::ForOp> loops; 718 Location loc = op.getLoc(); 719 for (auto loopRange : domain) { 720 Value offsetVal = 721 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 722 Value sizeVal = 723 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 724 Value strideVal = 725 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 726 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 727 strideVal, ValueRange{}); 728 loops.push_back(loop); 729 ivs.push_back(loop.getInductionVar()); 730 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 731 } 732 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 733 return failure(); 734 } 735 return loops; 736 } 737