1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 25 #include "mlir/Interfaces/TilingInterface.h" 26 #include "llvm/Support/Debug.h" 27 #include <optional> 28 29 #define DEBUG_TYPE "tile-using-interface" 30 31 using namespace mlir; 32 33 scf::SCFTilingOptions & 34 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 35 assert(!tileSizeComputationFunction && "tile sizes already set"); 36 SmallVector<int64_t> tileSizes(ts.begin(), ts.end()); 37 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 38 OpBuilder::InsertionGuard guard(b); 39 b.setInsertionPointToStart( 40 &op->getParentWithTrait<OpTrait::IsIsolatedFromAbove>() 41 ->getRegion(0) 42 .front()); 43 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 44 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 45 return v; 46 })); 47 }; 48 return *this; 49 } 50 51 /// Helper method to adjust the interchange vector to match the iteration 52 /// domain. 53 static SmallVector<int64_t> 54 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 55 size_t iterationDomainSize) { 56 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 57 if (filledVector.size() < iterationDomainSize) { 58 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 59 filledVector.append(range.begin(), range.end()); 60 } 61 if (filledVector.size() > iterationDomainSize) 62 filledVector.resize(iterationDomainSize); 63 return filledVector; 64 } 65 66 //===----------------------------------------------------------------------===// 67 // tileUsingSCFForOp implementation. 68 //===----------------------------------------------------------------------===// 69 70 // Check if `stride` evenly divides the trip count `size - offset`. 71 static bool tileDividesIterationDomain(Range loopRange) { 72 std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 73 if (!offsetAsInt) 74 return false; 75 std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 76 if (!sizeAsInt) 77 return false; 78 std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 79 if (!strideAsInt) 80 return false; 81 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 82 } 83 84 /// Returns the bounded tile size given the current `iv`, `loopRange` and 85 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. 86 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 87 Range loopRange, Value iv, 88 Value tileSize) { 89 std::optional<int64_t> ts = getConstantIntValue(tileSize); 90 if (ts && ts.value() == 1) 91 return getAsOpFoldResult(tileSize); 92 93 if (tileDividesIterationDomain( 94 Range{loopRange.offset, loopRange.size, tileSize})) 95 return tileSize; 96 97 // The tile size to use (to avoid out of bounds access) is minimum of 98 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 99 // loop. 100 AffineExpr s0, s1, d0; 101 bindDims(b.getContext(), d0); 102 bindSymbols(b.getContext(), s0, s1); 103 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); 104 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 105 return affine::makeComposedFoldedAffineMin( 106 b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size}); 107 } 108 109 /// Generate an empty loop nest that represents the tiled loop nest shell. 110 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 111 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 112 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 113 /// the 114 /// tile processed within the inner most loop. 115 static SmallVector<scf::ForOp> 116 generateTileLoopNest(OpBuilder &builder, Location loc, 117 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 118 SmallVector<OpFoldResult> &offsets, 119 SmallVector<OpFoldResult> &sizes) { 120 assert(!loopRanges.empty() && "expected at least one loop range"); 121 assert(loopRanges.size() == tileSizeVals.size() && 122 "expected as many tile sizes as loop ranges"); 123 OpBuilder::InsertionGuard guard(builder); 124 SmallVector<scf::ForOp> loops; 125 offsets.resize(loopRanges.size()); 126 sizes.resize(loopRanges.size()); 127 128 for (auto loopRange : llvm::enumerate(loopRanges)) { 129 Value offset = 130 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); 131 Value size = 132 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); 133 Value tileSize = tileSizeVals[loopRange.index()]; 134 // No loops if tile size is zero. Set offset and size to the loop 135 // offset and size. 136 if (matchPattern(tileSize, m_Zero())) { 137 offsets[loopRange.index()] = offset; 138 sizes[loopRange.index()] = size; 139 continue; 140 } 141 142 auto loop = builder.create<scf::ForOp>( 143 loc, offset, size, tileSize, ValueRange{}, 144 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 145 ValueRange /*iterArgs*/) { 146 sizes[loopRange.index()] = getBoundedTileSize( 147 bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize); 148 builder.create<scf::YieldOp>(loc); 149 }); 150 offsets[loopRange.index()] = loop.getInductionVar(); 151 loops.push_back(loop); 152 builder.setInsertionPoint(loop.getBody()->getTerminator()); 153 } 154 return loops; 155 } 156 157 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, 158 /// construct the destructive update pattern that inserts the yielded 159 /// value into a destination tensor provided by `initValue` at offset 160 /// `tileOffsets` and size `tileSizes`. For example, 161 /// 162 /// ```mlir 163 /// scf.for %iv0 = ... { 164 /// %0 = tiled_op 165 /// } 166 /// ``` 167 /// 168 /// is transformed to 169 /// 170 /// ```mlir 171 /// scf.for %iv0 = ... iter_args(%arg = %0) { 172 /// %1 = tensor.extract_slice %arg 173 /// %2 = tiled_op 174 /// %3 = tensor.insert_slice %2 into %arg 175 /// scf.yield %3 176 /// } 177 /// ``` 178 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. 179 static SmallVector<Value> 180 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, 181 ValueRange yieldedValues, 182 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 183 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 184 MutableArrayRef<scf::ForOp> loops) { 185 NewYieldValueFn yieldValueFn = 186 [&](OpBuilder &b, Location loc, 187 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 188 SmallVector<Value> inserts; 189 for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) { 190 ArrayRef<OpFoldResult> tileOffsets = 191 tileOffsetsList[yieldedValue.index()]; 192 ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()]; 193 SmallVector<OpFoldResult> tileStrides(tileOffsets.size(), 194 b.getIndexAttr(1)); 195 Value insert = b.create<tensor::InsertSliceOp>( 196 loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], 197 tileOffsets, tileSizes, tileStrides); 198 inserts.push_back(insert); 199 } 200 return inserts; 201 }; 202 203 SmallVector<scf::ForOp> newLoops = 204 replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, 205 /*replaceIterOperandsUsesInLoop =*/false); 206 for (const auto &loop : llvm::enumerate(loops)) { 207 rewriter.eraseOp(loop.value()); 208 loops[loop.index()] = newLoops[loop.index()]; 209 } 210 return llvm::to_vector(llvm::map_range( 211 loops.front().getResults().take_back(yieldedValues.size()), 212 [](OpResult r) -> Value { return r; })); 213 } 214 215 /// If the tiled operation is destination passing style, update the 216 /// slice of the destination used (which refers to the untiled destination) 217 /// to use the corresponding region argument of the innermost loop. 218 /// 219 /// ```mlir 220 /// %0 = 221 /// scf.for %iv0 = ... iter_args(%arg = %0) { 222 /// %1 = tensor.extract_slice %0 223 /// %2 = tiled_op 224 /// %3 = tensor.insert_slice %2 into %arg 225 /// scf.yield %3 226 /// } 227 /// ``` 228 /// 229 /// is transformed to 230 /// 231 /// ```mlir 232 /// scf.for %iv0 = ... iter_args(%arg = %0) { 233 /// %1 = tensor.extract_slice %arg 234 /// %2 = tiled_op 235 /// %3 = tensor.insert_slice %2 into %arg 236 /// scf.yield %3 237 /// } 238 /// ``` 239 static void 240 updateDestinationOperandsForTiledOp(OpBuilder &builder, 241 ValueRange tiledOpDestinationValues, 242 ValueRange bbArgsList) { 243 for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) { 244 auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>(); 245 if (!sliceOp) 246 continue; 247 sliceOp.setOperand(0, bbArgsList[destValue.index()]); 248 } 249 } 250 251 /// Helper method to yield the values of the tiled op, as well as 252 /// update the destination operands of the tiled op, if it is 253 /// a destination passing style op. 254 static SmallVector<Value> 255 yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues, 256 TilingResult tilingResult, 257 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 258 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 259 MutableArrayRef<scf::ForOp> loops) { 260 SmallVector<Value> replacements = 261 yieldTiledValues(rewriter, initValues, tilingResult.tiledValues, 262 tileOffsetsList, tileSizesList, loops); 263 for (auto tiledOp : tilingResult.tiledOps) { 264 if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) { 265 auto innerMostLoop = loops.back(); 266 SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands(); 267 updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, 268 innerMostLoop.getRegionIterArgs()); 269 } 270 } 271 return replacements; 272 } 273 274 /// Implementation of tiling transformation of `op` that implements the 275 /// `TilingInterface` using `scf.for` to iterate over the tiles. 276 FailureOr<scf::SCFTilingResult> 277 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, 278 const scf::SCFTilingOptions &options) { 279 OpBuilder::InsertionGuard guard(rewriter); 280 rewriter.setInsertionPointAfter(op); 281 282 if (!options.tileSizeComputationFunction) { 283 return rewriter.notifyMatchFailure( 284 op, "missing tile size computation function"); 285 } 286 287 // 1. Get the range of the loops that are represented by the operation. 288 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 289 size_t numLoops = iterationDomain.size(); 290 if (numLoops == 0) { 291 return rewriter.notifyMatchFailure( 292 op, "unable to tile op with no iteration domain"); 293 } 294 295 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 296 // skips tiling a particular dimension. This convention is significantly 297 // simpler to handle instead of adjusting affine maps to account for missing 298 // dimensions. 299 SmallVector<Value> tileSizeVector = 300 options.tileSizeComputationFunction(rewriter, op); 301 if (tileSizeVector.size() < iterationDomain.size()) { 302 auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 303 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 304 } 305 306 scf::SCFTilingResult tilingResult; 307 SmallVector<OpFoldResult> offsets, sizes; 308 { 309 // If there is an interchange specified, permute the iteration domain and 310 // the tile sizes. 311 SmallVector<int64_t> interchangeVector; 312 if (!options.interchangeVector.empty()) { 313 interchangeVector = fillInterchangeVector(options.interchangeVector, 314 iterationDomain.size()); 315 } 316 if (!interchangeVector.empty()) { 317 if (!isPermutationVector(interchangeVector)) { 318 return rewriter.notifyMatchFailure( 319 op, "invalid intechange vector, not a permutation of the entire " 320 "iteration space"); 321 } 322 323 applyPermutationToVector(iterationDomain, interchangeVector); 324 applyPermutationToVector(tileSizeVector, interchangeVector); 325 } 326 327 // 3. Materialize an empty loop nest that iterates over the tiles. These 328 // loops for now do not return any values even if the original operation has 329 // results. 330 tilingResult.loops = generateTileLoopNest( 331 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 332 333 if (!interchangeVector.empty()) { 334 auto inversePermutation = invertPermutationVector(interchangeVector); 335 applyPermutationToVector(offsets, inversePermutation); 336 applyPermutationToVector(sizes, inversePermutation); 337 } 338 } 339 340 LLVM_DEBUG({ 341 if (!tilingResult.loops.empty()) { 342 llvm::dbgs() << "LoopNest shell :\n"; 343 tilingResult.loops.front().dump(); 344 llvm::dbgs() << "\n"; 345 } 346 }); 347 348 // 4. Generate the tiled implementation within the inner most loop. 349 if (!tilingResult.loops.empty()) 350 rewriter.setInsertionPoint( 351 tilingResult.loops.back().getBody()->getTerminator()); 352 FailureOr<TilingResult> tiledImplementation = 353 op.getTiledImplementation(rewriter, offsets, sizes); 354 tilingResult.tiledOps.append(tiledImplementation->tiledOps); 355 if (op->getNumResults() == 0) { 356 // nothing more to do. 357 return tilingResult; 358 } 359 360 // If loops are empty, the tiled op is used as the replacement for the untiled 361 // op. 362 if (tilingResult.loops.empty()) { 363 tilingResult.replacements = tiledImplementation->tiledValues; 364 return tilingResult; 365 } 366 367 // 5. Yield all the results of the tiled operation. The surrounding loop 368 // nest is modified to insert a destructive update pattern to yield 369 // from the loop nest values to replace the untiled op with. 370 int64_t numResults = op->getNumResults(); 371 SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults), 372 resultSizesList(numResults); 373 for (const auto &result : llvm::enumerate(op->getResults())) { 374 if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, 375 sizes, 376 resultOffsetsList[result.index()], 377 resultSizesList[result.index()]))) { 378 return rewriter.notifyMatchFailure( 379 op, "failed to get slice of result produced"); 380 } 381 } 382 383 SmallVector<Value> destinationTensors; 384 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 385 destinationTensors))) 386 return rewriter.notifyMatchFailure(op, "failed to get destinations"); 387 388 tilingResult.replacements = yieldTiledValues( 389 rewriter, destinationTensors, tiledImplementation.value(), 390 resultOffsetsList, resultSizesList, tilingResult.loops); 391 392 LLVM_DEBUG({ 393 if (!tilingResult.loops.empty()) { 394 llvm::dbgs() << "After tiled implementation :\n"; 395 tilingResult.loops.front().dump(); 396 llvm::dbgs() << "\n"; 397 } 398 }); 399 return tilingResult; 400 } 401 402 FailureOr<scf::SCFReductionTilingResult> 403 mlir::scf::tileReductionUsingScf(RewriterBase &b, 404 PartialReductionOpInterface op, 405 ArrayRef<OpFoldResult> tileSize) { 406 Location loc = op.getLoc(); 407 // Ops implementing PartialReductionOpInterface are expected to implement 408 // TilingInterface. 409 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 410 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 411 SmallVector<Value> tileSizeVector = 412 getValueOrCreateConstantIndexOp(b, loc, tileSize); 413 if (tileSizeVector.size() < iterationDomain.size()) { 414 auto zero = b.create<arith::ConstantIndexOp>(loc, 0); 415 tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero); 416 } 417 if (op->getNumResults() != 1) 418 return b.notifyMatchFailure( 419 op, "don't support ops with multiple results for now"); 420 SmallVector<utils::IteratorType> iterators = 421 tilingInterfaceOp.getLoopIteratorTypes(); 422 int64_t numReductionDims = llvm::count( 423 tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction); 424 if (numReductionDims != 1) 425 return b.notifyMatchFailure( 426 op, "only support ops with one reduction dimension."); 427 int reductionDim; 428 for (auto [idx, iteratorType] : 429 llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { 430 if (iteratorType == utils::IteratorType::reduction) { 431 reductionDim = idx; 432 break; 433 } 434 } 435 if (static_cast<size_t>(reductionDim) >= tileSize.size()) 436 return b.notifyMatchFailure(op, "reduction dimension must be tiled"); 437 438 // 1. create the inital tensor value. 439 FailureOr<Operation *> identityTensor = 440 op.generateInitialTensorForPartialReduction(b, loc, tileSize, 441 reductionDim); 442 if (failed(identityTensor)) 443 return b.notifyMatchFailure(op, 444 "cannot create a tensor of identity value."); 445 // 2. Create the nested loops. 446 SmallVector<OpFoldResult> offsets, sizes; 447 SmallVector<scf::ForOp> loops = generateTileLoopNest( 448 b, loc, iterationDomain, tileSizeVector, offsets, sizes); 449 450 // 3. Generate the tiled implementation within the inner most loop. 451 b.setInsertionPoint(loops.back().getBody()->getTerminator()); 452 Operation *parallelOp = op.tileToPartialReduction( 453 b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDim); 454 455 SmallVector<OpFoldResult> resultSizesList; 456 for (size_t i = 0; i < offsets.size(); i++) 457 resultSizesList.push_back( 458 b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i)); 459 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); 460 SmallVector<Value> replacements = yieldTiledValues( 461 b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets, 462 resultSizesList, loops); 463 464 auto dstOp = cast<DestinationStyleOpInterface>(parallelOp); 465 auto innerMostLoop = loops.back(); 466 SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands(); 467 assert(destinationTensors.size() == 468 innerMostLoop.getRegionIterArgs().size() && 469 "unexpected number of outputs"); 470 updateDestinationOperandsForTiledOp(b, destinationTensors, 471 innerMostLoop.getRegionIterArgs()); 472 473 // 4. Apply the merge reduction to combine all the partial values. 474 b.setInsertionPointAfter(*loops.begin()); 475 Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim); 476 b.replaceOp(op, mergeOp->getResults()); 477 478 SCFReductionTilingResult results; 479 results.initialOp = *identityTensor; 480 results.loops = std::move(loops); 481 results.parallelTiledOp = parallelOp; 482 results.mergeOp = mergeOp; 483 return results; 484 } 485 //===----------------------------------------------------------------------===// 486 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. 487 //===----------------------------------------------------------------------===// 488 489 /// Return the untiled producer whose slice is used in a tiled consumer. The 490 /// method traverses the tile loop nest (`loops`) if needed, and returns the 491 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 492 /// indicates that this is a destination operand of the consumer. If there was 493 /// no loop traversal needed, the second value of the returned tuple is empty. 494 static std::tuple<OpResult, std::optional<OpOperand *>> 495 getUntiledProducerFromSliceSource(OpOperand *source, 496 ArrayRef<scf::ForOp> loops) { 497 std::optional<OpOperand *> destinationIterArg; 498 auto loopIt = loops.rbegin(); 499 while (auto iterArg = source->get().dyn_cast<BlockArgument>()) { 500 scf::ForOp loop = *loopIt; 501 if (iterArg.getOwner()->getParentOp() != loop) 502 break; 503 source = &loop.getOpOperandForRegionIterArg(iterArg); 504 loopIt++; 505 } 506 if (loopIt == loops.rend()) 507 destinationIterArg = source; 508 return {source->get().dyn_cast<OpResult>(), destinationIterArg}; 509 } 510 511 /// Implementation of fusing producer of a single slice by computing the 512 /// slice of the producer in-place. 513 std::optional<scf::SCFFuseProducerOfSliceResult> 514 mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, 515 tensor::ExtractSliceOp candidateSliceOp, 516 MutableArrayRef<scf::ForOp> loops) { 517 // 1. Get the producer of the source (potentially walking through 518 // `iter_args` of nested `scf.for`) 519 auto [fusableProducer, destinationIterArg] = 520 getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0), 521 loops); 522 if (!fusableProducer) 523 return std::nullopt; 524 525 // 2. Generate the tiled implementation of the producer of the source 526 OpBuilder::InsertionGuard g(rewriter); 527 rewriter.setInsertionPoint(candidateSliceOp); 528 FailureOr<TilingResult> tileAndFuseResult = 529 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 530 fusableProducer); 531 if (failed(tileAndFuseResult)) 532 return std::nullopt; 533 rewriter.replaceAllUsesWith(candidateSliceOp, 534 tileAndFuseResult->tiledValues[0]); 535 536 // 3. If the slice is for a destination operand, for example, 537 // 538 // ```mlir 539 // %0 = linalg.init 540 // %1 = linalg.fill .. outs(%0 : ) 541 // %2 = scf.for .. iter_args(%arg0 = %1) { 542 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 543 // %4 = tensor.extract_slice %arg1 [..] 544 // .. = linalg.matmul .. outs(%4 : ) 545 // } 546 // } 547 // ``` 548 // 549 // the IR is currently 550 // 551 // ``` 552 // %0 = linalg.init 553 // %1 = linalg.fill 554 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 555 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 556 // %4 = tensor.extract_slice %0 /*incorrect value */ [..] 557 // %5 = linalg.fill .. outs(%4 : ) 558 // .. = linalg.matmul .. outs(%5 : ) 559 // } 560 // } 561 // ``` 562 // 563 // The untiled `linalg.fill` is still used as the `init_value` since it 564 // was originally a destination operand of the untiled `linalg.matmul`. 565 // When fusing an operand that is a destination operand. 566 // - Update the iter_arg of the outer most loop to use the destination 567 // of the untiled producer. 568 // - Update the destination of the slice of the tiled producer generated 569 // to use the same basic block argument as the slice that was used to 570 // generate inplace the tiled implementation of the producer. 571 // With this the IR will be. 572 // 573 // ``` 574 // %0 = linalg.init 575 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 576 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 577 // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] 578 // %4 = linalg.fill .. outs(%3 : ) 579 // .. = linalg.matmul .. outs(%4 : ) 580 // } 581 // } 582 // ``` 583 // TODO: This can be modeled better if the `DestinationStyleOpInterface`. 584 // Update to use that when it does become available. 585 scf::ForOp outerMostLoop = loops.front(); 586 std::optional<unsigned> iterArgNumber; 587 if (destinationIterArg) { 588 iterArgNumber = 589 outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value()); 590 } 591 if (iterArgNumber) { 592 int64_t resultNumber = fusableProducer.getResultNumber(); 593 if (auto dstOp = 594 dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) { 595 outerMostLoop.setIterArg(iterArgNumber.value(), 596 dstOp.getTiedOpOperand(fusableProducer)->get()); 597 } 598 for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) { 599 auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp); 600 if (!dstOp) 601 continue; 602 scf::ForOp innerMostLoop = loops.back(); 603 updateDestinationOperandsForTiledOp( 604 rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), 605 innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); 606 } 607 } 608 return scf::SCFFuseProducerOfSliceResult{fusableProducer, 609 tileAndFuseResult->tiledValues[0], 610 tileAndFuseResult->tiledOps}; 611 } 612 613 /// Reconstruct the fused producer from within the tiled-and-fused code. 614 void mlir::scf::yieldReplacementForFusedProducer( 615 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, 616 scf::SCFFuseProducerOfSliceResult fusedProducerInfo, 617 MutableArrayRef<scf::ForOp> loops) { 618 auto [fusableProducer, fusedProducerValue, tileAndFusedOps] = 619 fusedProducerInfo; 620 SmallVector<Value> initValues; 621 FailureOr<Value> initValue = tensor::getOrCreateDestination( 622 rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); 623 if (succeeded(initValue)) { 624 SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets(); 625 SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes(); 626 SmallVector<Value> yieldedVals = 627 yieldTiledValues(rewriter, initValue.value(), fusedProducerValue, 628 resultOffsets, resultSizes, loops); 629 } 630 for (auto tileAndFusedOp : tileAndFusedOps) { 631 auto dstStyleProducer = 632 dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp); 633 if (!dstStyleProducer) 634 continue; 635 Value dstValue = 636 dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber()) 637 ->get(); 638 updateDestinationOperandsForTiledOp( 639 rewriter, dstValue, loops.back().getRegionIterArgs().back()); 640 } 641 } 642 643 /// Implementation of tile consumer and fuse producer greedily. 644 FailureOr<scf::SCFTileAndFuseResult> 645 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( 646 RewriterBase &rewriter, TilingInterface consumer, 647 const scf::SCFTileAndFuseOptions &options) { 648 // This transformation is only valid for ops that return values (i.e. not 649 // valid to use with operations that have memref operands). 650 if (!consumer->getNumResults()) { 651 return rewriter.notifyMatchFailure( 652 consumer, "invalid pattern for op with no results"); 653 } 654 655 // 1. First tile the consumer. 656 scf::SCFTileAndFuseResult tileAndFuseResult; 657 llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber; 658 { 659 FailureOr<scf::SCFTilingResult> tilingResult = 660 tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); 661 if (failed(tilingResult)) 662 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 663 for (auto *tiledOp : tilingResult->tiledOps) 664 tileAndFuseResult.tiledAndFusedOps.insert(tiledOp); 665 tileAndFuseResult.loops = std::move(tilingResult->loops); 666 for (const auto &result : llvm::enumerate( 667 llvm::zip(consumer->getResults(), tilingResult->replacements))) { 668 tileAndFuseResult.replacements[std::get<0>(result.value())] = 669 std::get<1>(result.value()); 670 yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult( 671 result.index())] = result.index(); 672 } 673 } 674 675 // If there are no loops generated, fusion is immaterial. 676 if (tileAndFuseResult.loops.empty()) 677 return tileAndFuseResult; 678 679 // 2. Typically, the operands of the tiled operation are slices of the 680 // operands of the untiled operation. These are expressed in IR using 681 // `tensor.extract_slice` operations with source being the operands of the 682 // untiled operation. Create a worklist of these `tensor.extract_slice` 683 // operations. If the producers of the source of the `tensor.extract_slice` 684 // can be tiled such that the tiled value is generated in-place, that 685 // effectively tiles + fuses the operations. 686 auto addCandidateSlices = [](Operation *fusedOp, 687 std::deque<tensor::ExtractSliceOp> &candidates) { 688 for (Value operand : fusedOp->getOperands()) 689 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 690 candidates.push_back(sliceOp); 691 }; 692 693 std::deque<tensor::ExtractSliceOp> candidates; 694 addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 695 OpBuilder::InsertionGuard g(rewriter); 696 while (!candidates.empty()) { 697 // Traverse the slices in BFS fashion. 698 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 699 candidates.pop_front(); 700 701 // The operands of the fused producer might themselved be slices of 702 // values produced by operations that implement the `TilingInterface`. 703 // Add these operations to the worklist. 704 std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer = 705 tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, 706 tileAndFuseResult.loops); 707 if (!fusedProducer) 708 continue; 709 710 if (Operation *tiledAndFusedOp = 711 fusedProducer->tiledAndFusedProducer.getDefiningOp()) { 712 tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp); 713 addCandidateSlices(tiledAndFusedOp, candidates); 714 } 715 } 716 return tileAndFuseResult; 717 } 718 719 //===----------------------------------------------------------------------===// 720 // lowerToLoopsUsingSCFForOp implementation. 721 //===----------------------------------------------------------------------===// 722 723 FailureOr<SmallVector<scf::ForOp>> 724 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 725 TilingInterface op) { 726 // TODO: Handle cases where the op has results if needed. 727 if (op->getNumResults() > 0) { 728 return rewriter.notifyMatchFailure( 729 op, "unable to lower to loops operations with return values"); 730 } 731 732 SmallVector<Range> domain = op.getIterationDomain(rewriter); 733 SmallVector<Value> ivs; 734 SmallVector<scf::ForOp> loops; 735 Location loc = op.getLoc(); 736 for (auto loopRange : domain) { 737 Value offsetVal = 738 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 739 Value sizeVal = 740 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 741 Value strideVal = 742 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 743 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 744 strideVal, ValueRange{}); 745 loops.push_back(loop); 746 ivs.push_back(loop.getInductionVar()); 747 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 748 } 749 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 750 return failure(); 751 } 752 return loops; 753 } 754