1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Analysis/TopologicalSortUtils.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arith/IR/Arith.h" 19 #include "mlir/Dialect/Arith/Utils/Utils.h" 20 #include "mlir/Dialect/Func/IR/FuncOps.h" 21 #include "mlir/Dialect/SCF/Utils/Utils.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Utils/IndexingUtils.h" 24 #include "mlir/IR/Dominance.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 28 #include "mlir/Interfaces/TilingInterface.h" 29 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "llvm/ADT/ScopeExit.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/Support/Debug.h" 34 #include <optional> 35 36 #define DEBUG_TYPE "tile-using-interface" 37 38 using namespace mlir; 39 40 scf::SCFTilingOptions & 41 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) { 42 assert(!tileSizeComputationFunction && "tile sizes already set"); 43 auto tileSizes = llvm::to_vector(ts); 44 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 45 return tileSizes; 46 }; 47 return *this; 48 } 49 50 scf::SCFTilingOptions & 51 scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) { 52 assert(!numThreadsComputationFunction && "num tiles already set"); 53 auto numThreads = llvm::to_vector(nt); 54 numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) { 55 return numThreads; 56 }; 57 return *this; 58 } 59 60 /// Helper method to adjust the interchange vector to match the iteration 61 /// domain. 62 static SmallVector<int64_t> 63 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 64 size_t iterationDomainSize) { 65 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 66 if (filledVector.size() < iterationDomainSize) { 67 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 68 filledVector.append(range.begin(), range.end()); 69 } 70 if (filledVector.size() > iterationDomainSize) 71 filledVector.resize(iterationDomainSize); 72 return filledVector; 73 } 74 75 //===----------------------------------------------------------------------===// 76 // tileUsingSCF implementation. 77 //===----------------------------------------------------------------------===// 78 79 /// Verify the tile size options are set in a consistent manner. 80 static LogicalResult 81 verifyTileSizeOptions(RewriterBase &rewriter, Location loc, 82 const scf::SCFTilingOptions &options) { 83 // Specifying number of threads is only supported on `scf.forall` op. 84 if (options.numThreadsComputationFunction && 85 options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) { 86 return rewriter.notifyMatchFailure( 87 loc, "number of threads can only by specified when loop type is " 88 "set to use `scf.forall`"); 89 } 90 91 // If specified, check that the interchange vector is a permutation. 92 if (!options.interchangeVector.empty()) { 93 if (!isPermutationVector(options.interchangeVector)) { 94 return rewriter.notifyMatchFailure( 95 loc, "invalid interchange vector, not a permutation of the entire " 96 "iteration space"); 97 } 98 } 99 return success(); 100 } 101 102 /// Method to instantiate the tile sizes and/or number of threads specified 103 /// by the user. 104 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> 105 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, 106 ArrayRef<Range> iterationDomain, 107 const scf::SCFTilingOptions &options) { 108 OpFoldResult zero = rewriter.getIndexAttr(0); 109 SmallVector<OpFoldResult> tileSizes, numThreads; 110 size_t numLoops = iterationDomain.size(); 111 112 // Check whether the number of tiles to use is specified. 113 if (options.numThreadsComputationFunction) { 114 numThreads = options.numThreadsComputationFunction(rewriter, op); 115 numThreads.resize(numLoops, zero); 116 117 // If the number of tiles is also specified, use that. 118 if (options.tileSizeComputationFunction) { 119 tileSizes = options.tileSizeComputationFunction(rewriter, op); 120 tileSizes.resize(numLoops, zero); 121 return {tileSizes, numThreads}; 122 } 123 124 // Compute the tile sizes from the iteration domain and number 125 // of tiles as follows 126 // - niters = ceilDiv(ub - lb, step) 127 // - tileSize = ceilDiv(niters, numThreads) 128 AffineExpr s0, s1, s2; 129 bindSymbols(rewriter.getContext(), s0, s1, s2); 130 // TODO: The step here is assumed to be 1. 131 AffineExpr numItersExpr = (s1 - s0); 132 AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2); 133 tileSizes.resize(numLoops, zero); 134 for (auto [index, range, nt] : 135 llvm::enumerate(iterationDomain, numThreads)) { 136 if (isConstantIntValue(nt, 0)) 137 continue; 138 139 tileSizes[index] = affine::makeComposedFoldedAffineApply( 140 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt}); 141 } 142 tileSizes.resize(numLoops, zero); 143 return {tileSizes, numThreads}; 144 } 145 146 // Enforce the convention that "tiling by zero" 147 // skips tiling a particular dimension. This convention is significantly 148 // simpler to handle instead of adjusting affine maps to account for missing 149 // dimensions. 150 assert(options.tileSizeComputationFunction && 151 "expected tile sizes to be specified"); 152 tileSizes = options.tileSizeComputationFunction(rewriter, op); 153 tileSizes.resize(numLoops, zero); 154 155 return {tileSizes, numThreads}; 156 } 157 158 /// Checks if any of the tiled loops are not parallel. 159 static void checkSafeToTileToForall(TilingInterface op, 160 ArrayRef<OpFoldResult> tileSizes, 161 ArrayRef<OpFoldResult> numThreads) { 162 auto iterators = op.getLoopIteratorTypes(); 163 assert(iterators.size() == tileSizes.size() && 164 "expected as many tile size values as number of loops"); 165 assert((numThreads.empty() || (numThreads.size() == iterators.size())) && 166 "when specified, expected number of threads to use for each loop"); 167 168 for (auto [index, iterator, tileSize] : 169 llvm::enumerate(iterators, tileSizes)) { 170 // If num threads is specified, check that it is greater than one only for 171 // parallel dimensions. 172 if (!numThreads.empty()) { 173 if (std::optional<int64_t> constNumThreads = 174 getConstantIntValue(numThreads[index])) { 175 if (constNumThreads.value() > 1 && 176 iterator != utils::IteratorType::parallel) { 177 op.emitWarning() << "tiling is not thread safe at axis #" << index; 178 } 179 } 180 continue; 181 } 182 183 if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) { 184 if (constTileSize.value() > 0 && 185 iterator != utils::IteratorType::parallel) { 186 op.emitWarning() << "tiling is not thread safe at axis #" << index; 187 } 188 } 189 } 190 } 191 192 /// Check if `stride` evenly divides the trip count `size - offset`. 193 static bool tileDividesIterationDomain(Range loopRange) { 194 std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 195 if (!offsetAsInt) 196 return false; 197 std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 198 if (!sizeAsInt) 199 return false; 200 std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 201 if (!strideAsInt) 202 return false; 203 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 204 } 205 206 /// Returns the bounded tile size given the current `offset`, `loopRange` and 207 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`. 208 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 209 Range loopRange, OpFoldResult offset, 210 OpFoldResult tileSize) { 211 std::optional<int64_t> ts = getConstantIntValue(tileSize); 212 if (ts && ts.value() == 1) 213 return tileSize; 214 215 if (tileDividesIterationDomain( 216 Range{loopRange.offset, loopRange.size, tileSize})) 217 return tileSize; 218 219 // The tile size to use (to avoid out of bounds access) is minimum of 220 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 221 // loop. 222 AffineExpr s0, s1, d0; 223 bindDims(b.getContext(), d0); 224 bindSymbols(b.getContext(), s0, s1); 225 AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext()); 226 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 227 return affine::makeComposedFoldedAffineMin( 228 b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize}); 229 } 230 231 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less 232 /// than `iterationSize`. 233 static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, 234 OpFoldResult numThreads, 235 OpFoldResult iterationSize) { 236 std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize); 237 std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads); 238 std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize); 239 if (!tileSizeConst || !numThreadsConst || !iterSizeConst) 240 return false; 241 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; 242 } 243 244 /// Compute the `OpFoldResult`s that represents the multi-dimensional 245 /// `offset`s and `size`s of the tile of the iteration space that the 246 /// innermost loop body of the generated tiled loops corresponds to. 247 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> 248 getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, 249 ArrayRef<Range> iterationDomain, 250 ArrayRef<OpFoldResult> tileSizes, 251 ArrayRef<OpFoldResult> numThreads) { 252 SmallVector<OpFoldResult> offsets, sizes; 253 int materializedLoopNum = 0; 254 255 if (!numThreads.empty()) { 256 AffineExpr d0, d1, s0, s1; 257 AffineExpr offsetExpr, residualTileSizeExpr; 258 bindDims(rewriter.getContext(), d0, d1); 259 bindSymbols(rewriter.getContext(), s0, s1); 260 offsetExpr = d0 + d1 * s0; 261 residualTileSizeExpr = s1 - (d0 + d1 * s0); 262 263 for (auto [nt, tileSize, loopRange] : 264 llvm::zip_equal(numThreads, tileSizes, iterationDomain)) { 265 266 // Non-tiled cases, set the offset and size to the 267 // `loopRange.offset/size`. 268 if (isConstantIntValue(nt, 0)) { 269 offsets.push_back(loopRange.offset); 270 sizes.push_back(loopRange.size); 271 continue; 272 } 273 274 Value iv = ivs[materializedLoopNum++]; 275 OpFoldResult offset = affine::makeComposedFoldedAffineApply( 276 rewriter, loc, offsetExpr, 277 ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize}); 278 OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply( 279 rewriter, loc, residualTileSizeExpr, 280 {loopRange.offset, nt, tileSize, loopRange.size}); 281 282 OpFoldResult size = tileSize; 283 if (!isConstantIntValue(residualTileSize, 0)) { 284 OpFoldResult sizeMinusOffsetPerThread = 285 affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, 286 {offset, loopRange.size}); 287 size = affine::makeComposedFoldedAffineMin( 288 rewriter, loc, 289 AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), 290 {sizeMinusOffsetPerThread, tileSize}); 291 } 292 293 // Consider the case where the original loop was `[0, 100)`. 294 // If number of threads are `7`, the tile size would be computed as 295 // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6) 296 // - `offset = 0 + 6 * 15 = 105` 297 // - `tileSize = min(15, 100 - 105) = -5` 298 // To avoid negative tile sizes, we need to do a further 299 // `nonNegativeTileSize = affine.max(0, tileSize)`. 300 // This `max` can be avoided if 301 // `offset + tileSize * (numThreads - 1) < (ub - lb)` 302 if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) { 303 AffineMap maxMap = 304 AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); 305 size = affine::makeComposedFoldedAffineMax( 306 rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size}); 307 } 308 309 offsets.push_back(offset); 310 sizes.push_back(size); 311 } 312 return {offsets, sizes}; 313 } else { 314 for (auto [tileSize, loopRange] : 315 llvm::zip_equal(tileSizes, iterationDomain)) { 316 317 // Non-tiled cases, set the offset and size to the 318 // `loopRange.offset/size`. 319 if (isConstantIntValue(tileSize, 0)) { 320 offsets.push_back(loopRange.offset); 321 sizes.push_back(loopRange.size); 322 continue; 323 } 324 325 Value iv = ivs[materializedLoopNum++]; 326 OpFoldResult offset = getAsOpFoldResult(iv); 327 offsets.push_back(offset); 328 OpFoldResult size = 329 getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize); 330 sizes.push_back(size); 331 } 332 return {offsets, sizes}; 333 } 334 } 335 336 /// Function to return the bounds of the loops to be generated. 337 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, 338 SmallVector<OpFoldResult>> 339 getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 340 ArrayRef<OpFoldResult> tileSizes) { 341 SmallVector<OpFoldResult> lbs, ubs, steps; 342 for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) { 343 // No loop if the tile size is 0. 344 if (isConstantIntValue(tileSize, 0)) 345 continue; 346 lbs.push_back(loopRange.offset); 347 ubs.push_back(loopRange.size); 348 steps.push_back(tileSize); 349 } 350 return {lbs, ubs, steps}; 351 } 352 353 /// A function that allows returning additional yielded values during 354 /// `yieldTiledValuesAndReplace`. 355 /// - `ivs` induction variable for the loop. 356 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args. 357 /// - `tiledValues` the tiled values to return. Must be of same size as 358 /// `newbbArgs`, each element of this array is inserted into the corresponding 359 /// element in `newbbArgs`. 360 /// - `resultOffsets` is of the same size as `tiledValues` and represents 361 /// the offsets to use when inserting corresponding element from `tiledValues` 362 /// into the element from `newBbArgs`. 363 /// - `resultSizes` is of the same size as `tiledValues` and represents 364 /// the size of the corresponding element from `tiledValues` inserted into 365 /// the element from `newBbArgs`. 366 /// In case the method needs to return `failure()` the method is expected 367 /// to clean up any inserted operations. 368 using YieldTiledValuesFn = std::function<LogicalResult( 369 RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, 370 SmallVector<Value> &tiledValues, 371 SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 372 SmallVector<SmallVector<OpFoldResult>> &resultSizes)>; 373 374 /// Clones the operation and updates the destination if the operation 375 /// implements the `DestinationStyleOpInterface`. 376 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, 377 Operation *op, 378 ValueRange newDestArgs) { 379 Operation *clonedOp = rewriter.clone(*op); 380 if (newDestArgs.empty()) 381 return clonedOp; 382 if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp)) 383 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs); 384 return clonedOp; 385 } 386 387 /// Generate the tile-loop nest using `scf.for` operation. 388 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 389 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 390 /// - `destinationTensors` are the init values to use for the outer most loop. 391 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 392 /// most 393 /// loop. 394 /// - `loops` is an in-out parameter into which the generated loops are 395 /// populated. 396 static LogicalResult generateLoopNestUsingForOp( 397 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 398 ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors, 399 YieldTiledValuesFn yieldTiledValuesFn, 400 SmallVector<LoopLikeOpInterface> &loops) { 401 assert(!loopRanges.empty() && "unexpected empty loop ranges"); 402 assert(loopRanges.size() == tileSizes.size() && 403 "expected as many tile sizes as loop ranges"); 404 OpBuilder::InsertionGuard guard(rewriter); 405 406 SmallVector<OpFoldResult> lbs, ubs, steps; 407 std::tie(lbs, ubs, steps) = 408 getLoopBounds(rewriter, loc, loopRanges, tileSizes); 409 SmallVector<Value> lbVals = 410 getValueOrCreateConstantIndexOp(rewriter, loc, lbs); 411 SmallVector<Value> ubVals = 412 getValueOrCreateConstantIndexOp(rewriter, loc, ubs); 413 SmallVector<Value> stepVals = 414 getValueOrCreateConstantIndexOp(rewriter, loc, steps); 415 416 SmallVector<Value> ivs; 417 for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) { 418 auto loop = 419 rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors, 420 [](OpBuilder &bodyBuilder, Location bodyLoc, 421 Value iv, ValueRange /*iterArgs*/) {}); 422 loops.push_back(loop); 423 ivs.push_back(loop.getInductionVar()); 424 rewriter.setInsertionPointToEnd(loop.getBody()); 425 destinationTensors = loop.getRegionIterArgs(); 426 } 427 428 SmallVector<Value> tiledResults; 429 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 430 if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors, 431 tiledResults, resultOffsets, resultSizes))) { 432 return rewriter.notifyMatchFailure( 433 loc, "failed to generate inner tile loop body"); 434 } 435 if (loops.empty()) 436 return success(); 437 438 assert(tiledResults.size() == destinationTensors.size() && 439 "Number of results of body should be equal to number of iter args"); 440 441 // 6. Yield all the results of the tiled operation. 442 SmallVector<Value> yieldedValues; 443 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : 444 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, 445 resultSizes)) { 446 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 447 rewriter.getIndexAttr(1)); 448 auto insertSlice = rewriter.create<tensor::InsertSliceOp>( 449 loc, tiledValue, destinationTensor, resultOffset, resultSize, 450 resultStride); 451 yieldedValues.push_back(insertSlice); 452 } 453 rewriter.create<scf::YieldOp>(loc, yieldedValues); 454 455 // Add the scf.yield operations for all the outer loops. 456 for (auto [outerLoop, innerLoop] : 457 llvm::zip_equal(MutableArrayRef(loops).drop_back(), 458 MutableArrayRef(loops).drop_front())) { 459 rewriter.setInsertionPointToEnd( 460 cast<scf::ForOp>(outerLoop.getOperation()).getBody()); 461 rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults()); 462 } 463 return success(); 464 } 465 466 /// Generate the tile-loop nest using `scf.forall` operation. 467 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 468 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 469 /// - `destinationTensors` are the init values to use for the outer most loop. 470 /// - `mappingVector` is the mapping attributes to use for loop construction. 471 /// Can be empty. 472 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 473 /// most 474 /// loop. 475 /// - `loops` is an in-out parameter into which the generated loops are 476 /// populated. 477 static LogicalResult generateLoopNestUsingForallOp( 478 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 479 ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads, 480 ArrayRef<Attribute> mappingVector, ValueRange destinationTensors, 481 YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { 482 assert(!loopRanges.empty() && "unexpected empty loop ranges"); 483 assert(loopRanges.size() == tileSizes.size() && 484 "expected as many tile sizes as loop ranges"); 485 OpBuilder::InsertionGuard guard(rewriter); 486 SmallVector<OpFoldResult> offsets(loopRanges.size()), 487 sizes(loopRanges.size()); 488 489 std::optional<ArrayAttr> mappingAttr; 490 if (!mappingVector.empty()) 491 mappingAttr = rewriter.getArrayAttr(mappingVector); 492 493 scf::ForallOp forallOp; 494 bool useNumThreads = !numThreads.empty(); 495 496 if (useNumThreads) { 497 // Prune the zero numthreads. 498 SmallVector<OpFoldResult> nonZeroNumThreads; 499 for (auto nt : numThreads) { 500 if (isConstantIntValue(nt, 0)) 501 continue; 502 nonZeroNumThreads.push_back(nt); 503 } 504 forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads, 505 destinationTensors, mappingAttr); 506 } else { 507 SmallVector<OpFoldResult> lbs, ubs, steps; 508 std::tie(lbs, ubs, steps) = 509 getLoopBounds(rewriter, loc, loopRanges, tileSizes); 510 forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, 511 destinationTensors, mappingAttr); 512 } 513 loops.push_back(forallOp); 514 515 rewriter.setInsertionPoint(forallOp.getTerminator()); 516 destinationTensors = forallOp.getRegionOutArgs(); 517 518 SmallVector<Value> tiledResults; 519 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 520 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(), 521 destinationTensors, tiledResults, resultOffsets, 522 resultSizes))) 523 return rewriter.notifyMatchFailure(loc, "failed to generate loop body"); 524 525 rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); 526 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : 527 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, 528 resultSizes)) { 529 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 530 rewriter.getIndexAttr(1)); 531 532 rewriter.create<tensor::ParallelInsertSliceOp>( 533 loc, tiledValue, destinationTensor, resultOffset, resultSize, 534 resultStride); 535 } 536 return success(); 537 } 538 539 /// Generate the tile-loop nest using the loop construct specifed in `options`. 540 /// - `options`: Tiling options specified. 541 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 542 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 543 /// - `destinationTensors` are the init values to use for the outer most loop. 544 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 545 /// most 546 /// loop. 547 /// - `loops` is an in-out parameter into which the generated loops are 548 /// populated. 549 static LogicalResult generateLoopNest( 550 RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, 551 ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes, 552 ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors, 553 YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { 554 // If the tile sizes are all zero, no loops are generated. Just call the 555 // callback function to handle untiled case. 556 if (llvm::all_of(tileSizes, isZeroIndex)) { 557 SmallVector<Value> tiledResults; 558 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 559 return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, 560 tiledResults, resultOffsets, resultSizes); 561 } 562 if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { 563 return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, 564 destinationTensors, tiledBodyFn, loops); 565 } 566 if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { 567 return generateLoopNestUsingForallOp( 568 rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector, 569 destinationTensors, tiledBodyFn, loops); 570 } 571 return rewriter.notifyMatchFailure(loc, "unhandled loop type"); 572 } 573 574 static FailureOr<SmallVector<Value>> 575 createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, 576 ArrayRef<OpFoldResult> tileSizes, 577 const scf::SCFTilingOptions &options) { 578 SmallVector<Value> initTensors; 579 Location loc = op->getLoc(); 580 switch (options.reductionStrategy) { 581 case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: 582 if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors))) 583 return failure(); 584 return initTensors; 585 case scf::SCFTilingOptions::ReductionTilingStrategy:: 586 PartialReductionOuterReduction: { 587 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); 588 if (!redOp) { 589 return rewriter.notifyMatchFailure( 590 op, "PartialReductionOuterReduction tiling strategy is only supported" 591 "for operations implementing PartialReductionOpInterface"); 592 } 593 // Get reduction dimensions. 594 // TODO: PartialReductionOpInterface should really query TilingInterface 595 // itself and find reduction dimensions. 596 SmallVector<int> reductionDims; 597 for (auto [idx, iteratorType] : 598 llvm::enumerate(op.getLoopIteratorTypes())) { 599 if (iteratorType == utils::IteratorType::reduction) 600 reductionDims.push_back(idx); 601 } 602 return redOp.generateInitialTensorForPartialReduction( 603 rewriter, loc, tileSizes, reductionDims); 604 } 605 default: 606 return rewriter.notifyMatchFailure(op, 607 "unhandled reduction tiling strategy"); 608 } 609 } 610 611 static FailureOr<TilingResult> 612 getTiledImplementation(RewriterBase &rewriter, TilingInterface op, 613 ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets, 614 ArrayRef<OpFoldResult> sizes, 615 const scf::SCFTilingOptions &options) { 616 switch (options.reductionStrategy) { 617 case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: 618 return op.getTiledImplementation(rewriter, offsets, sizes); 619 case scf::SCFTilingOptions::ReductionTilingStrategy:: 620 PartialReductionOuterReduction: { 621 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); 622 if (!redOp) { 623 return rewriter.notifyMatchFailure( 624 op, "PartialReductionOuterReduction tiling strategy is only " 625 "supported for operations " 626 "implementing PartialReductionOpInterface"); 627 } 628 // Get reduction dimensions. 629 // TODO: PartialReductionOpInterface should really query TilingInterface 630 // itself and find reduction dimensions. 631 SmallVector<int> reductionDims; 632 for (auto [idx, iteratorType] : 633 llvm::enumerate(op.getLoopIteratorTypes())) { 634 if (iteratorType == utils::IteratorType::reduction) 635 reductionDims.push_back(idx); 636 } 637 return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg, 638 offsets, sizes, reductionDims); 639 } 640 default: 641 return rewriter.notifyMatchFailure(op, 642 "unhandled reduction tiling strategy"); 643 } 644 } 645 646 static LogicalResult 647 getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, 648 TilingInterface op, ArrayRef<OpFoldResult> offsets, 649 ArrayRef<OpFoldResult> sizes, 650 SmallVector<OpFoldResult> &resultOffset, 651 SmallVector<OpFoldResult> &resultSize, 652 const scf::SCFTilingOptions &options) { 653 654 switch (options.reductionStrategy) { 655 case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: 656 return op.getResultTilePosition(rewriter, index, offsets, sizes, 657 resultOffset, resultSize); 658 case scf::SCFTilingOptions::ReductionTilingStrategy:: 659 PartialReductionOuterReduction: { 660 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); 661 if (!redOp) { 662 return rewriter.notifyMatchFailure( 663 op, "PartialReductionOuterReduction tiling strategy is only supported" 664 "for operations implementing PartialReductionOpInterface"); 665 } 666 // Get reduction dimensions. 667 // TODO: PartialReductionOpInterface should really query TilingInterface 668 // itself and find reduction dimensions. 669 SmallVector<int> reductionDims; 670 for (auto [idx, iteratorType] : 671 llvm::enumerate(op.getLoopIteratorTypes())) { 672 if (iteratorType == utils::IteratorType::reduction) 673 reductionDims.push_back(idx); 674 } 675 return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes, 676 resultOffset, resultSize, 677 reductionDims); 678 } 679 default: 680 return rewriter.notifyMatchFailure(op, 681 "unhandled reduction tiling strategy"); 682 } 683 } 684 685 static FailureOr<MergeResult> 686 mergeTilingResults(RewriterBase &rewriter, TilingInterface op, 687 ValueRange partialResults, 688 const scf::SCFTilingOptions &options) { 689 switch (options.reductionStrategy) { 690 case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: 691 // No need to merge results for reduction tiling strategy. 692 return MergeResult{{}, partialResults}; 693 case scf::SCFTilingOptions::ReductionTilingStrategy:: 694 PartialReductionOuterReduction: { 695 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); 696 if (!redOp) { 697 return rewriter.notifyMatchFailure( 698 op, "PartialReductionOuterReduction tiling strategy is only " 699 "supported for operations " 700 "implementing PartialReductionOpInterface"); 701 } 702 // Get reduction dimensions. 703 // TODO: PartialReductionOpInterface should really query TilingInterface 704 // itself and find reduction dimensions. 705 SmallVector<int> reductionDims; 706 for (auto [idx, iteratorType] : 707 llvm::enumerate(op.getLoopIteratorTypes())) { 708 if (iteratorType == utils::IteratorType::reduction) 709 reductionDims.push_back(idx); 710 } 711 return redOp.mergeReductions(rewriter, op.getLoc(), partialResults, 712 reductionDims); 713 } 714 default: 715 return rewriter.notifyMatchFailure(op, 716 "unhandled reduction tiling strategy"); 717 } 718 } 719 720 /// Append the specified additional `newInitOperands` operands to the 721 /// loops existing `init` operands (or similar), and replace `loopOp` with 722 /// the new loop that has the additional init operands. The loop body of 723 /// this loop is moved over to the new loop. `yieldTiledValuesFn` 724 /// is called to get the new tiled values returned, and the offset 725 /// and sizes at which the tiled value is inserted into the 726 /// new region iter_args that correspond to the newly added init operands. 727 template <typename LoopType> 728 FailureOr<LoopLikeOpInterface> 729 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, 730 ValueRange newInitOperands, 731 YieldTiledValuesFn yieldTiledValuesFn) { 732 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); 733 } 734 735 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`. 736 template <> 737 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( 738 scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, 739 YieldTiledValuesFn yieldTiledValuesFn) { 740 OpBuilder::InsertionGuard g(rewriter); 741 Location loc = loopOp.getLoc(); 742 rewriter.setInsertionPoint(loopOp); 743 744 auto inits = llvm::to_vector(loopOp.getInitArgs()); 745 inits.append(newInitOperands.begin(), newInitOperands.end()); 746 auto newLoop = rewriter.create<scf::ForOp>( 747 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(), 748 inits, [](OpBuilder &, Location, Value, ValueRange) {}); 749 750 // Move the loop body to the new op. 751 Block *loopBody = loopOp.getBody(); 752 Block *newLoopBody = newLoop.getBody(); 753 rewriter.mergeBlocks( 754 loopBody, newLoopBody, 755 newLoopBody->getArguments().take_front(loopBody->getNumArguments())); 756 757 auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator()); 758 rewriter.setInsertionPoint(yieldOp); 759 760 SmallVector<Value> tiledValues; 761 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 762 ValueRange newRegionIterArgs = 763 newLoop.getRegionIterArgs().take_back(newInitOperands.size()); 764 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(), 765 newRegionIterArgs, tiledValues, resultOffsets, 766 resultSizes))) { 767 rewriter.eraseOp(newLoop); 768 return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values"); 769 } 770 771 SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands()); 772 for (auto [tiledValue, regionIterArg, resultOffset, resultSize] : 773 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets, 774 resultSizes)) { 775 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 776 rewriter.getIndexAttr(1)); 777 Value insert = rewriter.create<tensor::InsertSliceOp>( 778 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize, 779 resultStride); 780 newYieldValues.push_back(insert); 781 } 782 783 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 784 rewriter.replaceOp(loopOp, 785 newLoop->getResults().take_front(loopOp.getNumResults())); 786 return cast<LoopLikeOpInterface>(newLoop.getOperation()); 787 } 788 789 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall` 790 template <> 791 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>( 792 scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, 793 YieldTiledValuesFn yieldTiledValuesFn) { 794 OpBuilder::InsertionGuard g(rewriter); 795 Location loc = loopOp.getLoc(); 796 rewriter.setInsertionPoint(loopOp); 797 auto inits = llvm::to_vector(loopOp.getOutputs()); 798 inits.append(newInitOperands.begin(), newInitOperands.end()); 799 auto newLoop = rewriter.create<scf::ForallOp>( 800 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(), 801 loopOp.getMixedStep(), inits, loopOp.getMapping(), 802 [](OpBuilder &, Location, ValueRange) {}); 803 804 // Move the region of the current block to the newly created op. 805 Block *loopBody = loopOp.getBody(); 806 Block *newLoopBody = newLoop.getBody(); 807 rewriter.mergeBlocks( 808 loopBody, newLoopBody, 809 newLoopBody->getArguments().take_front(loopBody->getNumArguments())); 810 811 auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator()); 812 rewriter.setInsertionPoint(terminator); 813 SmallVector<Value> tiledValues; 814 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 815 ValueRange regionIterArgs = 816 newLoop.getRegionIterArgs().take_back(newInitOperands.size()); 817 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(), 818 regionIterArgs, tiledValues, resultOffsets, 819 resultSizes))) { 820 rewriter.eraseOp(newLoop); 821 return rewriter.notifyMatchFailure(loopOp, 822 "failed to get yielded tiled values"); 823 } 824 825 // Update the terminator. 826 rewriter.setInsertionPointToEnd(terminator.getBody()); 827 828 for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal( 829 tiledValues, regionIterArgs, resultOffsets, resultSizes)) { 830 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 831 rewriter.getIndexAttr(1)); 832 rewriter.create<tensor::ParallelInsertSliceOp>( 833 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize, 834 resultStride); 835 } 836 837 rewriter.replaceOp(loopOp, 838 newLoop->getResults().take_front(loopOp.getNumResults())); 839 return cast<LoopLikeOpInterface>(newLoop.getOperation()); 840 } 841 842 /// Implementation of `yieldTiledValuesAndReplaceLoop` for 843 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each 844 /// supported loop type. 845 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop( 846 LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter, 847 ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { 848 return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>( 849 loopLikeOp.getOperation()) 850 .Case<scf::ForOp, scf::ForallOp>( 851 [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { 852 return yieldTiledValuesAndReplaceLoop( 853 loopOp, rewriter, newInitOperands, yieldTiledValuesFn); 854 }) 855 .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { 856 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); 857 }); 858 } 859 860 /// Method to add new init values to a loop nest. Updates `loops` in-place 861 /// with new loops that use the `newInitValues`. The outer-loops are updated 862 /// to yield the new result values of the inner loop. For the innermost loop, 863 /// the call back `getNewYields` is invoked to get the additional values to 864 /// yield form the innermost loop. 865 static LogicalResult addInitOperandsToLoopNest( 866 RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops, 867 ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) { 868 SmallVector<scf::ForOp> newLoops; 869 if (loops.empty()) 870 return success(); 871 OpBuilder::InsertionGuard g(rewriter); 872 rewriter.setInsertionPoint(loops.front()); 873 874 SmallVector<Value> ivs; 875 for (auto &loop : loops.drop_back()) { 876 rewriter.setInsertionPoint(loop); 877 878 // if loops.size() > 1 we assume that scf.for is used for the loops. 879 auto forLoop = cast<scf::ForOp>(loop.getOperation()); 880 881 // Create a new loop with the new init values for this loop. 882 SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs()); 883 newInits.append(newInitValues.begin(), newInitValues.end()); 884 auto newLoop = rewriter.create<scf::ForOp>( 885 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), 886 forLoop.getStep(), newInits, 887 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); 888 889 // Merge the body of the new loop with the body of the old loops. 890 SmallVector<Value> sourceBlockArgs; 891 sourceBlockArgs.push_back(newLoop.getInductionVar()); 892 auto newRegionIterArgs = newLoop.getRegionIterArgs(); 893 sourceBlockArgs.append( 894 newRegionIterArgs.begin(), 895 std::next(newRegionIterArgs.begin(), forLoop.getNumResults())); 896 rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs); 897 rewriter.replaceOp( 898 forLoop, newLoop.getResults().take_front(forLoop.getNumResults())); 899 loop = newLoop; 900 ivs.push_back(newLoop.getInductionVar()); 901 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size()); 902 } 903 904 // Update the loop body of the innermost loop to get new yield values. 905 LoopLikeOpInterface innerMostLoop = loops.back(); 906 FailureOr<LoopLikeOpInterface> newInnerMostLoop = 907 yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues, 908 getNewTiledYieldsFn); 909 910 if (failed(newInnerMostLoop)) 911 return innerMostLoop.emitOpError("failed to return additional yields"); 912 loops.back() = newInnerMostLoop.value(); 913 914 // Make all other loops except the innermost loops yield the values returned 915 // by the inner loop. 916 for (auto [outerLoop, innerLoop] : 917 llvm::zip_equal(loops.drop_back(), loops.drop_front())) { 918 // Again assume that all the outer loops are scf.for operations. 919 auto outerForLoop = cast<scf::ForOp>(outerLoop); 920 auto outerLoopYield = 921 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); 922 SmallVector<Value> newYields = 923 llvm::to_vector(outerLoopYield.getOperands()); 924 ValueRange additionalYields = 925 innerLoop->getResults().take_back(newInitValues.size()); 926 newYields.append(additionalYields.begin(), additionalYields.end()); 927 rewriter.setInsertionPoint(outerLoopYield); 928 rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields); 929 } 930 return success(); 931 } 932 933 /// Implementation of tiling transformation of `op` that implements the 934 /// `TilingInterface` using `scf.for` to iterate over the tiles. 935 FailureOr<scf::SCFTilingResult> 936 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, 937 const scf::SCFTilingOptions &options) { 938 if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) { 939 return failure(); 940 } 941 942 OpBuilder::InsertionGuard guard(rewriter); 943 rewriter.setInsertionPointAfter(op); 944 945 // 1. Get the range of the loops that are represented by the operation. 946 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 947 948 // 2. Materialize the tile sizes and/or number of threads; 949 SmallVector<OpFoldResult> tileSizes, numThreads; 950 std::tie(tileSizes, numThreads) = 951 getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options); 952 953 // Check if it is safe to tile. This is hold over from previous iterations 954 // of tile to for-all. Consider dropping it. 955 if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { 956 checkSafeToTileToForall(op, tileSizes, numThreads); 957 } 958 959 // 3. If there is an interchange specified, permute the iteration domain and 960 // the tile sizes. 961 SmallVector<int64_t> interchangeVector; 962 if (!options.interchangeVector.empty()) { 963 interchangeVector = fillInterchangeVector(options.interchangeVector, 964 iterationDomain.size()); 965 assert(isPermutationVector(interchangeVector) && 966 "expected interchange vector to be a permutation"); 967 968 applyPermutationToVector(iterationDomain, interchangeVector); 969 applyPermutationToVector(tileSizes, interchangeVector); 970 if (!numThreads.empty()) 971 applyPermutationToVector(numThreads, interchangeVector); 972 } 973 974 FailureOr<TilingResult> tilingResult; 975 // 4. Define the lambda function used later to generate the body of the 976 // innermost tiled loop. 977 YieldTiledValuesFn innerYieldTiledValuesFn = 978 [&](RewriterBase &rewriter, Location loc, ValueRange ivs, 979 ValueRange regionIterArgs, SmallVector<Value> &tiledResults, 980 SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 981 SmallVector<SmallVector<OpFoldResult>> &resultSizes) 982 -> LogicalResult { 983 // 4a. Compute the `offsets` and `sizes` to use for tiling. 984 SmallVector<OpFoldResult> offsets, sizes; 985 std::tie(offsets, sizes) = getTileOffsetAndSizes( 986 rewriter, loc, ivs, iterationDomain, tileSizes, numThreads); 987 988 // 4b. If interchange was provided, apply inverse of the interchange 989 // to get back the offsets/sizes in the order to be specified. 990 if (!interchangeVector.empty()) { 991 auto inversePermutation = invertPermutationVector(interchangeVector); 992 applyPermutationToVector(offsets, inversePermutation); 993 applyPermutationToVector(sizes, inversePermutation); 994 } 995 996 // 5. Generate the tiled implementation within the inner most loop. 997 998 // 5a. Clone the operation within the loop body. 999 auto clonedOp = cast<TilingInterface>( 1000 cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs)); 1001 1002 // 5b. Early return cloned op if tiling is not happening. We can not 1003 // return the original op because it could lead to `rewriter.replaceOp(op, 1004 // op->getResults())` and users would get crash. 1005 if (llvm::all_of(tileSizes, isZeroIndex)) { 1006 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); 1007 tilingResult = 1008 TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(), 1009 /*generatedSlices=*/{}}; 1010 return success(); 1011 } 1012 1013 // 5c. Tile the cloned operation. 1014 tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs, 1015 offsets, sizes, options); 1016 if (failed(tilingResult)) { 1017 rewriter.eraseOp(clonedOp); 1018 return op.emitOpError("faild to tile operation"); 1019 } 1020 1021 // 5d. Delete the cloned operation. 1022 rewriter.eraseOp(clonedOp); 1023 1024 // 5e. Compute the offsets at which the result values are to be inserted 1025 // back into its destinations. 1026 for (auto [index, tiledValue] : 1027 llvm::enumerate(tilingResult->tiledValues)) { 1028 tiledResults.push_back(tiledValue); 1029 SmallVector<OpFoldResult> resultOffset, resultSize; 1030 if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets, 1031 sizes, resultOffset, resultSize, 1032 options))) { 1033 for (auto op : tilingResult->tiledOps) { 1034 rewriter.eraseOp(op); 1035 } 1036 return rewriter.notifyMatchFailure( 1037 op, "failed to get slice of result produced"); 1038 } 1039 resultOffsets.emplace_back(std::move(resultOffset)); 1040 resultSizes.emplace_back(std::move(resultSize)); 1041 } 1042 1043 return success(); 1044 }; 1045 1046 // 6. Find the destination tensors to use for the operation. 1047 FailureOr<SmallVector<Value>> maybeInits = 1048 createInitialTensorsForTiling(rewriter, op, tileSizes, options); 1049 if (failed(maybeInits)) { 1050 return rewriter.notifyMatchFailure( 1051 op, "unable to create initial tensors for tiling"); 1052 } 1053 SmallVector<Value> &initTensors = maybeInits.value(); 1054 1055 // 7. Generate the tiled loops nest using the callback defined above. 1056 SmallVector<LoopLikeOpInterface> loops; 1057 if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain, 1058 tileSizes, numThreads, initTensors, 1059 innerYieldTiledValuesFn, loops))) 1060 return op.emitOpError("failed to generate tiling loops"); 1061 assert(succeeded(tilingResult) && 1062 "expected tiling result to be computed after loop generation"); 1063 1064 SmallVector<Value> partialResults; 1065 if (loops.empty()) { 1066 // If loops are empty, the tiled op is used as the replacement for the 1067 // untiled op. 1068 partialResults = tilingResult->tiledValues; 1069 } else { 1070 partialResults = llvm::map_to_vector(loops.front()->getResults(), 1071 [](OpResult r) -> Value { return r; }); 1072 } 1073 1074 FailureOr<MergeResult> mergeResult = 1075 mergeTilingResults(rewriter, op, partialResults, options); 1076 if (failed(mergeResult)) { 1077 return rewriter.notifyMatchFailure( 1078 op, "Failed to merge partial results from tiling"); 1079 } 1080 1081 return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops, 1082 mergeResult.value(), 1083 tilingResult->generatedSlices}; 1084 } 1085 1086 FailureOr<scf::SCFTilingResult> 1087 mlir::scf::tileReductionUsingScf(RewriterBase &b, 1088 PartialReductionOpInterface op, 1089 ArrayRef<OpFoldResult> tileSizes) { 1090 SCFTilingOptions options; 1091 options.setLoopType(SCFTilingOptions::LoopType::ForOp); 1092 options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy:: 1093 PartialReductionOuterReduction); 1094 options.setTileSizes(tileSizes); 1095 1096 TilingInterface tilingInterfaceOp = 1097 dyn_cast<TilingInterface>(op.getOperation()); 1098 if (!tilingInterfaceOp) { 1099 return b.notifyMatchFailure( 1100 op, 1101 "Operation implementing PartialReductionOpInterface should implement " 1102 "TilingInterface"); 1103 } 1104 1105 return tileUsingSCF(b, tilingInterfaceOp, options); 1106 } 1107 1108 //===----------------------------------------------------------------------===// 1109 // tileConsumerAndFuseProducersUsingSCF implementation. 1110 //===----------------------------------------------------------------------===// 1111 1112 /// Return the untiled producer whose slice is used in a tiled consumer. The 1113 /// method traverses the tile loop nest (`loops`) if needed, and returns the 1114 /// `iter_args` of the outer most that is encountered. Traversing the 1115 /// iter_args indicates that this is a destination operand of the consumer. If 1116 /// there was no loop traversal needed, the second value of the returned tuple 1117 /// is empty. 1118 static std::tuple<OpResult, std::optional<OpOperand *>> 1119 getUntiledProducerFromSliceSource(OpOperand *source, 1120 ArrayRef<LoopLikeOpInterface> loops) { 1121 std::optional<OpOperand *> destinationIterArg; 1122 auto loopIt = loops.rbegin(); 1123 while (auto iterArg = dyn_cast<BlockArgument>(source->get())) { 1124 auto loop = *loopIt; 1125 if (iterArg.getOwner()->getParentOp() != loop) 1126 break; 1127 source = loop.getTiedLoopInit(iterArg); 1128 loopIt++; 1129 } 1130 if (loopIt == loops.rend()) 1131 destinationIterArg = source; 1132 return {dyn_cast<OpResult>(source->get()), destinationIterArg}; 1133 } 1134 1135 /// Implementation of fusing producer of a single slice by computing the 1136 /// slice of the producer in-place. 1137 std::optional<scf::SCFFuseProducerOfSliceResult> 1138 mlir::scf::tileAndFuseProducerOfSlice( 1139 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, 1140 MutableArrayRef<LoopLikeOpInterface> loops) { 1141 // 1. Get the producer of the source (potentially walking through 1142 // `iter_args` of nested `scf.for`) 1143 auto [fusableProducer, destinationInitArg] = 1144 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), 1145 loops); 1146 if (!fusableProducer) 1147 return std::nullopt; 1148 unsigned resultNumber = fusableProducer.getResultNumber(); 1149 1150 OpBuilder::InsertionGuard g(rewriter); 1151 rewriter.setInsertionPoint(candidateSliceOp); 1152 1153 // 2. Clone the fused producer 1154 // 2a. Compute the destination operands to use for the cloned operation. 1155 SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors; 1156 Operation *fusableProducerOp = fusableProducer.getOwner(); 1157 if (isa<DestinationStyleOpInterface>(fusableProducerOp) && 1158 failed(tensor::getOrCreateDestinations( 1159 rewriter, fusableProducerOp->getLoc(), fusableProducerOp, 1160 origDestinationTensors))) 1161 return std::nullopt; 1162 1163 clonedOpDestinationTensors = origDestinationTensors; 1164 if (destinationInitArg && 1165 isa<DestinationStyleOpInterface>(fusableProducerOp)) { 1166 // 2b. If the producer is also destination style, then to maintain the 1167 // destination passing style, update the destination of the producer to be 1168 // the source of the slice. 1169 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource(); 1170 } 1171 // 2c. Clone the fused producer. 1172 Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs( 1173 rewriter, fusableProducerOp, clonedOpDestinationTensors); 1174 // 2d. Update the source of the candidateSlice to be the cloned producer. 1175 // Easier to just clone the slice with different source since 1176 // replacements and DCE of cloned ops becomes easier 1177 SmallVector<Value> candidateSliceOpOperands = 1178 llvm::to_vector(candidateSliceOp->getOperands()); 1179 candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber); 1180 tensor::ExtractSliceOp clonedCandidateSliceOp = 1181 mlir::clone(rewriter, candidateSliceOp, 1182 candidateSliceOp->getResultTypes(), candidateSliceOpOperands); 1183 1184 // 3. Generate the tiled implementation of the producer of the source 1185 FailureOr<TilingResult> tileAndFuseResult = 1186 tensor::replaceExtractSliceWithTiledProducer( 1187 rewriter, clonedCandidateSliceOp, 1188 clonedProducerOp->getResult(resultNumber)); 1189 if (failed(tileAndFuseResult)) 1190 return std::nullopt; 1191 // Note: Do not delete the candidateSliceOp, since its passed in from the 1192 // caller. 1193 rewriter.replaceAllUsesWith(candidateSliceOp, 1194 tileAndFuseResult->tiledValues[0]); 1195 rewriter.eraseOp(clonedCandidateSliceOp); 1196 rewriter.eraseOp(clonedProducerOp); 1197 1198 // 3. If the slice is for a destination operand, for example, 1199 // 1200 // ```mlir 1201 // %0 = linalg.init 1202 // %1 = linalg.fill .. outs(%0 : ) 1203 // %2 = scf.for .. iter_args(%arg0 = %1) { 1204 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 1205 // %4 = tensor.extract_slice %arg1 [..] 1206 // .. = linalg.matmul .. outs(%4 : ) 1207 // } 1208 // } 1209 // ``` 1210 // 1211 // the IR is currently 1212 // 1213 // ``` 1214 // %0 = linalg.init 1215 // %1 = linalg.fill 1216 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 1217 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 1218 // %4 = tensor.extract_slice %arg1[..] 1219 // %5 = linalg.fill .. outs(%4 : ) 1220 // .. = linalg.matmul .. outs(%5 : ) 1221 // } 1222 // } 1223 // ``` 1224 // 1225 // The untiled `linalg.fill` is still used as the `init_value` since it 1226 // was originally a destination operand of the untiled `linalg.matmul`. 1227 // When fusing an operand that is a destination operand, the iter_arg of 1228 // the outer most loop should be changed to use the destination of the 1229 // fused operation. With this the IR will be. 1230 // 1231 // ``` 1232 // %0 = linalg.init 1233 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 1234 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 1235 // %3 = tensor.extract_slice %arg1[..] 1236 // %4 = linalg.fill .. outs(%3 : ) 1237 // .. = linalg.matmul .. outs(%4 : ) 1238 // } 1239 // } 1240 // ``` 1241 if (destinationInitArg && 1242 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) { 1243 loops.front() 1244 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] 1245 .set(origDestinationTensors[resultNumber]); 1246 } 1247 return scf::SCFFuseProducerOfSliceResult{ 1248 fusableProducer, tileAndFuseResult->tiledValues[0], 1249 tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices}; 1250 } 1251 1252 /// Reconstruct the fused producer from within the tiled-and-fused code. 1253 FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer( 1254 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, 1255 scf::SCFFuseProducerOfSliceResult fusedProducerInfo, 1256 MutableArrayRef<LoopLikeOpInterface> loops, 1257 ArrayRef<unsigned> yieldResultNumber) { 1258 if (loops.empty()) 1259 return success(); 1260 1261 Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(), 1262 *tiledOwner = fusedProducerInfo.tiledOps[0]; 1263 1264 Location loc = originalOwner->getLoc(); 1265 // a. collect all init Value to be appended 1266 SmallVector<unsigned> initNumberList = 1267 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>( 1268 0, originalOwner->getNumResults())) 1269 : llvm::to_vector(yieldResultNumber); 1270 SmallVector<Value> initValueList; 1271 for (const auto &resultNumber : initNumberList) { 1272 FailureOr<Value> initValue = tensor::getOrCreateDestination( 1273 rewriter, loc, originalOwner->getResult(resultNumber)); 1274 if (succeeded(initValue)) { 1275 initValueList.push_back(initValue.value()); 1276 } else { 1277 return failure(); 1278 } 1279 } 1280 1281 SmallVector<Operation *> generatedSlices; 1282 YieldTiledValuesFn newYieldValuesFn = 1283 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, 1284 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, 1285 SmallVector<SmallVector<OpFoldResult>> &tiledOffset, 1286 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { 1287 OpBuilder::InsertionGuard g(innerRewriter); 1288 1289 // get sliceOp tile information 1290 SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(), 1291 sliceSizes = sliceOp.getMixedSizes(); 1292 1293 // expect all strides of sliceOp being 1 1294 if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { 1295 return !isConstantIntValue(ofr, 1); 1296 })) 1297 return failure(); 1298 1299 unsigned sliceResultNumber = 1300 fusedProducerInfo.origProducer.getResultNumber(); 1301 1302 auto tilableOp = cast<TilingInterface>(originalOwner); 1303 // b. get iterDomain Offset and Sizes based on sliceOp tile 1304 SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes; 1305 // skip tensor.pack/unpack/pad, which expects single opResult 1306 if (tilableOp->getNumResults() > 1 && 1307 failed(tilableOp.getIterationDomainTileFromResultTile( 1308 rewriter, sliceResultNumber, sliceOffset, sliceSizes, 1309 iterDomainOffset, iterDomainSizes))) { 1310 // In theory, it is unnecessary to raise an error here. Actually 1311 // although it fails to reconstruct the result tensor, it should not 1312 // broke current fusion anyway. The reason why we must return failure 1313 // currently is that the callback function `newYieldValuesFn` will be 1314 // called after new init operand(s) has already been appended. It will 1315 // take more refactoring to make sure the init operands are added 1316 // consistently in the future. For more details, please refer to: 1317 // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814 1318 return failure(); 1319 } 1320 1321 // c. calculate offsets and sizes info of all OpResults respectively based 1322 // on iteration Domain Tile 1323 SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList; 1324 for (const auto &resultNumber : initNumberList) { 1325 if (resultNumber == sliceResultNumber) { 1326 offsetList.push_back(sliceOffset); 1327 sizesList.push_back(sliceSizes); 1328 } else { 1329 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty()); 1330 // infer result tile according to the iteration domain tile 1331 SmallVector<OpFoldResult> offset, sizes; 1332 if (failed(tilableOp.getResultTilePosition( 1333 rewriter, resultNumber, iterDomainOffset, iterDomainSizes, 1334 offset, sizes))) { 1335 return failure(); 1336 } 1337 offsetList.push_back(offset); 1338 sizesList.push_back(sizes); 1339 } 1340 } 1341 1342 // d. create `extract_slice` for `iter_args` for DPS operation if 1343 // necessary 1344 if (auto tiledDestStyleOp = 1345 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) { 1346 rewriter.setInsertionPoint(tiledDestStyleOp); 1347 for (const auto &&[index, newRegionArg] : 1348 llvm::enumerate(newRegionIterArgs)) { 1349 auto destSlice = rewriter.create<tensor::ExtractSliceOp>( 1350 loc, newRegionArg, offsetList[index], sizesList[index], 1351 SmallVector<OpFoldResult>(offsetList[index].size(), 1352 rewriter.getIndexAttr(1))); 1353 generatedSlices.push_back(destSlice); 1354 unsigned resultNumber = initNumberList[index]; 1355 rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { 1356 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); 1357 }); 1358 } 1359 } 1360 1361 // e. prepare tiled offset and sizes for later `insert_slice` creation by 1362 // caller 1363 Block *block = rewriter.getInsertionPoint()->getBlock(); 1364 rewriter.setInsertionPoint(block->getTerminator()); 1365 for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) { 1366 tiledResult.push_back(tiledOwner->getResult(resultNumber)); 1367 tiledOffset.emplace_back(offsetList[index]); 1368 tiledSizes.emplace_back(sizesList[index]); 1369 } 1370 return success(); 1371 }; 1372 1373 if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList, 1374 newYieldValuesFn))) { 1375 return failure(); 1376 } 1377 return generatedSlices; 1378 } 1379 1380 namespace { 1381 1382 //===----------------------------------------------------------------------===// 1383 // SliceTrackingListener 1384 //===----------------------------------------------------------------------===// 1385 1386 /// This class is a listener for tracking the insertion and removal of 1387 /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy 1388 /// fusion algorithm to apply cleanup patterns in between fusion steps. 1389 class SliceTrackingListener : public RewriterBase::Listener { 1390 public: 1391 explicit SliceTrackingListener( 1392 std::optional<FrozenRewritePatternSet> patterns); 1393 SliceTrackingListener() = default; 1394 1395 /// Adds the given list of operations to the worklist, and if present, 1396 /// applies the list of `patterns` to the newly added operations. This only 1397 /// processes the given operations and any newly inserted ones by the 1398 /// pattern set. 1399 LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps); 1400 1401 /// Add to the new operation worklist if it is an extract_slice. 1402 void notifyOperationInserted(Operation *op, 1403 OpBuilder::InsertPoint previous) override; 1404 1405 /// Shared helper for operation removal from the worklist. 1406 void removeOp(Operation *op); 1407 1408 /// Remove the operation from the worklist. 1409 void notifyOperationErased(Operation *op) override; 1410 1411 /// Remove the operation from the worklist. 1412 void notifyOperationReplaced(Operation *op, ValueRange replacement) override; 1413 1414 /// The worklist for this transformation keeps track of the slices to visit 1415 /// next for fusion. 1416 std::deque<tensor::ExtractSliceOp> worklist; 1417 1418 private: 1419 /// Optional pattern set to apply when adding new operations to the 1420 /// worklist. 1421 std::optional<FrozenRewritePatternSet> patterns = std::nullopt; 1422 }; 1423 1424 SliceTrackingListener::SliceTrackingListener( 1425 std::optional<FrozenRewritePatternSet> p) { 1426 patterns = std::move(p); 1427 } 1428 1429 LogicalResult 1430 SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) { 1431 for (Operation *op : ops) { 1432 if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op)) 1433 worklist.push_back(slice); 1434 } 1435 1436 if (!patterns) 1437 return success(); 1438 1439 GreedyRewriteConfig config; 1440 config.listener = this; 1441 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 1442 return applyOpPatternsGreedily(ops, patterns.value(), config); 1443 } 1444 1445 void SliceTrackingListener::notifyOperationInserted( 1446 Operation *op, OpBuilder::InsertPoint previous) { 1447 auto slice = dyn_cast<tensor::ExtractSliceOp>(op); 1448 if (!slice) 1449 return; 1450 worklist.push_back(slice); 1451 } 1452 1453 // Scan the worklist for the given op and remove it if present. The 1454 // expectation is for the worklist to be small and for removal to be 1455 // relatively rare. 1456 void SliceTrackingListener::removeOp(Operation *op) { 1457 if (!isa<tensor::ExtractSliceOp>(op)) 1458 return; 1459 auto iter = worklist.begin(); 1460 while (iter != worklist.end()) { 1461 if (*iter == op) 1462 break; 1463 iter++; 1464 } 1465 if (iter == worklist.end()) 1466 return; 1467 1468 worklist.erase(iter); 1469 } 1470 1471 void SliceTrackingListener::notifyOperationErased(Operation *op) { 1472 removeOp(op); 1473 } 1474 1475 void SliceTrackingListener::notifyOperationReplaced(Operation *op, 1476 ValueRange replacement) { 1477 removeOp(op); 1478 } 1479 1480 //===----------------------------------------------------------------------===// 1481 // ReplacementListener 1482 //===----------------------------------------------------------------------===// 1483 1484 /// Listener that tracks updates replacements for values which can be mutated. 1485 /// This listener runs on top of the existing listener for the rewriter, 1486 /// to make sure external users can still run listeners. 1487 class ReplacementListener : public RewriterBase::ForwardingListener { 1488 public: 1489 ReplacementListener(DenseMap<Value, Value> &replacements, 1490 OpBuilder::Listener *listener) 1491 : ForwardingListener(listener), replacements(replacements) {} 1492 1493 void updateReplacementValues(ValueRange origValues, 1494 ValueRange replaceValues) { 1495 // This can probably be written better, but just iterates over the map 1496 // and the new replacements for now. 1497 for (auto &[key, val] : replacements) { 1498 for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) { 1499 if (val == orig) { 1500 val = replace; 1501 } 1502 } 1503 } 1504 } 1505 1506 void notifyOperationReplaced(Operation *op, Operation *newOp) override { 1507 ForwardingListener::notifyOperationReplaced(op, newOp); 1508 updateReplacementValues(op->getResults(), newOp->getResults()); 1509 } 1510 1511 void notifyOperationReplaced(Operation *op, ValueRange values) override { 1512 ForwardingListener::notifyOperationReplaced(op, values); 1513 updateReplacementValues(op->getResults(), values); 1514 } 1515 1516 private: 1517 DenseMap<Value, Value> &replacements; 1518 }; 1519 1520 } // namespace 1521 1522 /// Implementation of tile consumer and fuse producer greedily. 1523 FailureOr<scf::SCFTileAndFuseResult> 1524 mlir::scf::tileConsumerAndFuseProducersUsingSCF( 1525 RewriterBase &rewriter, TilingInterface consumer, 1526 const scf::SCFTileAndFuseOptions &options) { 1527 // This transformation is only valid for ops that return values (i.e. not 1528 // valid to use with operations that have memref operands). 1529 if (!consumer->getNumResults()) { 1530 return rewriter.notifyMatchFailure( 1531 consumer, "invalid pattern for op with no results"); 1532 } 1533 1534 // 1. First tile the consumer. 1535 SetVector<Operation *> fusedProducers, tiledAndFusedOps; 1536 llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum; 1537 1538 FailureOr<scf::SCFTilingResult> tilingResult = 1539 tileUsingSCF(rewriter, consumer, options.tilingOptions); 1540 1541 if (failed(tilingResult)) 1542 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 1543 for (auto *tiledOp : tilingResult->tiledOps) 1544 tiledAndFusedOps.insert(tiledOp); 1545 1546 DenseMap<Value, Value> replacements; 1547 for (auto [origVal, replacement] : llvm::zip_equal( 1548 consumer->getResults(), tilingResult->mergeResult.replacements)) { 1549 replacements[origVal] = replacement; 1550 } 1551 1552 // If there are no loops generated, fusion is immaterial. 1553 auto &loops = tilingResult->loops; 1554 if (loops.empty()) { 1555 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, 1556 replacements}; 1557 } 1558 1559 // Since the loop gets potentially replaced during fusion, we need to track 1560 // the mutation of replacement values. To do this, we attach a listener to 1561 // update the replacements as they happen. 1562 OpBuilder::Listener *previousListener = rewriter.getListener(); 1563 auto resetListener = 1564 llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); }); 1565 ReplacementListener replaceListener(replacements, previousListener); 1566 rewriter.setListener(&replaceListener); 1567 1568 // 2. Typically, the operands of the tiled operation are slices of the 1569 // operands of the untiled operation. These are expressed in IR using 1570 // `tensor.extract_slice` operations with source being the operands of 1571 // the untiled operation. Create a worklist of these 1572 // `tensor.extract_slice` operations. If the producers of the source of 1573 // the `tensor.extract_slice` can be tiled such that the tiled value is 1574 // generated in-place, that effectively tiles + fuses the operations. 1575 struct WorklistItem { 1576 tensor::ExtractSliceOp candidateSlice; 1577 SCFTileAndFuseOptions::ControlFnResult controlFnResult; 1578 }; 1579 1580 SliceTrackingListener sliceTracker = 1581 SliceTrackingListener(options.cleanupPatterns); 1582 1583 if (failed( 1584 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) { 1585 return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); 1586 } 1587 OpBuilder::InsertionGuard g(rewriter); 1588 while (!sliceTracker.worklist.empty()) { 1589 auto candidateSlice = sliceTracker.worklist.front(); 1590 sliceTracker.worklist.pop_front(); 1591 1592 auto [fusableProducer, destinationInitArg] = 1593 getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(), 1594 loops); 1595 if (!fusableProducer) 1596 continue; 1597 1598 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult = 1599 options.fusionControlFn(candidateSlice, fusableProducer, 1600 destinationInitArg.has_value()); 1601 if (!controlFnResult) 1602 continue; 1603 1604 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()}; 1605 1606 // The operands of the fused producer might themselved be slices of 1607 // values produced by operations that implement the `TilingInterface`. 1608 // Add these operations to the worklist. 1609 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult = 1610 tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice, 1611 loops); 1612 if (!fusedResult) 1613 continue; 1614 1615 SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices; 1616 1617 if (worklistItem.controlFnResult.yieldProducerReplacement) { 1618 // Reconstruct and yield all opResult of fusableProducerOp by default. 1619 // The caller can specific which one to yield by designating optional 1620 // argument named `yieldResultNumber` of 1621 // `yieldReplacementForFusedProducer`. 1622 Operation *fusableProducerOp = fusedResult->origProducer.getOwner(); 1623 FailureOr<SmallVector<Operation *>> newSlices = 1624 yieldReplacementForFusedProducer(rewriter, 1625 worklistItem.candidateSlice, 1626 fusedResult.value(), loops); 1627 if (failed(newSlices)) { 1628 return rewriter.notifyMatchFailure( 1629 fusableProducerOp, "failed to replacement value for this " 1630 "operation from within the tiled loop"); 1631 } 1632 worklistCandidates.append(newSlices.value()); 1633 for (auto [index, result] : 1634 llvm::enumerate(fusableProducerOp->getResults())) { 1635 replacements[result] = loops.front()->getResult( 1636 loops.front()->getNumResults() - 1637 fusableProducerOp->getNumResults() + index); 1638 } 1639 } 1640 if (Operation *tiledAndFusedOp = 1641 fusedResult->tiledAndFusedProducer.getDefiningOp()) { 1642 fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); 1643 tiledAndFusedOps.insert(tiledAndFusedOp); 1644 } 1645 1646 if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) { 1647 return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); 1648 } 1649 } 1650 1651 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, 1652 replacements}; 1653 } 1654 1655 //===----------------------------------------------------------------------===// 1656 // tileAndFuseConsumerUsingSCF implementation. 1657 //===----------------------------------------------------------------------===// 1658 1659 /// A utility function that checks whether the only use of the result of a 1660 /// tensor.insert_slice op is in a scf.yield op. 1661 static LogicalResult 1662 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { 1663 Value result = candidateSliceOp.getResult(); 1664 Value::use_range uses = result.getUses(); 1665 if (!llvm::hasSingleElement(uses)) { 1666 LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n"); 1667 return failure(); 1668 } 1669 OpOperand &operandUse = (*uses.begin()); 1670 Operation *userOp = operandUse.getOwner(); 1671 if (!isa<scf::YieldOp>(userOp)) { 1672 LLVM_DEBUG(llvm::dbgs() 1673 << "Expected scf.yield to be the only user, but got -> " 1674 << (*userOp)); 1675 return failure(); 1676 } 1677 if (result.getDefiningOp()->getBlock() != userOp->getBlock()) { 1678 LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to " 1679 "be in the same block\n"); 1680 return failure(); 1681 } 1682 return success(); 1683 } 1684 1685 /// An utility to get the first user of the given loopOp. If any of user stay 1686 /// in different block of loopOp, return failure. 1687 static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) { 1688 if (!isa<LoopLikeOpInterface>(loopOp)) 1689 return failure(); 1690 Operation *firstUserOfLoop = nullptr; 1691 for (Operation *userOp : loopOp->getUsers()) { 1692 // `ParallelInsertSlice` located inside `InParallelOp` has no same parent 1693 // block with any other types of operation. Thus, just redirecting to its 1694 // parent `InParallelOp`. E.g. 1695 // 1696 // ``` 1697 // %1 = scf.for { 1698 // ... 1699 // } 1700 // %2 = consumerOp ins(%1, ...) 1701 // scf.forall.in_parallel { 1702 // tensor.parallel_insert_slice %1 1703 // } 1704 // ``` 1705 // where `InParallelOp` but not `ParallelInsertSlice` stays in the same 1706 // same block with `consumerOp`. 1707 if (isa<tensor::ParallelInsertSliceOp>(userOp)) 1708 userOp = userOp->getParentOfType<scf::InParallelOp>(); 1709 1710 if (loopOp->getBlock() != userOp->getBlock()) 1711 return failure(); 1712 1713 if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop)) 1714 firstUserOfLoop = userOp; 1715 } 1716 return firstUserOfLoop; 1717 } 1718 1719 /// This utility currently checks whether the first userOp of loop is NOT 1720 /// before the last defineOp of consumer operand. Because that we need to move 1721 /// the whole loop structure right before the `firstUserOfLoop`. This utility 1722 /// thus helps ensuring that no invalid IR is formed, i.e. no backward slice 1723 /// of consumerOp is dominated by the `firstUserOfLoop`. Saying that: 1724 /// 1725 /// ``` 1726 /// %0 = scf.for() { 1727 /// ... 1728 /// } 1729 /// ... 1730 /// %1 = firstUserOfLoop(%0) 1731 /// ... 1732 /// %2 = lastDefOfConsumerOperand 1733 /// ... 1734 /// %3 = consumerOp(%2) 1735 /// ``` 1736 /// 1737 /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it 1738 /// would be invalid to move the `loopOp` right before the `firstUserOfLoop`, 1739 /// a.k.a. use-def chain violation: 1740 /// 1741 /// ``` 1742 /// %0:2 = scf.for() { 1743 /// // use before define error 1744 /// %3 = tiledConsumerOp(%2) 1745 /// } 1746 /// %1 = firstUserOfLoop(%0) 1747 /// ... 1748 /// %2 = lastDefOfConsumerOperand 1749 /// ``` 1750 /// 1751 /// @param loopOp: loop operation 1752 /// @param consumerOp: consumer operation 1753 /// @param reorderOperations: the flag controls whether to reorder the 1754 /// backward slice w.r.t. the defineOp of `consumerOp` operands. 1755 /// @return: computed backward slice of consumerOp, but excluding those 1756 /// already dominates `firstUserOfLoop`. 1757 static FailureOr<llvm::SetVector<Operation *>> 1758 checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, 1759 bool reorderOperations) { 1760 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp); 1761 if (failed(firstUserOfLoop)) 1762 return failure(); 1763 1764 BackwardSliceOptions options; 1765 DominanceInfo dominanceInfo; 1766 options.inclusive = true; 1767 options.omitBlockArguments = true; 1768 bool includeLoopOp = false; 1769 options.filter = [&](Operation *op) { 1770 if (op == loopOp) { 1771 includeLoopOp = true; 1772 return false; 1773 } 1774 // Cut off the slice to not include any operation that already dominates 1775 // firstUserOfLoop. 1776 return !dominanceInfo.properlyDominates(op, *firstUserOfLoop); 1777 }; 1778 llvm::SetVector<Operation *> slice; 1779 for (auto operand : consumerOp->getOperands()) { 1780 getBackwardSlice(operand, &slice, options); 1781 } 1782 1783 if (!slice.empty()) { 1784 // If consumerOp has one producer, which is also the user of loopOp. 1785 // E.g. 1786 // ``` 1787 // %0 = %loopOp 1788 // %1 = consumerOp1 ins(%0) 1789 // %2 = consumerOp2 ins(%0, %1) 1790 // ``` 1791 // We can not fuse consumerOp2 into loopOp due to UD chain, unless 1792 // consumerOp1 has already been fused into loopOp before. 1793 if (includeLoopOp || !reorderOperations) 1794 return failure(); 1795 } 1796 1797 return slice; 1798 } 1799 1800 /// Fetches the OpOperand of the first valid user (and use) of the value `val` 1801 /// which implements `TilingInterface` and `DestinationStyleOpInterface`. 1802 /// Returns failure otherwise. 1803 static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter, 1804 Operation *loopOp, 1805 unsigned resultNumber) { 1806 if (!isa<LoopLikeOpInterface>(loopOp)) 1807 return failure(); 1808 Value val = loopOp->getResult(resultNumber); 1809 Block *loopBlock = loopOp->getBlock(); 1810 for (OpOperand &opOperand : val.getUses()) { 1811 Operation *consumerOp = opOperand.getOwner(); 1812 // Step 1. Check if the user is tilable. 1813 if (!isa<TilingInterface>(consumerOp) || 1814 !isa<DestinationStyleOpInterface>(consumerOp)) { 1815 // TODO: We have to init result of consumer before scf.for, use 1816 // DestinationStyleOpInterface to get result shape from init for now. 1817 // Add support for other op such as op has InferTypeOpInterface. 1818 continue; 1819 } 1820 // Step 2. Check if user stay in the same block. 1821 if (loopBlock != consumerOp->getBlock()) 1822 continue; 1823 // Step 3. Check if user has succeeding user. Otherwise, it usually 1824 // represents already tiled. 1825 if (consumerOp->use_empty()) 1826 continue; 1827 // Step 4. Check assumption for loop with `reorderOperations` enabled. 1828 FailureOr<llvm::SetVector<Operation *>> slice = 1829 checkAssumptionForLoop(loopOp, consumerOp, true); 1830 if (failed(slice)) 1831 continue; 1832 // Step 5. If backward sice is not empty, move them before 1833 // firstUserOfLoop. 1834 if (!slice->empty()) { 1835 mlir::topologicalSort(*slice); 1836 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp); 1837 assert(succeeded(firstUserOfLoop) && "First user of loop is not found"); 1838 for (auto op : *slice) { 1839 rewriter.moveOpBefore(op, *firstUserOfLoop); 1840 } 1841 } 1842 return &opOperand; 1843 } 1844 return failure(); 1845 } 1846 1847 /// Find the perfectly nested loops outside of given loop(included) sorted 1848 /// from outer to inner. 1849 /// 1850 /// E.g. 1851 /// 1852 /// ``` 1853 /// %0 = scf.for() 1854 /// %1 = scf.for() 1855 /// %2 = scf.for() 1856 /// %3 = ... 1857 /// yield %3 1858 /// yield %2 1859 /// yield %1 1860 /// ``` 1861 /// 1862 /// This function will return three perfectly nested loops: %0 + %1 + %2, when 1863 /// target inner loop is %2. 1864 static SmallVector<scf::ForOp> 1865 getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) { 1866 SmallVector<scf::ForOp> nestLoops = {loop}; 1867 auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp()); 1868 1869 // Check if it is the ForOp that yield the result of inner loop. 1870 auto isForOpYieldResultOfInnerLoop = 1871 [](scf::ForOp outerLoop) -> LogicalResult { 1872 Block *body = outerLoop.getBody(); 1873 if (!llvm::hasSingleElement(body->without_terminator())) 1874 return failure(); 1875 auto yieldOp = cast<scf::YieldOp>(body->getTerminator()); 1876 auto innerForOp = dyn_cast<scf::ForOp>(body->front()); 1877 if (!innerForOp) 1878 return failure(); 1879 // All of innerForOp results should be yielded. 1880 return success(innerForOp->getNumResults() == yieldOp->getNumOperands()); 1881 }; 1882 1883 while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) { 1884 nestLoops.push_back(outerLoop); 1885 outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp()); 1886 } 1887 // sorted from outer to inner 1888 return {nestLoops.rbegin(), nestLoops.rend()}; 1889 } 1890 1891 /// Fetch the untiled consumer of a scf.for's result which is yielded by a 1892 /// tensor.insert_slice. This function makes the following assumptions : 1893 /// 1. tensor.insert_slice has scf.yield as its only user. 1894 /// 2. scf.for's corresponding result has only one use. 1895 static FailureOr<OpOperand *> 1896 getUntiledConsumerFromSlice(RewriterBase &rewriter, 1897 tensor::InsertSliceOp candidateSliceOp) { 1898 if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) 1899 return failure(); 1900 Value sliceResult = candidateSliceOp.getResult(); 1901 // Step 1. Fetch the corresponding output. 1902 OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); 1903 unsigned resultNumber = yieldOpOperand.getOperandNumber(); 1904 // Step 2. Check containing op is scf.for. 1905 Operation *containingOp = candidateSliceOp->getParentOp(); 1906 auto forOp = dyn_cast<scf::ForOp>(containingOp); 1907 if (!forOp) 1908 return failure(); 1909 scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front(); 1910 1911 return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber); 1912 } 1913 1914 /// Fetch the first untiled consumer of a scf.forall's result which is yielded 1915 /// by a tensor.parallel_insert_slice. 1916 static FailureOr<OpOperand *> 1917 getUntiledConsumerFromSlice(RewriterBase &rewriter, 1918 tensor::ParallelInsertSliceOp candidateSliceOp) { 1919 // Step 1. Fetch the corresponding output 1920 Value sliceDest = candidateSliceOp.getDest(); 1921 auto iterArg = dyn_cast<BlockArgument>(sliceDest); 1922 if (!iterArg) 1923 return failure(); 1924 Operation *containingOp = iterArg.getOwner()->getParentOp(); 1925 if (containingOp != candidateSliceOp->getParentOp()->getParentOp()) 1926 return failure(); 1927 // Step 2. Check that the containing op is scf.forall. 1928 auto forallOp = dyn_cast<scf::ForallOp>(containingOp); 1929 if (!forallOp) 1930 return failure(); 1931 unsigned resultNumber = 1932 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)) 1933 .getResultNumber(); 1934 1935 return getConsumerFromLoopUses(rewriter, containingOp, resultNumber); 1936 } 1937 1938 /// A utility to fetch an untiled consumer of 1939 /// tensor.insert_slice/tensor.parallel_insert_slice. 1940 static FailureOr<OpOperand *> 1941 getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) { 1942 if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) { 1943 return getUntiledConsumerFromSlice(rewriter, insertSlice); 1944 } else if (auto parallelInsertSlice = 1945 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) { 1946 return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice); 1947 } else { 1948 return failure(); 1949 } 1950 } 1951 1952 /// Implementation of fusing consumer of a single slice by computing the 1953 /// slice of the consumer in-place for scf loop. 1954 FailureOr<scf::SCFFuseConsumerOfSliceResult> 1955 mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, 1956 Operation *candidateSliceOp) { 1957 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>( 1958 candidateSliceOp)) 1959 return failure(); 1960 1961 bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp); 1962 1963 // 1. Get the consumer of scf.for for the result yielded by 1964 // tensor.insert_slice/parallel_insert_slice. 1965 FailureOr<OpOperand *> maybeConsumerOpOperand = 1966 getUntiledConsumerFromSlice(rewriter, candidateSliceOp); 1967 if (failed(maybeConsumerOpOperand)) { 1968 return rewriter.notifyMatchFailure(candidateSliceOp, 1969 "could not fetch consumer to fuse"); 1970 } 1971 OpOperand *consumerOpOperand = *maybeConsumerOpOperand; 1972 Operation *consumerOp = consumerOpOperand->getOwner(); 1973 unsigned operandNumber = consumerOpOperand->getOperandNumber(); 1974 unsigned resultNumber = 0; 1975 if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) { 1976 resultNumber = producerResult.getResultNumber(); 1977 } else { 1978 return rewriter.notifyMatchFailure( 1979 consumerOp, "consumer op's operand doesn't seem to be an OpResult"); 1980 } 1981 1982 // There are two possible cases regarding `oldLoopOp` here: 1983 // 1. single `scf.forall` or `scf.for`. 1984 // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the 1985 // top-level loop is the outer-most one of these nested loops. 1986 LoopLikeOpInterface innerMostLoop = 1987 candidateSliceOp->getParentOfType<LoopLikeOpInterface>(); 1988 SmallVector<LoopLikeOpInterface> nestedLoops; 1989 if (isInsertSliceOp) { 1990 nestedLoops = llvm::map_to_vector( 1991 getPerfectlyNestedLoopsOutsideOf( 1992 cast<scf::ForOp>(innerMostLoop.getOperation())), 1993 [](scf::ForOp forOp) { 1994 return cast<LoopLikeOpInterface>(forOp.getOperation()); 1995 }); 1996 } else { 1997 nestedLoops = {innerMostLoop}; 1998 } 1999 2000 LoopLikeOpInterface outerMostLoop = nestedLoops.front(); 2001 2002 // Check assumption for loop with `reorderOperations` disabled. 2003 if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { 2004 return rewriter.notifyMatchFailure( 2005 outerMostLoop, "the first user of loop should not dominate any define " 2006 "of consumer operand(s)"); 2007 } 2008 2009 OpBuilder::InsertionGuard g(rewriter); 2010 2011 // 2. Check consumer is not using scf loop's output as init. 2012 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); 2013 if (!dstOp) 2014 return rewriter.notifyMatchFailure(consumerOp, 2015 "consumer op is not DPS operation"); 2016 SmallVector<Value> dpsInits = 2017 llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; }); 2018 if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) { 2019 return rewriter.notifyMatchFailure( 2020 consumerOp, 2021 "consumer op taking the result of scf.for as init is not supported"); 2022 } 2023 SmallVector<Value> newInits = dpsInits; 2024 2025 Location loc = outerMostLoop->getLoc(); 2026 2027 // 3. Move the whole loop structure right before firstUserOfLoop, the 2028 // dominance should be already ensured by `checkAssumptionForLoop`. 2029 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop); 2030 if (failed(firstUserOfLoop)) { 2031 return rewriter.notifyMatchFailure( 2032 outerMostLoop, "could not find the first user of outer most loop"); 2033 } 2034 rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop); 2035 2036 // 4. Set insertion point before terminator op of the loop and create a new 2037 // tensor.insert_slice. In the scf.for case this is a clone of the 2038 // candidateSliceOp whereas in the scf.forall case this is created from the 2039 // operands of tensor.parallel_insert_slice. 2040 tensor::InsertSliceOp clonedInsertSliceOp; 2041 if (auto sliceOp = 2042 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) { 2043 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation()); 2044 rewriter.setInsertionPoint(newForallOp.getTerminator()); 2045 clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>( 2046 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(), 2047 sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); 2048 } else { 2049 rewriter.setInsertionPoint(candidateSliceOp); 2050 clonedInsertSliceOp = 2051 cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp)); 2052 } 2053 2054 // 5.a. Clone consumer op. 2055 auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp)); 2056 2057 // 5.b. Replace all uses of the loop result with the result of the cloned 2058 // tensor.insert_slice. 2059 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber); 2060 rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { 2061 operandToReplace.set(clonedInsertSliceOp.getResult()); 2062 }); 2063 2064 // 6. Perform tiling of the cloned consumer and replace the operand at 2065 // `operandNumber` with the source of the cloned tensor.insert_slice op. 2066 auto ossSliceOp = 2067 cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation()); 2068 FailureOr<TilingResult> tileAndFuseResult = 2069 tensor::replaceInsertSliceWithTiledConsumer( 2070 rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber)); 2071 if (failed(tileAndFuseResult)) { 2072 return failure(); 2073 } 2074 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]); 2075 rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber), 2076 clonedInsertSliceOp.getSource()); 2077 2078 // 7. Reconstruct [nested] loop with new inits. 2079 YieldTiledValuesFn newYieldValuesFn = 2080 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, 2081 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, 2082 SmallVector<SmallVector<OpFoldResult>> &tiledOffset, 2083 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { 2084 OpBuilder::InsertionGuard g(innerRewriter); 2085 // 8. Set inner insertPoint right before tiled consumer op. 2086 innerRewriter.setInsertionPoint(tiledConsumerOp); 2087 2088 SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets(); 2089 SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes(); 2090 SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides(); 2091 2092 // 9. Check all insert stride is 1. 2093 if (llvm::any_of(strides, [](OpFoldResult stride) { 2094 return !isConstantIntValue(stride, 1); 2095 })) { 2096 return rewriter.notifyMatchFailure( 2097 candidateSliceOp, "containingOp's result yield with stride"); 2098 } 2099 2100 // 10. Try to get iter domain position from input position. Use 2101 // clonedConsumerOp instead of tiledConsumerOp, because the iteration 2102 // domain may require index computation based on the result size. The 2103 // sizes and offsets should be the same either way, but using 2104 // tiledConsumerOp could lead to some chained unnecessary extra index 2105 // computation. 2106 SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes; 2107 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile( 2108 rewriter, operandNumber, offsets, sizes, iterDomainOffsets, 2109 iterDomainSizes))) { 2110 return rewriter.notifyMatchFailure( 2111 clonedConsumerOp, 2112 "can't get iter domain position from input position"); 2113 } 2114 2115 // 11. Try to fetch the offset and size for all results of the cloned 2116 // consumer. This would then be used to form the corresponding 2117 // tensor.insert_slice/parallel_insert_slice later. 2118 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults(); 2119 SmallVector<SmallVector<OpFoldResult>> resultOffsets( 2120 totalNumResultsOfConsumer); 2121 SmallVector<SmallVector<OpFoldResult>> resultSizes( 2122 totalNumResultsOfConsumer); 2123 for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) { 2124 if (failed(tiledConsumerOp.getResultTilePosition( 2125 rewriter, idx, iterDomainOffsets, iterDomainSizes, 2126 resultOffsets[idx], resultSizes[idx]))) { 2127 return rewriter.notifyMatchFailure( 2128 tiledConsumerOp, 2129 "can't get result domain position from iter domain position"); 2130 } 2131 } 2132 2133 // 12. Create `extract_slice` for `iter_args` for DPS operation if 2134 // necessary. 2135 if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>( 2136 tiledConsumerOp.getOperation())) { 2137 rewriter.setInsertionPoint(tiledDestStyleOp); 2138 for (const auto &&[index, newRegionArg] : 2139 llvm::enumerate(newRegionIterArgs)) { 2140 auto destSlice = rewriter.create<tensor::ExtractSliceOp>( 2141 loc, newRegionArg, resultOffsets[index], resultSizes[index], 2142 SmallVector<OpFoldResult>(resultOffsets[index].size(), 2143 rewriter.getIndexAttr(1))); 2144 // Make a copy of index to avoid a capturing structured binding, which 2145 // is a C++20 extension. 2146 auto dstNumber = index; 2147 rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { 2148 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice); 2149 }); 2150 } 2151 } 2152 2153 // 13. Prepare tiled offset and sizes for later `insert_slice` creation by 2154 // caller. 2155 Block *block = rewriter.getInsertionPoint()->getBlock(); 2156 rewriter.setInsertionPoint(block->getTerminator()); 2157 for (const auto &&[index, result] : 2158 llvm::enumerate(tiledConsumerOp->getResults())) { 2159 tiledResult.push_back(result); 2160 tiledOffset.emplace_back(resultOffsets[index]); 2161 tiledSizes.emplace_back(resultSizes[index]); 2162 } 2163 return success(); 2164 }; 2165 // 14. Add new inits to [nested] loops. 2166 if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits, 2167 newYieldValuesFn))) { 2168 return rewriter.notifyMatchFailure(tiledConsumerOp, 2169 "unable to add new inits to nest loop"); 2170 } 2171 2172 // 15. Replace the result of scf loop and consumer op with new loop's 2173 // results. 2174 2175 for (auto &&[oldResult, newResult] : llvm::zip( 2176 consumerOp->getResults(), 2177 nestedLoops.front()->getResults().take_back(newInits.size()))) { 2178 rewriter.replaceAllUsesWith(oldResult, newResult); 2179 } 2180 2181 // 16. Need to erase the old scf loop and the cloned consumer op. 2182 rewriter.eraseOp(clonedConsumerOp); 2183 2184 return scf::SCFFuseConsumerOfSliceResult{ 2185 consumerOpOperand, 2186 &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)), 2187 tileAndFuseResult->tiledOps}; 2188 } 2189 2190 //===----------------------------------------------------------------------===// 2191 // lowerToLoopsUsingSCFForOp implementation. 2192 //===----------------------------------------------------------------------===// 2193 2194 FailureOr<SmallVector<scf::ForOp>> 2195 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 2196 TilingInterface op) { 2197 // TODO: Handle cases where the op has results if needed. 2198 if (op->getNumResults() > 0) { 2199 return rewriter.notifyMatchFailure( 2200 op, "unable to lower to loops operations with return values"); 2201 } 2202 2203 SmallVector<Range> domain = op.getIterationDomain(rewriter); 2204 SmallVector<Value> ivs; 2205 SmallVector<scf::ForOp> loops; 2206 Location loc = op.getLoc(); 2207 for (auto loopRange : domain) { 2208 Value offsetVal = 2209 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 2210 Value sizeVal = 2211 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 2212 Value strideVal = 2213 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 2214 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 2215 strideVal, ValueRange{}); 2216 loops.push_back(loop); 2217 ivs.push_back(loop.getInductionVar()); 2218 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 2219 } 2220 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 2221 return failure(); 2222 } 2223 return loops; 2224 } 2225