1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 25 #include "mlir/Interfaces/TilingInterface.h" 26 #include "llvm/Support/Debug.h" 27 #include <optional> 28 29 #define DEBUG_TYPE "tile-using-interface" 30 31 using namespace mlir; 32 33 scf::SCFTilingOptions & 34 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 35 assert(!tileSizeComputationFunction && "tile sizes already set"); 36 SmallVector<int64_t> tileSizes(ts.begin(), ts.end()); 37 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 38 OpBuilder::InsertionGuard guard(b); 39 b.setInsertionPointToStart( 40 &op->getParentOfType<func::FuncOp>().getBody().front()); 41 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 42 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 43 return v; 44 })); 45 }; 46 return *this; 47 } 48 49 /// Helper method to adjust the interchange vector to match the iteration 50 /// domain. 51 static SmallVector<int64_t> 52 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 53 size_t iterationDomainSize) { 54 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 55 if (filledVector.size() < iterationDomainSize) { 56 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 57 filledVector.append(range.begin(), range.end()); 58 } 59 if (filledVector.size() > iterationDomainSize) 60 filledVector.resize(iterationDomainSize); 61 return filledVector; 62 } 63 64 //===----------------------------------------------------------------------===// 65 // tileUsingSCFForOp implementation. 66 //===----------------------------------------------------------------------===// 67 68 // Check if `stride` evenly divides the trip count `size - offset`. 69 static bool tileDividesIterationDomain(Range loopRange) { 70 std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 71 if (!offsetAsInt) 72 return false; 73 std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 74 if (!sizeAsInt) 75 return false; 76 std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 77 if (!strideAsInt) 78 return false; 79 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 80 } 81 82 /// Returns the bounded tile size given the current `iv`, `loopRange` and 83 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. 84 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 85 Range loopRange, Value iv, 86 Value tileSize) { 87 std::optional<int64_t> ts = getConstantIntValue(tileSize); 88 if (ts && ts.value() == 1) 89 return getAsOpFoldResult(tileSize); 90 91 if (tileDividesIterationDomain( 92 Range{loopRange.offset, loopRange.size, tileSize})) 93 return tileSize; 94 95 // The tile size to use (to avoid out of bounds access) is minimum of 96 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 97 // loop. 98 AffineExpr s0, s1, d0; 99 bindDims(b.getContext(), d0); 100 bindSymbols(b.getContext(), s0, s1); 101 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); 102 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 103 return makeComposedFoldedAffineMin( 104 b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size}); 105 } 106 107 /// Generate an empty loop nest that represents the tiled loop nest shell. 108 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 109 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 110 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 111 /// the 112 /// tile processed within the inner most loop. 113 static SmallVector<scf::ForOp> 114 generateTileLoopNest(OpBuilder &builder, Location loc, 115 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 116 SmallVector<OpFoldResult> &offsets, 117 SmallVector<OpFoldResult> &sizes) { 118 assert(!loopRanges.empty() && "expected at least one loop range"); 119 assert(loopRanges.size() == tileSizeVals.size() && 120 "expected as many tile sizes as loop ranges"); 121 OpBuilder::InsertionGuard guard(builder); 122 SmallVector<scf::ForOp> loops; 123 offsets.resize(loopRanges.size()); 124 sizes.resize(loopRanges.size()); 125 126 for (auto loopRange : llvm::enumerate(loopRanges)) { 127 Value offset = 128 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); 129 Value size = 130 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); 131 Value tileSize = tileSizeVals[loopRange.index()]; 132 // No loops if tile size is zero. Set offset and size to the loop 133 // offset and size. 134 if (matchPattern(tileSize, m_Zero())) { 135 offsets[loopRange.index()] = offset; 136 sizes[loopRange.index()] = size; 137 continue; 138 } 139 140 auto loop = builder.create<scf::ForOp>( 141 loc, offset, size, tileSize, ValueRange{}, 142 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 143 ValueRange /*iterArgs*/) { 144 sizes[loopRange.index()] = getBoundedTileSize( 145 bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize); 146 builder.create<scf::YieldOp>(loc); 147 }); 148 offsets[loopRange.index()] = loop.getInductionVar(); 149 loops.push_back(loop); 150 builder.setInsertionPoint(loop.getBody()->getTerminator()); 151 } 152 return loops; 153 } 154 155 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, 156 /// construct the destructive update pattern that inserts the yielded 157 /// value into a destination tensor provided by `initValue` at offset 158 /// `tileOffsets` and size `tileSizes`. For example, 159 /// 160 /// ```mlir 161 /// scf.for %iv0 = ... { 162 /// %0 = tiled_op 163 /// } 164 /// ``` 165 /// 166 /// is transformed to 167 /// 168 /// ```mlir 169 /// scf.for %iv0 = ... iter_args(%arg = %0) { 170 /// %1 = tensor.extract_slice %arg 171 /// %2 = tiled_op 172 /// %3 = tensor.insert_slice %2 into %arg 173 /// scf.yield %3 174 /// } 175 /// ``` 176 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. 177 static SmallVector<Value> 178 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, 179 ValueRange yieldedValues, 180 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 181 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 182 MutableArrayRef<scf::ForOp> loops) { 183 NewYieldValueFn yieldValueFn = 184 [&](OpBuilder &b, Location loc, 185 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 186 SmallVector<Value> inserts; 187 for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) { 188 ArrayRef<OpFoldResult> tileOffsets = 189 tileOffsetsList[yieldedValue.index()]; 190 ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()]; 191 SmallVector<OpFoldResult> tileStrides(tileOffsets.size(), 192 b.getIndexAttr(1)); 193 Value insert = b.create<tensor::InsertSliceOp>( 194 loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], 195 tileOffsets, tileSizes, tileStrides); 196 inserts.push_back(insert); 197 } 198 return inserts; 199 }; 200 201 SmallVector<scf::ForOp> newLoops = 202 replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, 203 /*replaceIterOperandsUsesInLoop =*/false); 204 for (const auto &loop : llvm::enumerate(loops)) { 205 rewriter.eraseOp(loop.value()); 206 loops[loop.index()] = newLoops[loop.index()]; 207 } 208 return llvm::to_vector(llvm::map_range( 209 loops.front().getResults().take_back(yieldedValues.size()), 210 [](OpResult r) -> Value { return r; })); 211 } 212 213 /// If the tiled operation is destination passing style, update the 214 /// slice of the destination used (which refers to the untiled destination) 215 /// to use the corresponding region argument of the innermost loop. 216 /// 217 /// ```mlir 218 /// %0 = 219 /// scf.for %iv0 = ... iter_args(%arg = %0) { 220 /// %1 = tensor.extract_slice %0 221 /// %2 = tiled_op 222 /// %3 = tensor.insert_slice %2 into %arg 223 /// scf.yield %3 224 /// } 225 /// ``` 226 /// 227 /// is transformed to 228 /// 229 /// ```mlir 230 /// scf.for %iv0 = ... iter_args(%arg = %0) { 231 /// %1 = tensor.extract_slice %arg 232 /// %2 = tiled_op 233 /// %3 = tensor.insert_slice %2 into %arg 234 /// scf.yield %3 235 /// } 236 /// ``` 237 static void 238 updateDestinationOperandsForTiledOp(OpBuilder &builder, 239 ValueRange tiledOpDestinationValues, 240 ValueRange bbArgsList) { 241 for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) { 242 auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>(); 243 if (!sliceOp) 244 continue; 245 sliceOp.setOperand(0, bbArgsList[destValue.index()]); 246 } 247 } 248 249 /// Helper method to yield the values of the tiled op, as well as 250 /// update the destination operands of the tiled op, if it is 251 /// a destination passing style op. 252 static SmallVector<Value> 253 yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues, 254 TilingResult tilingResult, 255 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 256 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 257 MutableArrayRef<scf::ForOp> loops) { 258 SmallVector<Value> replacements = 259 yieldTiledValues(rewriter, initValues, tilingResult.tiledValues, 260 tileOffsetsList, tileSizesList, loops); 261 for (auto tiledOp : tilingResult.tiledOps) { 262 if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) { 263 auto innerMostLoop = loops.back(); 264 SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands(); 265 updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, 266 innerMostLoop.getRegionIterArgs()); 267 } 268 } 269 return replacements; 270 } 271 272 /// Implementation of tiling transformation of `op` that implements the 273 /// `TilingInterface` using `scf.for` to iterate over the tiles. 274 FailureOr<scf::SCFTilingResult> 275 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, 276 const scf::SCFTilingOptions &options) { 277 OpBuilder::InsertionGuard guard(rewriter); 278 rewriter.setInsertionPointAfter(op); 279 280 if (!options.tileSizeComputationFunction) { 281 return rewriter.notifyMatchFailure( 282 op, "missing tile size computation function"); 283 } 284 285 // 1. Get the range of the loops that are represented by the operation. 286 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 287 size_t numLoops = iterationDomain.size(); 288 if (numLoops == 0) { 289 return rewriter.notifyMatchFailure( 290 op, "unable to tile op with no iteration domain"); 291 } 292 293 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 294 // skips tiling a particular dimension. This convention is significantly 295 // simpler to handle instead of adjusting affine maps to account for missing 296 // dimensions. 297 SmallVector<Value> tileSizeVector = 298 options.tileSizeComputationFunction(rewriter, op); 299 if (tileSizeVector.size() < iterationDomain.size()) { 300 auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 301 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 302 } 303 304 scf::SCFTilingResult tilingResult; 305 SmallVector<OpFoldResult> offsets, sizes; 306 { 307 // If there is an interchange specified, permute the iteration domain and 308 // the tile sizes. 309 SmallVector<int64_t> interchangeVector; 310 if (!options.interchangeVector.empty()) { 311 interchangeVector = fillInterchangeVector(options.interchangeVector, 312 iterationDomain.size()); 313 } 314 if (!interchangeVector.empty()) { 315 if (!isPermutationVector(interchangeVector)) { 316 return rewriter.notifyMatchFailure( 317 op, "invalid intechange vector, not a permutation of the entire " 318 "iteration space"); 319 } 320 321 applyPermutationToVector(iterationDomain, interchangeVector); 322 applyPermutationToVector(tileSizeVector, interchangeVector); 323 } 324 325 // 3. Materialize an empty loop nest that iterates over the tiles. These 326 // loops for now do not return any values even if the original operation has 327 // results. 328 tilingResult.loops = generateTileLoopNest( 329 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 330 331 if (!interchangeVector.empty()) { 332 auto inversePermutation = invertPermutationVector(interchangeVector); 333 applyPermutationToVector(offsets, inversePermutation); 334 applyPermutationToVector(sizes, inversePermutation); 335 } 336 } 337 338 LLVM_DEBUG({ 339 if (!tilingResult.loops.empty()) { 340 llvm::dbgs() << "LoopNest shell :\n"; 341 tilingResult.loops.front().dump(); 342 llvm::dbgs() << "\n"; 343 } 344 }); 345 346 // 4. Generate the tiled implementation within the inner most loop. 347 if (!tilingResult.loops.empty()) 348 rewriter.setInsertionPoint( 349 tilingResult.loops.back().getBody()->getTerminator()); 350 FailureOr<TilingResult> tiledImplementation = 351 op.getTiledImplementation(rewriter, offsets, sizes); 352 tilingResult.tiledOps.append(tiledImplementation->tiledOps); 353 if (op->getNumResults() == 0) { 354 // nothing more to do. 355 return tilingResult; 356 } 357 358 // If loops are empty, the tiled op is used as the replacement for the untiled 359 // op. 360 if (tilingResult.loops.empty()) { 361 tilingResult.replacements = tiledImplementation->tiledValues; 362 return tilingResult; 363 } 364 365 // 5. Yield all the results of the tiled operation. The surrounding loop 366 // nest is modified to insert a destructive update pattern to yield 367 // from the loop nest values to replace the untiled op with. 368 int64_t numResults = op->getNumResults(); 369 SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults), 370 resultSizesList(numResults); 371 for (const auto &result : llvm::enumerate(op->getResults())) { 372 if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, 373 sizes, 374 resultOffsetsList[result.index()], 375 resultSizesList[result.index()]))) { 376 return rewriter.notifyMatchFailure( 377 op, "failed to get slice of result produced"); 378 } 379 } 380 381 SmallVector<Value> destinationTensors; 382 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 383 destinationTensors))) 384 return rewriter.notifyMatchFailure(op, "failed to get destinations"); 385 386 tilingResult.replacements = yieldTiledValues( 387 rewriter, destinationTensors, tiledImplementation.value(), 388 resultOffsetsList, resultSizesList, tilingResult.loops); 389 390 LLVM_DEBUG({ 391 if (!tilingResult.loops.empty()) { 392 llvm::dbgs() << "After tiled implementation :\n"; 393 tilingResult.loops.front().dump(); 394 llvm::dbgs() << "\n"; 395 } 396 }); 397 return tilingResult; 398 } 399 400 FailureOr<scf::SCFReductionTilingResult> 401 mlir::scf::tileReductionUsingScf(RewriterBase &b, 402 PartialReductionOpInterface op, 403 ArrayRef<OpFoldResult> tileSize) { 404 Location loc = op.getLoc(); 405 // Ops implementing PartialReductionOpInterface are expected to implement 406 // TilingInterface. 407 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 408 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 409 SmallVector<Value> tileSizeVector = 410 getValueOrCreateConstantIndexOp(b, loc, tileSize); 411 if (tileSizeVector.size() < iterationDomain.size()) { 412 auto zero = b.create<arith::ConstantIndexOp>(loc, 0); 413 tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero); 414 } 415 if (op->getNumResults() != 1) 416 return b.notifyMatchFailure( 417 op, "don't support ops with multiple results for now"); 418 SmallVector<utils::IteratorType> iterators = 419 tilingInterfaceOp.getLoopIteratorTypes(); 420 int64_t numReductionDims = llvm::count( 421 tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction); 422 if (numReductionDims != 1) 423 return b.notifyMatchFailure( 424 op, "only support ops with one reduction dimension."); 425 int reductionDim; 426 for (auto [idx, iteratorType] : 427 llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { 428 if (iteratorType == utils::IteratorType::reduction) { 429 reductionDim = idx; 430 break; 431 } 432 } 433 if (static_cast<size_t>(reductionDim) >= tileSize.size()) 434 return b.notifyMatchFailure(op, "reduction dimension must be tiled"); 435 436 // 1. create the inital tensor value. 437 FailureOr<Operation *> identityTensor = 438 op.generateInitialTensorForPartialReduction(b, loc, tileSize, 439 reductionDim); 440 if (failed(identityTensor)) 441 return b.notifyMatchFailure(op, 442 "cannot create a tensor of identity value."); 443 // 2. Create the nested loops. 444 SmallVector<OpFoldResult> offsets, sizes; 445 SmallVector<scf::ForOp> loops = generateTileLoopNest( 446 b, loc, iterationDomain, tileSizeVector, offsets, sizes); 447 448 // 3. Generate the tiled implementation within the inner most loop. 449 b.setInsertionPoint(loops.back().getBody()->getTerminator()); 450 Operation *parallelOp = op.tileToPartialReduction( 451 b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDim); 452 453 SmallVector<OpFoldResult> resultSizesList; 454 for (size_t i = 0; i < offsets.size(); i++) 455 resultSizesList.push_back( 456 b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i)); 457 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); 458 SmallVector<Value> replacements = yieldTiledValues( 459 b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets, 460 resultSizesList, loops); 461 462 auto dstOp = cast<DestinationStyleOpInterface>(parallelOp); 463 auto innerMostLoop = loops.back(); 464 SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands(); 465 assert(destinationTensors.size() == 466 innerMostLoop.getRegionIterArgs().size() && 467 "unexpected number of outputs"); 468 updateDestinationOperandsForTiledOp(b, destinationTensors, 469 innerMostLoop.getRegionIterArgs()); 470 471 // 4. Apply the merge reduction to combine all the partial values. 472 b.setInsertionPointAfter(*loops.begin()); 473 Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim); 474 b.replaceOp(op, mergeOp->getResults()); 475 476 SCFReductionTilingResult results; 477 results.initialOp = *identityTensor; 478 results.loops = std::move(loops); 479 results.parallelTiledOp = parallelOp; 480 results.mergeOp = mergeOp; 481 return results; 482 } 483 //===----------------------------------------------------------------------===// 484 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. 485 //===----------------------------------------------------------------------===// 486 487 /// Return the untiled producer whose slice is used in a tiled consumer. The 488 /// method traverses the tile loop nest (`loops`) if needed, and returns the 489 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 490 /// indicates that this is a destination operand of the consumer. If there was 491 /// no loop traversal needed, the second value of the returned tuple is empty. 492 static std::tuple<OpResult, std::optional<OpOperand *>> 493 getUntiledProducerFromSliceSource(OpOperand *source, 494 ArrayRef<scf::ForOp> loops) { 495 std::optional<OpOperand *> destinationIterArg; 496 auto loopIt = loops.rbegin(); 497 while (auto iterArg = source->get().dyn_cast<BlockArgument>()) { 498 scf::ForOp loop = *loopIt; 499 if (iterArg.getOwner()->getParentOp() != loop) 500 break; 501 source = &loop.getOpOperandForRegionIterArg(iterArg); 502 loopIt++; 503 } 504 if (loopIt == loops.rend()) 505 destinationIterArg = source; 506 return {source->get().dyn_cast<OpResult>(), destinationIterArg}; 507 } 508 509 /// Implementation of fusing producer of a single slice by computing the 510 /// slice of the producer in-place. 511 std::optional<scf::SCFFuseProducerOfSliceResult> 512 mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, 513 tensor::ExtractSliceOp candidateSliceOp, 514 MutableArrayRef<scf::ForOp> loops) { 515 // 1. Get the producer of the source (potentially walking through 516 // `iter_args` of nested `scf.for`) 517 auto [fusableProducer, destinationIterArg] = 518 getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0), 519 loops); 520 if (!fusableProducer) 521 return std::nullopt; 522 523 // 2. Generate the tiled implementation of the producer of the source 524 OpBuilder::InsertionGuard g(rewriter); 525 rewriter.setInsertionPoint(candidateSliceOp); 526 FailureOr<TilingResult> tileAndFuseResult = 527 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 528 fusableProducer); 529 if (failed(tileAndFuseResult)) 530 return std::nullopt; 531 rewriter.replaceAllUsesWith(candidateSliceOp, 532 tileAndFuseResult->tiledValues[0]); 533 534 // 3. If the slice is for a destination operand, for example, 535 // 536 // ```mlir 537 // %0 = linalg.init 538 // %1 = linalg.fill .. outs(%0 : ) 539 // %2 = scf.for .. iter_args(%arg0 = %1) { 540 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 541 // %4 = tensor.extract_slice %arg1 [..] 542 // .. = linalg.matmul .. outs(%4 : ) 543 // } 544 // } 545 // ``` 546 // 547 // the IR is currently 548 // 549 // ``` 550 // %0 = linalg.init 551 // %1 = linalg.fill 552 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 553 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 554 // %4 = tensor.extract_slice %0 /*incorrect value */ [..] 555 // %5 = linalg.fill .. outs(%4 : ) 556 // .. = linalg.matmul .. outs(%5 : ) 557 // } 558 // } 559 // ``` 560 // 561 // The untiled `linalg.fill` is still used as the `init_value` since it 562 // was originally a destination operand of the untiled `linalg.matmul`. 563 // When fusing an operand that is a destination operand. 564 // - Update the iter_arg of the outer most loop to use the destination 565 // of the untiled producer. 566 // - Update the destination of the slice of the tiled producer generated 567 // to use the same basic block argument as the slice that was used to 568 // generate inplace the tiled implementation of the producer. 569 // With this the IR will be. 570 // 571 // ``` 572 // %0 = linalg.init 573 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 574 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 575 // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] 576 // %4 = linalg.fill .. outs(%3 : ) 577 // .. = linalg.matmul .. outs(%4 : ) 578 // } 579 // } 580 // ``` 581 // TODO: This can be modeled better if the `DestinationStyleOpInterface`. 582 // Update to use that when it does become available. 583 scf::ForOp outerMostLoop = loops.front(); 584 std::optional<unsigned> iterArgNumber; 585 if (destinationIterArg) { 586 iterArgNumber = 587 outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value()); 588 } 589 if (iterArgNumber) { 590 int64_t resultNumber = fusableProducer.getResultNumber(); 591 if (auto dstOp = 592 dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) { 593 outerMostLoop.setIterArg(iterArgNumber.value(), 594 dstOp.getTiedOpOperand(fusableProducer)->get()); 595 } 596 for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) { 597 auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp); 598 if (!dstOp) 599 continue; 600 scf::ForOp innerMostLoop = loops.back(); 601 updateDestinationOperandsForTiledOp( 602 rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), 603 innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); 604 } 605 } 606 return scf::SCFFuseProducerOfSliceResult{fusableProducer, 607 tileAndFuseResult->tiledValues[0], 608 tileAndFuseResult->tiledOps}; 609 } 610 611 /// Reconstruct the fused producer from within the tiled-and-fused code. 612 void mlir::scf::yieldReplacementForFusedProducer( 613 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, 614 scf::SCFFuseProducerOfSliceResult fusedProducerInfo, 615 MutableArrayRef<scf::ForOp> loops) { 616 auto [fusableProducer, fusedProducerValue, tileAndFusedOps] = 617 fusedProducerInfo; 618 SmallVector<Value> initValues; 619 FailureOr<Value> initValue = tensor::getOrCreateDestination( 620 rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); 621 if (succeeded(initValue)) { 622 SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets(); 623 SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes(); 624 SmallVector<Value> yieldedVals = 625 yieldTiledValues(rewriter, initValue.value(), fusedProducerValue, 626 resultOffsets, resultSizes, loops); 627 } 628 for (auto tileAndFusedOp : tileAndFusedOps) { 629 auto dstStyleProducer = 630 dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp); 631 if (!dstStyleProducer) 632 continue; 633 Value dstValue = 634 dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber()) 635 ->get(); 636 updateDestinationOperandsForTiledOp( 637 rewriter, dstValue, loops.back().getRegionIterArgs().back()); 638 } 639 } 640 641 /// Implementation of tile consumer and fuse producer greedily. 642 FailureOr<scf::SCFTileAndFuseResult> 643 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( 644 RewriterBase &rewriter, TilingInterface consumer, 645 const scf::SCFTileAndFuseOptions &options) { 646 // This transformation is only valid for ops that return values (i.e. not 647 // valid to use with operations that have memref operands). 648 if (!consumer->getNumResults()) { 649 return rewriter.notifyMatchFailure( 650 consumer, "invalid pattern for op with no results"); 651 } 652 653 // 1. First tile the consumer. 654 scf::SCFTileAndFuseResult tileAndFuseResult; 655 llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber; 656 { 657 FailureOr<scf::SCFTilingResult> tilingResult = 658 tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); 659 if (failed(tilingResult)) 660 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 661 for (auto *tiledOp : tilingResult->tiledOps) 662 tileAndFuseResult.tiledAndFusedOps.insert(tiledOp); 663 tileAndFuseResult.loops = std::move(tilingResult->loops); 664 for (const auto &result : llvm::enumerate( 665 llvm::zip(consumer->getResults(), tilingResult->replacements))) { 666 tileAndFuseResult.replacements[std::get<0>(result.value())] = 667 std::get<1>(result.value()); 668 yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult( 669 result.index())] = result.index(); 670 } 671 } 672 673 // If there are no loops generated, fusion is immaterial. 674 if (tileAndFuseResult.loops.empty()) 675 return tileAndFuseResult; 676 677 // 2. Typically, the operands of the tiled operation are slices of the 678 // operands of the untiled operation. These are expressed in IR using 679 // `tensor.extract_slice` operations with source being the operands of the 680 // untiled operation. Create a worklist of these `tensor.extract_slice` 681 // operations. If the producers of the source of the `tensor.extract_slice` 682 // can be tiled such that the tiled value is generated in-place, that 683 // effectively tiles + fuses the operations. 684 auto addCandidateSlices = [](Operation *fusedOp, 685 std::deque<tensor::ExtractSliceOp> &candidates) { 686 for (Value operand : fusedOp->getOperands()) 687 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 688 candidates.push_back(sliceOp); 689 }; 690 691 std::deque<tensor::ExtractSliceOp> candidates; 692 addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 693 OpBuilder::InsertionGuard g(rewriter); 694 while (!candidates.empty()) { 695 // Traverse the slices in BFS fashion. 696 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 697 candidates.pop_front(); 698 699 // The operands of the fused producer might themselved be slices of 700 // values produced by operations that implement the `TilingInterface`. 701 // Add these operations to the worklist. 702 std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer = 703 tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, 704 tileAndFuseResult.loops); 705 if (!fusedProducer) 706 continue; 707 708 if (Operation *tiledAndFusedOp = 709 fusedProducer->tiledAndFusedProducer.getDefiningOp()) { 710 tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp); 711 addCandidateSlices(tiledAndFusedOp, candidates); 712 } 713 } 714 return tileAndFuseResult; 715 } 716 717 //===----------------------------------------------------------------------===// 718 // lowerToLoopsUsingSCFForOp implementation. 719 //===----------------------------------------------------------------------===// 720 721 FailureOr<SmallVector<scf::ForOp>> 722 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 723 TilingInterface op) { 724 // TODO: Handle cases where the op has results if needed. 725 if (op->getNumResults() > 0) { 726 return rewriter.notifyMatchFailure( 727 op, "unable to lower to loops operations with return values"); 728 } 729 730 SmallVector<Range> domain = op.getIterationDomain(rewriter); 731 SmallVector<Value> ivs; 732 SmallVector<scf::ForOp> loops; 733 Location loc = op.getLoc(); 734 for (auto loopRange : domain) { 735 Value offsetVal = 736 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 737 Value sizeVal = 738 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 739 Value strideVal = 740 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 741 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 742 strideVal, ValueRange{}); 743 loops.push_back(loop); 744 ivs.push_back(loop.getInductionVar()); 745 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 746 } 747 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 748 return failure(); 749 } 750 return loops; 751 } 752