1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 25 #include "mlir/Interfaces/TilingInterface.h" 26 #include "llvm/Support/Debug.h" 27 #include <optional> 28 29 #define DEBUG_TYPE "tile-using-interface" 30 31 using namespace mlir; 32 33 scf::SCFTilingOptions & 34 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) { 35 assert(!tileSizeComputationFunction && "tile sizes already set"); 36 auto tileSizes = llvm::to_vector(ts); 37 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 38 return tileSizes; 39 }; 40 return *this; 41 } 42 43 /// Helper method to adjust the interchange vector to match the iteration 44 /// domain. 45 static SmallVector<int64_t> 46 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 47 size_t iterationDomainSize) { 48 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 49 if (filledVector.size() < iterationDomainSize) { 50 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 51 filledVector.append(range.begin(), range.end()); 52 } 53 if (filledVector.size() > iterationDomainSize) 54 filledVector.resize(iterationDomainSize); 55 return filledVector; 56 } 57 58 /// Convert a list of ops of type `SrcOpTy` to list of `Operation *`. 59 template <typename SrcOpTy> 60 static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) { 61 return llvm::to_vector( 62 llvm::map_range(ops, [](auto op) -> Operation * { return op; })); 63 } 64 template <typename SrcOpTy> 65 static SmallVector<Operation *> 66 getAsOperations(const SmallVector<SrcOpTy> &ops) { 67 return getAsOperations(ArrayRef<SrcOpTy>(ops)); 68 } 69 70 /// Convert a list of `Operation *` to a list of `DstOpTy. 71 template <typename DstOpTy> 72 static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) { 73 return llvm::to_vector( 74 llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); })); 75 } 76 template <typename DstOpTy> 77 static SmallVector<DstOpTy> 78 castToTypedOperations(const SmallVector<Operation *> &ops) { 79 return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops)); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // tileUsingSCFForOp implementation. 84 //===----------------------------------------------------------------------===// 85 86 // Check if `stride` evenly divides the trip count `size - offset`. 87 static bool tileDividesIterationDomain(Range loopRange) { 88 std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 89 if (!offsetAsInt) 90 return false; 91 std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 92 if (!sizeAsInt) 93 return false; 94 std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 95 if (!strideAsInt) 96 return false; 97 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 98 } 99 100 /// Returns the bounded tile size given the current `iv`, `loopRange` and 101 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. 102 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 103 Range loopRange, Value iv, 104 OpFoldResult tileSize) { 105 std::optional<int64_t> ts = getConstantIntValue(tileSize); 106 if (ts && ts.value() == 1) 107 return tileSize; 108 109 if (tileDividesIterationDomain( 110 Range{loopRange.offset, loopRange.size, tileSize})) 111 return tileSize; 112 113 // The tile size to use (to avoid out of bounds access) is minimum of 114 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 115 // loop. 116 AffineExpr s0, s1, d0; 117 bindDims(b.getContext(), d0); 118 bindSymbols(b.getContext(), s0, s1); 119 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); 120 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 121 return affine::makeComposedFoldedAffineMin( 122 b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size}); 123 } 124 125 /// Clones the operation and updates the destination if the operation 126 /// implements the `DestinationStyleOpInterface`. 127 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, 128 Operation *op, 129 ValueRange newDestArgs) { 130 Operation *clonedOp = rewriter.clone(*op); 131 if (auto destinationStyleOp = 132 dyn_cast<DestinationStyleOpInterface>(clonedOp)) { 133 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs); 134 } 135 return clonedOp; 136 } 137 138 /// Generate an empty loop nest that represents the tiled loop nest shell. 139 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 140 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 141 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 142 /// the 143 /// tile processed within the inner most loop. 144 static SmallVector<scf::ForOp> generateTileLoopNest( 145 OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges, 146 ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets, 147 SmallVector<OpFoldResult> &sizes) { 148 assert(!loopRanges.empty() && "expected at least one loop range"); 149 assert(loopRanges.size() == tileSizes.size() && 150 "expected as many tile sizes as loop ranges"); 151 OpBuilder::InsertionGuard guard(builder); 152 SmallVector<scf::ForOp> loops; 153 offsets.resize(loopRanges.size()); 154 sizes.resize(loopRanges.size()); 155 156 for (auto loopRange : llvm::enumerate(loopRanges)) { 157 Value offset = 158 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); 159 Value size = 160 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); 161 Value tileSize = getValueOrCreateConstantIndexOp( 162 builder, loc, tileSizes[loopRange.index()]); 163 // No loops if tile size is zero. Set offset and size to the loop 164 // offset and size. 165 if (matchPattern(tileSize, m_Zero())) { 166 offsets[loopRange.index()] = offset; 167 sizes[loopRange.index()] = size; 168 continue; 169 } 170 171 auto loop = builder.create<scf::ForOp>( 172 loc, offset, size, tileSize, ValueRange{}, 173 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 174 ValueRange /*iterArgs*/) { 175 sizes[loopRange.index()] = 176 getBoundedTileSize(bodyBuilder, bodyLoc, loopRange.value(), iv, 177 getAsOpFoldResult(tileSize)); 178 builder.create<scf::YieldOp>(loc); 179 }); 180 offsets[loopRange.index()] = loop.getInductionVar(); 181 loops.push_back(loop); 182 builder.setInsertionPoint(loop.getBody()->getTerminator()); 183 } 184 return loops; 185 } 186 187 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, 188 /// construct the destructive update pattern that inserts the yielded 189 /// value into a destination tensor provided by `initValue` at offset 190 /// `tileOffsets` and size `tileSizes`. For example, 191 /// 192 /// ```mlir 193 /// scf.for %iv0 = ... { 194 /// %0 = tiled_op 195 /// } 196 /// ``` 197 /// 198 /// is transformed to 199 /// 200 /// ```mlir 201 /// scf.for %iv0 = ... iter_args(%arg = %0) { 202 /// %1 = tensor.extract_slice %arg 203 /// %2 = tiled_op 204 /// %3 = tensor.insert_slice %2 into %arg 205 /// scf.yield %3 206 /// } 207 /// ``` 208 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. 209 static SmallVector<Value> 210 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, 211 ValueRange yieldedValues, 212 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 213 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 214 MutableArrayRef<scf::ForOp> loops) { 215 NewYieldValuesFn yieldValueFn = 216 [&](OpBuilder &b, Location loc, 217 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 218 SmallVector<Value> inserts; 219 for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) { 220 ArrayRef<OpFoldResult> tileOffsets = 221 tileOffsetsList[yieldedValue.index()]; 222 ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()]; 223 SmallVector<OpFoldResult> tileStrides(tileOffsets.size(), 224 b.getIndexAttr(1)); 225 Value insert = b.create<tensor::InsertSliceOp>( 226 loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], 227 tileOffsets, tileSizes, tileStrides); 228 inserts.push_back(insert); 229 } 230 return inserts; 231 }; 232 233 SmallVector<scf::ForOp> newLoops = 234 replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, 235 /*replaceIterOperandsUsesInLoop =*/false); 236 for (const auto &loop : llvm::enumerate(loops)) { 237 loops[loop.index()] = newLoops[loop.index()]; 238 } 239 return llvm::to_vector(llvm::map_range( 240 loops.front().getResults().take_back(yieldedValues.size()), 241 [](OpResult r) -> Value { return r; })); 242 } 243 244 /// If the tiled operation is destination passing style, update the 245 /// slice of the destination used (which refers to the untiled destination) 246 /// to use the corresponding region argument of the innermost loop. 247 /// 248 /// ```mlir 249 /// %0 = 250 /// scf.for %iv0 = ... iter_args(%arg = %0) { 251 /// %1 = tensor.extract_slice %0 252 /// %2 = tiled_op 253 /// %3 = tensor.insert_slice %2 into %arg 254 /// scf.yield %3 255 /// } 256 /// ``` 257 /// 258 /// is transformed to 259 /// 260 /// ```mlir 261 /// scf.for %iv0 = ... iter_args(%arg = %0) { 262 /// %1 = tensor.extract_slice %arg 263 /// %2 = tiled_op 264 /// %3 = tensor.insert_slice %2 into %arg 265 /// scf.yield %3 266 /// } 267 /// ``` 268 static void 269 updateDestinationOperandsForTiledOp(OpBuilder &builder, 270 ValueRange tiledOpDestinationValues, 271 ValueRange bbArgsList) { 272 for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) { 273 auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>(); 274 if (!sliceOp) 275 continue; 276 sliceOp.setOperand(0, bbArgsList[destValue.index()]); 277 } 278 } 279 280 /// Helper method to yield the values of the tiled op, as well as 281 /// update the destination operands of the tiled op, if it is 282 /// a destination passing style op. 283 static SmallVector<Value> 284 yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues, 285 TilingResult tilingResult, 286 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 287 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 288 MutableArrayRef<scf::ForOp> loops) { 289 SmallVector<Value> replacements = 290 yieldTiledValues(rewriter, initValues, tilingResult.tiledValues, 291 tileOffsetsList, tileSizesList, loops); 292 for (auto tiledOp : tilingResult.tiledOps) { 293 if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) { 294 auto innerMostLoop = loops.back(); 295 SmallVector<Value> tiledOpDestinationTensors = 296 llvm::to_vector(dstOp.getDpsInits()); 297 updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, 298 innerMostLoop.getRegionIterArgs()); 299 } 300 } 301 return replacements; 302 } 303 304 /// Implementation of tiling transformation of `op` that implements the 305 /// `TilingInterface` using `scf.for` to iterate over the tiles. 306 FailureOr<scf::SCFTilingResult> 307 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, 308 const scf::SCFTilingOptions &options) { 309 OpBuilder::InsertionGuard guard(rewriter); 310 rewriter.setInsertionPointAfter(op); 311 312 if (!options.tileSizeComputationFunction) { 313 return rewriter.notifyMatchFailure( 314 op, "missing tile size computation function"); 315 } 316 317 // 1. Get the range of the loops that are represented by the operation. 318 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 319 size_t numLoops = iterationDomain.size(); 320 if (numLoops == 0) { 321 return rewriter.notifyMatchFailure( 322 op, "unable to tile op with no iteration domain"); 323 } 324 325 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 326 // skips tiling a particular dimension. This convention is significantly 327 // simpler to handle instead of adjusting affine maps to account for missing 328 // dimensions. 329 SmallVector<OpFoldResult> tileSizeVector = 330 options.tileSizeComputationFunction(rewriter, op); 331 if (tileSizeVector.size() < iterationDomain.size()) { 332 auto zero = rewriter.getIndexAttr(0); 333 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 334 } 335 336 SmallVector<OpFoldResult> offsets, sizes; 337 SmallVector<scf::ForOp> forLoops; 338 { 339 // If there is an interchange specified, permute the iteration domain and 340 // the tile sizes. 341 SmallVector<int64_t> interchangeVector; 342 if (!options.interchangeVector.empty()) { 343 interchangeVector = fillInterchangeVector(options.interchangeVector, 344 iterationDomain.size()); 345 } 346 if (!interchangeVector.empty()) { 347 if (!isPermutationVector(interchangeVector)) { 348 return rewriter.notifyMatchFailure( 349 op, "invalid intechange vector, not a permutation of the entire " 350 "iteration space"); 351 } 352 353 applyPermutationToVector(iterationDomain, interchangeVector); 354 applyPermutationToVector(tileSizeVector, interchangeVector); 355 } 356 357 // 3. Materialize an empty loop nest that iterates over the tiles. These 358 // loops for now do not return any values even if the original operation has 359 // results. 360 forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain, 361 tileSizeVector, offsets, sizes); 362 363 if (!interchangeVector.empty()) { 364 auto inversePermutation = invertPermutationVector(interchangeVector); 365 applyPermutationToVector(offsets, inversePermutation); 366 applyPermutationToVector(sizes, inversePermutation); 367 } 368 } 369 370 LLVM_DEBUG({ 371 if (!forLoops.empty()) { 372 llvm::dbgs() << "LoopNest shell :\n"; 373 forLoops.front().dump(); 374 llvm::dbgs() << "\n"; 375 } 376 }); 377 378 // 4. Generate the tiled implementation within the inner most loop. 379 if (!forLoops.empty()) 380 rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator()); 381 FailureOr<TilingResult> tiledImplementation = 382 op.getTiledImplementation(rewriter, offsets, sizes); 383 384 if (op->getNumResults() == 0) { 385 return scf::SCFTilingResult{ 386 tiledImplementation->tiledOps, getAsOperations(forLoops), {}}; 387 } 388 389 // If loops are empty, the tiled op is used as the replacement for the untiled 390 // op. 391 if (forLoops.empty()) { 392 return scf::SCFTilingResult{tiledImplementation->tiledOps, 393 getAsOperations(forLoops), 394 tiledImplementation->tiledValues}; 395 } 396 397 // 5. Yield all the results of the tiled operation. The surrounding loop 398 // nest is modified to insert a destructive update pattern to yield 399 // from the loop nest values to replace the untiled op with. 400 int64_t numResults = op->getNumResults(); 401 SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults), 402 resultSizesList(numResults); 403 for (const auto &result : llvm::enumerate(op->getResults())) { 404 if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, 405 sizes, 406 resultOffsetsList[result.index()], 407 resultSizesList[result.index()]))) { 408 return rewriter.notifyMatchFailure( 409 op, "failed to get slice of result produced"); 410 } 411 } 412 413 SmallVector<Value> destinationTensors; 414 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 415 destinationTensors))) 416 return rewriter.notifyMatchFailure(op, "failed to get destinations"); 417 418 SmallVector<Value> replacements = yieldTiledValues( 419 rewriter, destinationTensors, tiledImplementation.value(), 420 resultOffsetsList, resultSizesList, forLoops); 421 LLVM_DEBUG({ 422 if (!forLoops.empty()) { 423 llvm::dbgs() << "After tiled implementation :\n"; 424 forLoops.front().dump(); 425 llvm::dbgs() << "\n"; 426 } 427 }); 428 return scf::SCFTilingResult{tiledImplementation->tiledOps, 429 getAsOperations(forLoops), replacements}; 430 } 431 432 FailureOr<scf::SCFReductionTilingResult> 433 mlir::scf::tileReductionUsingScf(RewriterBase &b, 434 PartialReductionOpInterface op, 435 ArrayRef<OpFoldResult> tileSizes) { 436 Location loc = op.getLoc(); 437 // Ops implementing PartialReductionOpInterface are expected to implement 438 // TilingInterface. 439 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 440 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 441 auto tileSizesVector = llvm::to_vector(tileSizes); 442 if (tileSizesVector.size() < iterationDomain.size()) { 443 auto zero = b.getIndexAttr(0); 444 tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(), 445 zero); 446 } 447 if (op->getNumResults() != 1) 448 return b.notifyMatchFailure( 449 op, "don't support ops with multiple results for now"); 450 SmallVector<utils::IteratorType> iterators = 451 tilingInterfaceOp.getLoopIteratorTypes(); 452 453 SmallVector<int> reductionDims; 454 for (auto [idx, iteratorType] : 455 llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { 456 if (iteratorType == utils::IteratorType::reduction) 457 reductionDims.push_back(idx); 458 } 459 460 // 1. create the inital tensor value. 461 FailureOr<Operation *> identityTensor = 462 op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector, 463 reductionDims); 464 if (failed(identityTensor)) 465 return b.notifyMatchFailure(op, 466 "cannot create a tensor of identity value."); 467 // 2. Create the nested loops. 468 SmallVector<OpFoldResult> offsets, sizes; 469 SmallVector<scf::ForOp> loops = generateTileLoopNest( 470 b, loc, iterationDomain, tileSizesVector, offsets, sizes); 471 472 // 3. Generate the tiled implementation within the inner most loop. 473 b.setInsertionPoint(loops.back().getBody()->getTerminator()); 474 Operation *parallelOp = op.tileToPartialReduction( 475 b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDims); 476 477 SmallVector<OpFoldResult> resultSizesList; 478 for (size_t i = 0; i < offsets.size(); i++) 479 resultSizesList.push_back( 480 tensor::getMixedSize(b, loc, parallelOp->getResult(0), i)); 481 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); 482 SmallVector<Value> replacements = yieldTiledValues( 483 b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets, 484 resultSizesList, loops); 485 486 auto dstOp = cast<DestinationStyleOpInterface>(parallelOp); 487 auto innerMostLoop = loops.back(); 488 SmallVector<Value> destinationTensors = llvm::to_vector(dstOp.getDpsInits()); 489 assert(destinationTensors.size() == 490 innerMostLoop.getRegionIterArgs().size() && 491 "unexpected number of outputs"); 492 updateDestinationOperandsForTiledOp(b, destinationTensors, 493 innerMostLoop.getRegionIterArgs()); 494 495 // 4. Apply the merge reduction to combine all the partial values. 496 b.setInsertionPointAfter(*loops.begin()); 497 Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims); 498 b.replaceOp(op, mergeOp->getResults()); 499 500 SCFReductionTilingResult results; 501 results.initialOp = *identityTensor; 502 results.loops = std::move(loops); 503 results.parallelTiledOp = parallelOp; 504 results.mergeOp = mergeOp; 505 return results; 506 } 507 508 //===----------------------------------------------------------------------===// 509 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. 510 //===----------------------------------------------------------------------===// 511 512 /// Return the untiled producer whose slice is used in a tiled consumer. The 513 /// method traverses the tile loop nest (`loops`) if needed, and returns the 514 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 515 /// indicates that this is a destination operand of the consumer. If there was 516 /// no loop traversal needed, the second value of the returned tuple is empty. 517 static std::tuple<OpResult, std::optional<OpOperand *>> 518 getUntiledProducerFromSliceSource(OpOperand *source, 519 ArrayRef<scf::ForOp> loops) { 520 std::optional<OpOperand *> destinationIterArg; 521 auto loopIt = loops.rbegin(); 522 while (auto iterArg = dyn_cast<BlockArgument>(source->get())) { 523 scf::ForOp loop = *loopIt; 524 if (iterArg.getOwner()->getParentOp() != loop) 525 break; 526 source = &loop.getOpOperandForRegionIterArg(iterArg); 527 loopIt++; 528 } 529 if (loopIt == loops.rend()) 530 destinationIterArg = source; 531 return {dyn_cast<OpResult>(source->get()), destinationIterArg}; 532 } 533 534 /// Implementation of fusing producer of a single slice by computing the 535 /// slice of the producer in-place. 536 std::optional<scf::SCFFuseProducerOfSliceResult> 537 mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, 538 tensor::ExtractSliceOp candidateSliceOp, 539 MutableArrayRef<scf::ForOp> loops) { 540 // 1. Get the producer of the source (potentially walking through 541 // `iter_args` of nested `scf.for`) 542 auto [fusableProducer, destinationInitArg] = 543 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), 544 loops); 545 if (!fusableProducer) 546 return std::nullopt; 547 548 // 2. Generate the tiled implementation of the producer of the source 549 OpBuilder::InsertionGuard g(rewriter); 550 rewriter.setInsertionPoint(candidateSliceOp); 551 FailureOr<TilingResult> tileAndFuseResult = 552 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 553 fusableProducer); 554 if (failed(tileAndFuseResult)) 555 return std::nullopt; 556 rewriter.replaceAllUsesWith(candidateSliceOp, 557 tileAndFuseResult->tiledValues[0]); 558 559 // 3. If the slice is for a destination operand, for example, 560 // 561 // ```mlir 562 // %0 = linalg.init 563 // %1 = linalg.fill .. outs(%0 : ) 564 // %2 = scf.for .. iter_args(%arg0 = %1) { 565 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 566 // %4 = tensor.extract_slice %arg1 [..] 567 // .. = linalg.matmul .. outs(%4 : ) 568 // } 569 // } 570 // ``` 571 // 572 // the IR is currently 573 // 574 // ``` 575 // %0 = linalg.init 576 // %1 = linalg.fill 577 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 578 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 579 // %4 = tensor.extract_slice %0 /*incorrect value */ [..] 580 // %5 = linalg.fill .. outs(%4 : ) 581 // .. = linalg.matmul .. outs(%5 : ) 582 // } 583 // } 584 // ``` 585 // 586 // The untiled `linalg.fill` is still used as the `init_value` since it 587 // was originally a destination operand of the untiled `linalg.matmul`. 588 // When fusing an operand that is a destination operand. 589 // - Update the iter_arg of the outer most loop to use the destination 590 // of the untiled producer. 591 // - Update the destination of the slice of the tiled producer generated 592 // to use the same basic block argument as the slice that was used to 593 // generate inplace the tiled implementation of the producer. 594 // With this the IR will be. 595 // 596 // ``` 597 // %0 = linalg.init 598 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 599 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 600 // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] 601 // %4 = linalg.fill .. outs(%3 : ) 602 // .. = linalg.matmul .. outs(%4 : ) 603 // } 604 // } 605 // ``` 606 // TODO: This can be modeled better if the `DestinationStyleOpInterface`. 607 // Update to use that when it does become available. 608 scf::ForOp outerMostLoop = loops.front(); 609 if (destinationInitArg && 610 (*destinationInitArg)->getOwner() == outerMostLoop) { 611 unsigned iterArgNumber = 612 outerMostLoop.getResultForOpOperand(**destinationInitArg) 613 .getResultNumber(); 614 int64_t resultNumber = fusableProducer.getResultNumber(); 615 if (auto dstOp = 616 dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) { 617 (*destinationInitArg) 618 ->set(dstOp.getTiedOpOperand(fusableProducer)->get()); 619 } 620 for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) { 621 auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp); 622 if (!dstOp) 623 continue; 624 scf::ForOp innerMostLoop = loops.back(); 625 updateDestinationOperandsForTiledOp( 626 rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), 627 innerMostLoop.getRegionIterArgs()[iterArgNumber]); 628 } 629 } 630 return scf::SCFFuseProducerOfSliceResult{fusableProducer, 631 tileAndFuseResult->tiledValues[0], 632 tileAndFuseResult->tiledOps}; 633 } 634 635 /// Reconstruct the fused producer from within the tiled-and-fused code. 636 void mlir::scf::yieldReplacementForFusedProducer( 637 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, 638 scf::SCFFuseProducerOfSliceResult fusedProducerInfo, 639 MutableArrayRef<scf::ForOp> loops) { 640 auto [fusableProducer, fusedProducerValue, tileAndFusedOps] = 641 fusedProducerInfo; 642 SmallVector<Value> initValues; 643 FailureOr<Value> initValue = tensor::getOrCreateDestination( 644 rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); 645 if (succeeded(initValue)) { 646 SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets(); 647 SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes(); 648 SmallVector<Value> yieldedVals = 649 yieldTiledValues(rewriter, initValue.value(), fusedProducerValue, 650 resultOffsets, resultSizes, loops); 651 } 652 for (auto tileAndFusedOp : tileAndFusedOps) { 653 auto dstStyleProducer = 654 dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp); 655 if (!dstStyleProducer) 656 continue; 657 Value dstValue = 658 dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber()) 659 ->get(); 660 updateDestinationOperandsForTiledOp( 661 rewriter, dstValue, loops.back().getRegionIterArgs().back()); 662 } 663 } 664 665 /// Implementation of tile consumer and fuse producer greedily. 666 FailureOr<scf::SCFTileAndFuseResult> 667 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( 668 RewriterBase &rewriter, TilingInterface consumer, 669 const scf::SCFTileAndFuseOptions &options) { 670 // This transformation is only valid for ops that return values (i.e. not 671 // valid to use with operations that have memref operands). 672 if (!consumer->getNumResults()) { 673 return rewriter.notifyMatchFailure( 674 consumer, "invalid pattern for op with no results"); 675 } 676 677 // 1. First tile the consumer. 678 SmallVector<scf::ForOp> forLoops; 679 SetVector<Operation *> fusedProducers, tiledAndFusedOps; 680 DenseMap<Value, Value> replacements; 681 llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber; 682 { 683 FailureOr<scf::SCFTilingResult> tilingResult = 684 tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); 685 if (failed(tilingResult)) 686 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 687 for (auto *tiledOp : tilingResult->tiledOps) 688 tiledAndFusedOps.insert(tiledOp); 689 forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops); 690 for (auto [index, origValue, replacement] : 691 llvm::enumerate(consumer->getResults(), tilingResult->replacements)) { 692 replacements[origValue] = replacement; 693 yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult( 694 index)] = index; 695 } 696 } 697 698 // If there are no loops generated, fusion is immaterial. 699 if (forLoops.empty()) { 700 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, 701 getAsOperations(forLoops), replacements}; 702 } 703 704 // 2. Typically, the operands of the tiled operation are slices of the 705 // operands of the untiled operation. These are expressed in IR using 706 // `tensor.extract_slice` operations with source being the operands of the 707 // untiled operation. Create a worklist of these `tensor.extract_slice` 708 // operations. If the producers of the source of the `tensor.extract_slice` 709 // can be tiled such that the tiled value is generated in-place, that 710 // effectively tiles + fuses the operations. 711 auto addCandidateSlices = [](Operation *fusedOp, 712 std::deque<tensor::ExtractSliceOp> &candidates) { 713 for (Value operand : fusedOp->getOperands()) 714 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 715 candidates.push_back(sliceOp); 716 }; 717 718 std::deque<tensor::ExtractSliceOp> candidates; 719 addCandidateSlices(tiledAndFusedOps.back(), candidates); 720 OpBuilder::InsertionGuard g(rewriter); 721 while (!candidates.empty()) { 722 // Traverse the slices in BFS fashion. 723 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 724 candidates.pop_front(); 725 726 // The operands of the fused producer might themselved be slices of 727 // values produced by operations that implement the `TilingInterface`. 728 // Add these operations to the worklist. 729 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult = 730 tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops); 731 if (!fusedResult) 732 continue; 733 734 if (Operation *tiledAndFusedOp = 735 fusedResult->tiledAndFusedProducer.getDefiningOp()) { 736 fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); 737 tiledAndFusedOps.insert(tiledAndFusedOp); 738 addCandidateSlices(tiledAndFusedOp, candidates); 739 } 740 } 741 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, 742 getAsOperations(forLoops), replacements}; 743 } 744 745 //===----------------------------------------------------------------------===// 746 // tileUsingSCFForAllOp implementation. 747 //===----------------------------------------------------------------------===// 748 749 FailureOr<scf::SCFTilingResult> 750 mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op, 751 const scf::SCFTilingOptions &options) { 752 Location loc = op->getLoc(); 753 OpBuilder::InsertionGuard g(rewriter); 754 755 // 1. Get the range of loops that are represented by the operation. 756 SmallVector<Range> loopRanges = op.getIterationDomain(rewriter); 757 if (loopRanges.empty()) 758 return op->emitOpError("expected non-empty loop ranges"); 759 auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); }; 760 if (llvm::any_of(loopRanges, hasStrideOne)) 761 return op->emitOpError("only stride-1 supported atm"); 762 763 // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed. 764 // To make it easier, pad the tile sizes to loopRanges.size with value 0. 765 SmallVector<OpFoldResult> tileSizeVector = 766 options.tileSizeComputationFunction(rewriter, op); 767 tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0)); 768 769 // 3. Build the offsets, sizes and steps for the tile and distributed loops. 770 SmallVector<OpFoldResult> lbs, ubs, steps; 771 for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) { 772 if (isConstantIntValue(tileSize, 0)) 773 continue; 774 lbs.push_back(loopRange.offset); 775 ubs.push_back(loopRange.size); 776 steps.push_back(tileSize); 777 } 778 779 // 4. Gather destination tensors. 780 SmallVector<Value> dest; 781 if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest))) 782 return op->emitOpError("failed to get destination tensors"); 783 784 // 5. Build the device mapping attribute. 785 std::optional<ArrayAttr> mappingAttr; 786 if (!options.mappingVector.empty()) { 787 mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector)); 788 } 789 790 // 6. Create the ForallOp. We don't use the lambda body-builder 791 // version because we require the use of RewriterBase in the body, so we 792 // manually move the insertion point to the body below. 793 auto forallOp = 794 rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr); 795 796 // 7. Get the tile offset and sizes. 797 rewriter.setInsertionPoint(forallOp.getTerminator()); 798 SmallVector<OpFoldResult> tiledOffsets, tiledSizes; 799 ValueRange ivs = forallOp.getInductionVars(); 800 { 801 int materializedLoopNum = 0; 802 for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) { 803 if (isConstantIntValue(tileSize, 0)) { 804 tiledOffsets.push_back(loopRange.offset); 805 tiledSizes.push_back(loopRange.size); 806 continue; 807 } 808 Value iv = ivs[materializedLoopNum++]; 809 tiledOffsets.push_back(iv); 810 tiledSizes.push_back( 811 getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize)); 812 } 813 } 814 815 // 8. Tile the operation. Clone the operation to allow fix up of destination 816 // operands. 817 ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments(); 818 Operation *clonedOp = 819 cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs); 820 FailureOr<TilingResult> tilingResult = 821 cast<TilingInterface>(clonedOp).getTiledImplementation( 822 rewriter, tiledOffsets, tiledSizes); 823 if (failed(tilingResult)) 824 return clonedOp->emitError("failed to tile op: "); 825 rewriter.eraseOp(clonedOp); 826 827 // 9. Parallel insert back into the result tensor. 828 for (auto [index, tiledValue, destBBArg] : 829 llvm::enumerate(tilingResult->tiledValues, destBbArgs)) { 830 // 9.a. Partial subset information is inserted just before the terminator. 831 rewriter.setInsertionPoint(forallOp.getTerminator()); 832 833 SmallVector<OpFoldResult> resultOffsets, resultSizes; 834 if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets, 835 tiledSizes, resultOffsets, 836 resultSizes))) { 837 return op->emitOpError("output offsets couldn't be calculated"); 838 } 839 840 SmallVector<OpFoldResult> strides(resultSizes.size(), 841 rewriter.getIndexAttr(1)); 842 // 9.b. Parallel insertions are inserted at the end of the combining 843 // terminator. 844 rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); 845 rewriter.create<tensor::ParallelInsertSliceOp>( 846 loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides); 847 } 848 849 // 10. Return the tiling result. 850 return scf::SCFTilingResult{ 851 tilingResult->tiledOps, 852 {forallOp.getOperation()}, 853 llvm::map_to_vector(forallOp.getResults(), 854 [](auto val) -> Value { return val; })}; 855 } 856 857 //===----------------------------------------------------------------------===// 858 // lowerToLoopsUsingSCFForOp implementation. 859 //===----------------------------------------------------------------------===// 860 861 FailureOr<SmallVector<scf::ForOp>> 862 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 863 TilingInterface op) { 864 // TODO: Handle cases where the op has results if needed. 865 if (op->getNumResults() > 0) { 866 return rewriter.notifyMatchFailure( 867 op, "unable to lower to loops operations with return values"); 868 } 869 870 SmallVector<Range> domain = op.getIterationDomain(rewriter); 871 SmallVector<Value> ivs; 872 SmallVector<scf::ForOp> loops; 873 Location loc = op.getLoc(); 874 for (auto loopRange : domain) { 875 Value offsetVal = 876 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 877 Value sizeVal = 878 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 879 Value strideVal = 880 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 881 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 882 strideVal, ValueRange{}); 883 loops.push_back(loop); 884 ivs.push_back(loop.getInductionVar()); 885 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 886 } 887 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 888 return failure(); 889 } 890 return loops; 891 } 892