1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 25 #include "mlir/Interfaces/TilingInterface.h" 26 #include "llvm/Support/Debug.h" 27 28 #define DEBUG_TYPE "tile-using-interface" 29 30 using namespace mlir; 31 32 scf::SCFTilingOptions & 33 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 34 assert(!tileSizeComputationFunction && "tile sizes already set"); 35 SmallVector<int64_t> tileSizes(ts.begin(), ts.end()); 36 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 37 OpBuilder::InsertionGuard guard(b); 38 b.setInsertionPointToStart( 39 &op->getParentOfType<func::FuncOp>().getBody().front()); 40 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 41 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 42 return v; 43 })); 44 }; 45 return *this; 46 } 47 48 /// Helper method to adjust the interchange vector to match the iteration 49 /// domain. 50 static SmallVector<int64_t> 51 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 52 size_t iterationDomainSize) { 53 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 54 if (filledVector.size() < iterationDomainSize) { 55 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 56 filledVector.append(range.begin(), range.end()); 57 } 58 if (filledVector.size() > iterationDomainSize) 59 filledVector.resize(iterationDomainSize); 60 return filledVector; 61 } 62 63 //===----------------------------------------------------------------------===// 64 // tileUsingSCFForOp implementation. 65 //===----------------------------------------------------------------------===// 66 67 // Check if `stride` evenly divides the trip count `size - offset`. 68 static bool tileDividesIterationDomain(Range loopRange) { 69 std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 70 if (!offsetAsInt) 71 return false; 72 std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 73 if (!sizeAsInt) 74 return false; 75 std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 76 if (!strideAsInt) 77 return false; 78 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 79 } 80 81 /// Returns the bounded tile size given the current `iv`, `loopRange` and 82 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. 83 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 84 Range loopRange, Value iv, 85 Value tileSize) { 86 std::optional<int64_t> ts = getConstantIntValue(tileSize); 87 if (ts && ts.value() == 1) 88 return getAsOpFoldResult(tileSize); 89 90 if (tileDividesIterationDomain( 91 Range{loopRange.offset, loopRange.size, tileSize})) 92 return tileSize; 93 94 // The tile size to use (to avoid out of bounds access) is minimum of 95 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 96 // loop. 97 AffineExpr s0, s1, d0; 98 bindDims(b.getContext(), d0); 99 bindSymbols(b.getContext(), s0, s1); 100 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); 101 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 102 return makeComposedFoldedAffineMin( 103 b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size}); 104 } 105 106 /// Generate an empty loop nest that represents the tiled loop nest shell. 107 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 108 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 109 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 110 /// the 111 /// tile processed within the inner most loop. 112 static SmallVector<scf::ForOp> 113 generateTileLoopNest(OpBuilder &builder, Location loc, 114 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 115 SmallVector<OpFoldResult> &offsets, 116 SmallVector<OpFoldResult> &sizes) { 117 assert(!loopRanges.empty() && "expected at least one loop range"); 118 assert(loopRanges.size() == tileSizeVals.size() && 119 "expected as many tile sizes as loop ranges"); 120 OpBuilder::InsertionGuard guard(builder); 121 SmallVector<scf::ForOp> loops; 122 offsets.resize(loopRanges.size()); 123 sizes.resize(loopRanges.size()); 124 125 for (auto loopRange : llvm::enumerate(loopRanges)) { 126 Value offset = 127 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); 128 Value size = 129 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); 130 Value tileSize = tileSizeVals[loopRange.index()]; 131 // No loops if tile size is zero. Set offset and size to the loop 132 // offset and size. 133 if (matchPattern(tileSize, m_Zero())) { 134 offsets[loopRange.index()] = offset; 135 sizes[loopRange.index()] = size; 136 continue; 137 } 138 139 auto loop = builder.create<scf::ForOp>( 140 loc, offset, size, tileSize, ValueRange{}, 141 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 142 ValueRange /*iterArgs*/) { 143 sizes[loopRange.index()] = getBoundedTileSize( 144 bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize); 145 builder.create<scf::YieldOp>(loc); 146 }); 147 offsets[loopRange.index()] = loop.getInductionVar(); 148 loops.push_back(loop); 149 builder.setInsertionPoint(loop.getBody()->getTerminator()); 150 } 151 return loops; 152 } 153 154 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, 155 /// construct the destructive update pattern that inserts the yielded 156 /// value into a destination tensor provided by `initValue` at offset 157 /// `tileOffsets` and size `tileSizes`. For example, 158 /// 159 /// ```mlir 160 /// scf.for %iv0 = ... { 161 /// %0 = tiled_op 162 /// } 163 /// ``` 164 /// 165 /// is transformed to 166 /// 167 /// ```mlir 168 /// scf.for %iv0 = ... iter_args(%arg = %0) { 169 /// %1 = tensor.extract_slice %arg 170 /// %2 = tiled_op 171 /// %3 = tensor.insert_slice %2 into %arg 172 /// scf.yield %3 173 /// } 174 /// ``` 175 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. 176 static SmallVector<Value> 177 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, 178 ValueRange yieldedValues, 179 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 180 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 181 MutableArrayRef<scf::ForOp> loops) { 182 NewYieldValueFn yieldValueFn = 183 [&](OpBuilder &b, Location loc, 184 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 185 SmallVector<Value> inserts; 186 for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) { 187 ArrayRef<OpFoldResult> tileOffsets = 188 tileOffsetsList[yieldedValue.index()]; 189 ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()]; 190 SmallVector<OpFoldResult> tileStrides(tileOffsets.size(), 191 b.getIndexAttr(1)); 192 Value insert = b.create<tensor::InsertSliceOp>( 193 loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], 194 tileOffsets, tileSizes, tileStrides); 195 inserts.push_back(insert); 196 } 197 return inserts; 198 }; 199 200 SmallVector<scf::ForOp> newLoops = 201 replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, 202 /*replaceIterOperandsUsesInLoop =*/false); 203 for (const auto &loop : llvm::enumerate(loops)) { 204 rewriter.eraseOp(loop.value()); 205 loops[loop.index()] = newLoops[loop.index()]; 206 } 207 return llvm::to_vector(llvm::map_range( 208 loops.front().getResults().take_back(yieldedValues.size()), 209 [](OpResult r) -> Value { return r; })); 210 } 211 212 /// If the tiled operation is destination passing style, update the 213 /// slice of the destination used (which refers to the untiled destination) 214 /// to use the corresponding region argument of the innermost loop. 215 /// 216 /// ```mlir 217 /// %0 = 218 /// scf.for %iv0 = ... iter_args(%arg = %0) { 219 /// %1 = tensor.extract_slice %0 220 /// %2 = tiled_op 221 /// %3 = tensor.insert_slice %2 into %arg 222 /// scf.yield %3 223 /// } 224 /// ``` 225 /// 226 /// is transformed to 227 /// 228 /// ```mlir 229 /// scf.for %iv0 = ... iter_args(%arg = %0) { 230 /// %1 = tensor.extract_slice %arg 231 /// %2 = tiled_op 232 /// %3 = tensor.insert_slice %2 into %arg 233 /// scf.yield %3 234 /// } 235 /// ``` 236 static void 237 updateDestinationOperandsForTiledOp(OpBuilder &builder, 238 ValueRange tiledOpDestinationValues, 239 ValueRange bbArgsList) { 240 for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) { 241 auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>(); 242 if (!sliceOp) 243 continue; 244 sliceOp.setOperand(0, bbArgsList[destValue.index()]); 245 } 246 } 247 248 /// Helper method to yield the values of the tiled op, as well as 249 /// update the destination operands of the tiled op, if it is 250 /// a destination passing style op. 251 static SmallVector<Value> 252 yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues, 253 Operation *tiledOp, 254 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 255 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 256 MutableArrayRef<scf::ForOp> loops) { 257 SmallVector<Value> replacements = 258 yieldTiledValues(rewriter, initValues, tiledOp->getResults(), 259 tileOffsetsList, tileSizesList, loops); 260 if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) { 261 auto innerMostLoop = loops.back(); 262 SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands(); 263 updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, 264 innerMostLoop.getRegionIterArgs()); 265 } 266 return replacements; 267 } 268 269 /// Implementation of tiling transformation of `op` that implements the 270 /// `TilingInterface` using `scf.for` to iterate over the tiles. 271 FailureOr<scf::SCFTilingResult> 272 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, 273 const scf::SCFTilingOptions &options) { 274 OpBuilder::InsertionGuard guard(rewriter); 275 rewriter.setInsertionPointAfter(op); 276 277 if (!options.tileSizeComputationFunction) { 278 return rewriter.notifyMatchFailure( 279 op, "missing tile size computation function"); 280 } 281 282 // 1. Get the range of the loops that are represented by the operation. 283 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 284 size_t numLoops = iterationDomain.size(); 285 if (numLoops == 0) { 286 return rewriter.notifyMatchFailure( 287 op, "unable to tile op with no iteration domain"); 288 } 289 290 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 291 // skips tiling a particular dimension. This convention is significantly 292 // simpler to handle instead of adjusting affine maps to account for missing 293 // dimensions. 294 SmallVector<Value> tileSizeVector = 295 options.tileSizeComputationFunction(rewriter, op); 296 if (tileSizeVector.size() < iterationDomain.size()) { 297 auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 298 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 299 } 300 301 scf::SCFTilingResult tilingResult; 302 SmallVector<OpFoldResult> offsets, sizes; 303 { 304 // If there is an interchange specified, permute the iteration domain and 305 // the tile sizes. 306 SmallVector<int64_t> interchangeVector; 307 if (!options.interchangeVector.empty()) { 308 interchangeVector = fillInterchangeVector(options.interchangeVector, 309 iterationDomain.size()); 310 } 311 if (!interchangeVector.empty()) { 312 if (!isPermutationVector(interchangeVector)) { 313 return rewriter.notifyMatchFailure( 314 op, "invalid intechange vector, not a permutation of the entire " 315 "iteration space"); 316 } 317 318 applyPermutationToVector(iterationDomain, interchangeVector); 319 applyPermutationToVector(tileSizeVector, interchangeVector); 320 } 321 322 // 3. Materialize an empty loop nest that iterates over the tiles. These 323 // loops for now do not return any values even if the original operation has 324 // results. 325 tilingResult.loops = generateTileLoopNest( 326 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 327 328 if (!interchangeVector.empty()) { 329 auto inversePermutation = invertPermutationVector(interchangeVector); 330 applyPermutationToVector(offsets, inversePermutation); 331 applyPermutationToVector(sizes, inversePermutation); 332 } 333 } 334 335 LLVM_DEBUG({ 336 if (!tilingResult.loops.empty()) { 337 llvm::dbgs() << "LoopNest shell :\n"; 338 tilingResult.loops.front().dump(); 339 llvm::dbgs() << "\n"; 340 } 341 }); 342 343 // 4. Generate the tiled implementation within the inner most loop. 344 if (!tilingResult.loops.empty()) 345 rewriter.setInsertionPoint( 346 tilingResult.loops.back().getBody()->getTerminator()); 347 SmallVector<Operation *> tiledImplementation = 348 op.getTiledImplementation(rewriter, offsets, sizes); 349 tilingResult.tiledOps.append(tiledImplementation); 350 if (op->getNumResults() == 0) { 351 // nothing more to do. 352 return tilingResult; 353 } 354 355 // If loops are empty, the tiled op is used as the replacement for the untiled 356 // op. 357 if (tilingResult.loops.empty()) { 358 tilingResult.replacements = llvm::to_vector( 359 llvm::map_range(tiledImplementation[0]->getResults(), 360 [](OpResult result) -> Value { return result; })); 361 return tilingResult; 362 } 363 364 // 5. Yield all the results of the tiled operation. The surrounding loop 365 // nest is modified to insert a destructive update pattern to yield 366 // from the loop nest values to replace the untiled op with. 367 int64_t numResults = op->getNumResults(); 368 SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults), 369 resultSizesList(numResults); 370 for (const auto &result : llvm::enumerate(op->getResults())) { 371 if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, 372 sizes, 373 resultOffsetsList[result.index()], 374 resultSizesList[result.index()]))) { 375 return rewriter.notifyMatchFailure( 376 op, "failed to get slice of result produced"); 377 } 378 } 379 380 SmallVector<Value> destinationTensors; 381 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 382 destinationTensors))) 383 return rewriter.notifyMatchFailure(op, "failed to get destinations"); 384 385 tilingResult.replacements = yieldTiledValues( 386 rewriter, destinationTensors, tilingResult.tiledOps.back(), 387 resultOffsetsList, resultSizesList, tilingResult.loops); 388 389 LLVM_DEBUG({ 390 if (!tilingResult.loops.empty()) { 391 llvm::dbgs() << "After tiled implementation :\n"; 392 tilingResult.loops.front().dump(); 393 llvm::dbgs() << "\n"; 394 } 395 }); 396 return tilingResult; 397 } 398 399 FailureOr<scf::SCFReductionTilingResult> 400 mlir::scf::tileReductionUsingScf(PatternRewriter &b, 401 PartialReductionOpInterface op, 402 ArrayRef<OpFoldResult> tileSize) { 403 Location loc = op.getLoc(); 404 // Ops implementing PartialReductionOpInterface are expected to implement 405 // TilingInterface. 406 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 407 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 408 SmallVector<Value> tileSizeVector = 409 getValueOrCreateConstantIndexOp(b, loc, tileSize); 410 if (tileSizeVector.size() < iterationDomain.size()) { 411 auto zero = b.create<arith::ConstantIndexOp>(loc, 0); 412 tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero); 413 } 414 if (op->getNumResults() != 1) 415 return b.notifyMatchFailure( 416 op, "don't support ops with multiple results for now"); 417 SmallVector<utils::IteratorType> iterators = 418 tilingInterfaceOp.getLoopIteratorTypes(); 419 int64_t numReductionDims = llvm::count( 420 tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction); 421 if (numReductionDims != 1) 422 return b.notifyMatchFailure( 423 op, "only support ops with one reduction dimension."); 424 int reductionDim; 425 for (auto &[idx, iteratorType] : 426 llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { 427 if (iteratorType == utils::IteratorType::reduction) { 428 reductionDim = idx; 429 break; 430 } 431 } 432 if (static_cast<size_t>(reductionDim) >= tileSize.size()) 433 return b.notifyMatchFailure(op, "reduction dimension must be tiled"); 434 435 // 1. create the inital tensor value. 436 FailureOr<Operation *> identityTensor = 437 op.generateInitialTensorForPartialReduction(b, loc, tileSize, 438 reductionDim); 439 if (failed(identityTensor)) 440 return b.notifyMatchFailure(op, 441 "cannot create a tensor of identity value."); 442 // 2. Create the nested loops. 443 SmallVector<OpFoldResult> offsets, sizes; 444 SmallVector<scf::ForOp> loops = generateTileLoopNest( 445 b, loc, iterationDomain, tileSizeVector, offsets, sizes); 446 447 // 3. Generate the tiled implementation within the inner most loop. 448 b.setInsertionPoint(loops.back().getBody()->getTerminator()); 449 Operation *parallelOp = op.tileToPartialReduction( 450 b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDim); 451 452 SmallVector<OpFoldResult> resultSizesList; 453 for (size_t i = 0; i < offsets.size(); i++) 454 resultSizesList.push_back( 455 b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i)); 456 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); 457 SmallVector<Value> replacements = yieldTiledValues( 458 b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets, 459 resultSizesList, loops); 460 461 auto dstOp = cast<DestinationStyleOpInterface>(parallelOp); 462 auto innerMostLoop = loops.back(); 463 SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands(); 464 assert(destinationTensors.size() == 465 innerMostLoop.getRegionIterArgs().size() && 466 "unexpected number of outputs"); 467 updateDestinationOperandsForTiledOp(b, destinationTensors, 468 innerMostLoop.getRegionIterArgs()); 469 470 // 4. Apply the merge reduction to combine all the partial values. 471 b.setInsertionPointAfter(*loops.begin()); 472 Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim); 473 b.replaceOp(op, mergeOp->getResults()); 474 475 SCFReductionTilingResult results; 476 results.initialOp = *identityTensor; 477 results.loops = std::move(loops); 478 results.parallelTiledOp = parallelOp; 479 results.mergeOp = mergeOp; 480 return results; 481 } 482 //===----------------------------------------------------------------------===// 483 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. 484 //===----------------------------------------------------------------------===// 485 486 /// Return the untiled producer whose slice is used in a tiled consumer. The 487 /// method traverses the tile loop nest (`loops`) if needed, and returns the 488 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 489 /// indicates that this is a destination operand of the consumer. If there was 490 /// no loop traversal needed, the second value of the returned tuple is empty. 491 static std::tuple<OpResult, std::optional<OpOperand *>> 492 getUntiledProducerFromSliceSource(OpOperand *source, 493 ArrayRef<scf::ForOp> loops) { 494 std::optional<OpOperand *> destinationIterArg; 495 auto loopIt = loops.rbegin(); 496 while (auto iterArg = source->get().dyn_cast<BlockArgument>()) { 497 scf::ForOp loop = *loopIt; 498 if (iterArg.getOwner()->getParentOp() != loop) 499 break; 500 source = &loop.getOpOperandForRegionIterArg(iterArg); 501 loopIt++; 502 } 503 if (loopIt == loops.rend()) 504 destinationIterArg = source; 505 return {source->get().dyn_cast<OpResult>(), destinationIterArg}; 506 } 507 508 /// Implementation of fusing producer of a single slice by computing the 509 /// slice of the producer in-place. 510 std::optional<scf::SCFFuseProducerOfSliceResult> 511 mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, 512 tensor::ExtractSliceOp candidateSliceOp, 513 MutableArrayRef<scf::ForOp> loops) { 514 // 1. Get the producer of the source (potentially walking through 515 // `iter_args` of nested `scf.for`) 516 auto [fusableProducer, destinationIterArg] = 517 getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0), 518 loops); 519 if (!fusableProducer) 520 return std::nullopt; 521 522 // 2. Generate the tiled implementation of the producer of the source 523 OpBuilder::InsertionGuard g(rewriter); 524 rewriter.setInsertionPoint(candidateSliceOp); 525 FailureOr<Value> fusedProducerValue = 526 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 527 fusableProducer); 528 if (failed(fusedProducerValue)) 529 return std::nullopt; 530 rewriter.replaceAllUsesWith(candidateSliceOp, fusedProducerValue.value()); 531 532 // 3. If the slice is for a destination operand, for example, 533 // 534 // ```mlir 535 // %0 = linalg.init 536 // %1 = linalg.fill .. outs(%0 : ) 537 // %2 = scf.for .. iter_args(%arg0 = %1) { 538 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 539 // %4 = tensor.extract_slice %arg1 [..] 540 // .. = linalg.matmul .. outs(%4 : ) 541 // } 542 // } 543 // ``` 544 // 545 // the IR is currently 546 // 547 // ``` 548 // %0 = linalg.init 549 // %1 = linalg.fill 550 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 551 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 552 // %4 = tensor.extract_slice %0 /*incorrect value */ [..] 553 // %5 = linalg.fill .. outs(%4 : ) 554 // .. = linalg.matmul .. outs(%5 : ) 555 // } 556 // } 557 // ``` 558 // 559 // The untiled `linalg.fill` is still used as the `init_value` since it 560 // was originally a destination operand of the untiled `linalg.matmul`. 561 // When fusing an operand that is a destination operand. 562 // - Update the iter_arg of the outer most loop to use the destination 563 // of the untiled producer. 564 // - Update the destination of the slice of the tiled producer generated 565 // to use the same basic block argument as the slice that was used to 566 // generate inplace the tiled implementation of the producer. 567 // With this the IR will be. 568 // 569 // ``` 570 // %0 = linalg.init 571 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 572 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 573 // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] 574 // %4 = linalg.fill .. outs(%3 : ) 575 // .. = linalg.matmul .. outs(%4 : ) 576 // } 577 // } 578 // ``` 579 // TODO: This can be modeled better if the `DestinationStyleOpInterface`. 580 // Update to use that when it does become available. 581 scf::ForOp outerMostLoop = loops.front(); 582 Optional<unsigned> iterArgNumber; 583 if (destinationIterArg) { 584 iterArgNumber = 585 outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value()); 586 } 587 if (iterArgNumber) { 588 int64_t resultNumber = fusableProducer.getResultNumber(); 589 if (auto dstOp = 590 dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) { 591 outerMostLoop.setIterArg(iterArgNumber.value(), 592 dstOp.getTiedOpOperand(fusableProducer)->get()); 593 } 594 if (auto dstOp = fusedProducerValue.value() 595 .getDefiningOp<DestinationStyleOpInterface>()) { 596 scf::ForOp innerMostLoop = loops.back(); 597 updateDestinationOperandsForTiledOp( 598 rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), 599 innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); 600 } 601 } 602 return scf::SCFFuseProducerOfSliceResult{fusableProducer, 603 fusedProducerValue.value()}; 604 } 605 606 /// Reconstruct the fused producer from within the tiled-and-fused code. 607 void mlir::scf::yieldReplacementForFusedProducer( 608 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, 609 scf::SCFFuseProducerOfSliceResult fusedProducerInfo, 610 MutableArrayRef<scf::ForOp> loops) { 611 auto [fusableProducer, fusedProducerValue] = fusedProducerInfo; 612 SmallVector<Value> initValues; 613 FailureOr<Value> initValue = tensor::getOrCreateDestination( 614 rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); 615 if (succeeded(initValue)) { 616 SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets(); 617 SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes(); 618 SmallVector<Value> yieldedVals = 619 yieldTiledValues(rewriter, initValue.value(), fusedProducerValue, 620 resultOffsets, resultSizes, loops); 621 } 622 if (auto dstStyleProducer = 623 fusedProducerValue.getDefiningOp<DestinationStyleOpInterface>()) { 624 Value dstValue = 625 dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber()) 626 ->get(); 627 updateDestinationOperandsForTiledOp( 628 rewriter, dstValue, loops.back().getRegionIterArgs().back()); 629 } 630 } 631 632 /// Implementation of tile consumer and fuse producer greedily. 633 FailureOr<scf::SCFTileAndFuseResult> 634 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( 635 RewriterBase &rewriter, TilingInterface consumer, 636 const scf::SCFTileAndFuseOptions &options) { 637 // This transformation is only valid for ops that return values (i.e. not 638 // valid to use with operations that have memref operands). 639 if (!consumer->getNumResults()) { 640 return rewriter.notifyMatchFailure( 641 consumer, "invalid pattern for op with no results"); 642 } 643 644 // 1. First tile the consumer. 645 scf::SCFTileAndFuseResult tileAndFuseResult; 646 llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber; 647 { 648 FailureOr<scf::SCFTilingResult> tilingResult = 649 tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); 650 if (failed(tilingResult)) 651 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 652 for (auto *tiledOp : tilingResult->tiledOps) 653 tileAndFuseResult.tiledAndFusedOps.insert(tiledOp); 654 tileAndFuseResult.loops = std::move(tilingResult->loops); 655 for (const auto &result : llvm::enumerate( 656 llvm::zip(consumer->getResults(), tilingResult->replacements))) { 657 tileAndFuseResult.replacements[std::get<0>(result.value())] = 658 std::get<1>(result.value()); 659 yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult( 660 result.index())] = result.index(); 661 } 662 } 663 664 // If there are no loops generated, fusion is immaterial. 665 if (tileAndFuseResult.loops.empty()) 666 return tileAndFuseResult; 667 668 // 2. Typically, the operands of the tiled operation are slices of the 669 // operands of the untiled operation. These are expressed in IR using 670 // `tensor.extract_slice` operations with source being the operands of the 671 // untiled operation. Create a worklist of these `tensor.extract_slice` 672 // operations. If the producers of the source of the `tensor.extract_slice` 673 // can be tiled such that the tiled value is generated in-place, that 674 // effectively tiles + fuses the operations. 675 auto addCandidateSlices = [](Operation *fusedOp, 676 std::deque<tensor::ExtractSliceOp> &candidates) { 677 for (Value operand : fusedOp->getOperands()) 678 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 679 candidates.push_back(sliceOp); 680 }; 681 682 std::deque<tensor::ExtractSliceOp> candidates; 683 addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 684 OpBuilder::InsertionGuard g(rewriter); 685 while (!candidates.empty()) { 686 // Traverse the slices in BFS fashion. 687 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 688 candidates.pop_front(); 689 690 // The operands of the fused producer might themselved be slices of 691 // values produced by operations that implement the `TilingInterface`. 692 // Add these operations to the worklist. 693 std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer = 694 tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, 695 tileAndFuseResult.loops); 696 if (!fusedProducer) 697 continue; 698 699 if (Operation *tiledAndFusedOp = 700 fusedProducer->tiledAndFusedProducer.getDefiningOp()) { 701 tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp); 702 addCandidateSlices(tiledAndFusedOp, candidates); 703 } 704 } 705 return tileAndFuseResult; 706 } 707 708 //===----------------------------------------------------------------------===// 709 // lowerToLoopsUsingSCFForOp implementation. 710 //===----------------------------------------------------------------------===// 711 712 FailureOr<SmallVector<scf::ForOp>> 713 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 714 TilingInterface op) { 715 // TODO: Handle cases where the op has results if needed. 716 if (op->getNumResults() > 0) { 717 return rewriter.notifyMatchFailure( 718 op, "unable to lower to loops operations with return values"); 719 } 720 721 SmallVector<Range> domain = op.getIterationDomain(rewriter); 722 SmallVector<Value> ivs; 723 SmallVector<scf::ForOp> loops; 724 Location loc = op.getLoc(); 725 for (auto loopRange : domain) { 726 Value offsetVal = 727 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 728 Value sizeVal = 729 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 730 Value strideVal = 731 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 732 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 733 strideVal, ValueRange{}); 734 loops.push_back(loop); 735 ivs.push_back(loop.getInductionVar()); 736 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 737 } 738 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 739 return failure(); 740 } 741 return loops; 742 } 743