1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Analysis/TopologicalSortUtils.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arith/IR/Arith.h" 19 #include "mlir/Dialect/Arith/Utils/Utils.h" 20 #include "mlir/Dialect/Func/IR/FuncOps.h" 21 #include "mlir/Dialect/SCF/Utils/Utils.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Utils/IndexingUtils.h" 24 #include "mlir/IR/Dominance.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 28 #include "mlir/Interfaces/TilingInterface.h" 29 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "llvm/ADT/TypeSwitch.h" 32 #include "llvm/Support/Debug.h" 33 #include <optional> 34 35 #define DEBUG_TYPE "tile-using-interface" 36 37 using namespace mlir; 38 39 scf::SCFTilingOptions & 40 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) { 41 assert(!tileSizeComputationFunction && "tile sizes already set"); 42 auto tileSizes = llvm::to_vector(ts); 43 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 44 return tileSizes; 45 }; 46 return *this; 47 } 48 49 scf::SCFTilingOptions & 50 scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) { 51 assert(!numThreadsComputationFunction && "num tiles already set"); 52 auto numThreads = llvm::to_vector(nt); 53 numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) { 54 return numThreads; 55 }; 56 return *this; 57 } 58 59 /// Helper method to adjust the interchange vector to match the iteration 60 /// domain. 61 static SmallVector<int64_t> 62 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 63 size_t iterationDomainSize) { 64 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 65 if (filledVector.size() < iterationDomainSize) { 66 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 67 filledVector.append(range.begin(), range.end()); 68 } 69 if (filledVector.size() > iterationDomainSize) 70 filledVector.resize(iterationDomainSize); 71 return filledVector; 72 } 73 74 //===----------------------------------------------------------------------===// 75 // tileUsingSCF implementation. 76 //===----------------------------------------------------------------------===// 77 78 /// Verify the tile size options are set in a consistent manner. 79 static LogicalResult 80 verifyTileSizeOptions(RewriterBase &rewriter, Location loc, 81 const scf::SCFTilingOptions &options) { 82 // Specifying number of threads is only supported on `scf.forall` op. 83 if (options.numThreadsComputationFunction && 84 options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) { 85 return rewriter.notifyMatchFailure( 86 loc, "number of threads can only by specified when loop type is " 87 "set to use `scf.forall`"); 88 } 89 90 // If specified, check that the interchange vector is a permutation. 91 if (!options.interchangeVector.empty()) { 92 if (!isPermutationVector(options.interchangeVector)) { 93 return rewriter.notifyMatchFailure( 94 loc, "invalid interchange vector, not a permutation of the entire " 95 "iteration space"); 96 } 97 } 98 return success(); 99 } 100 101 /// Method to instantiate the tile sizes and/or number of threads specified 102 /// by the user. 103 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> 104 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, 105 ArrayRef<Range> iterationDomain, 106 const scf::SCFTilingOptions &options) { 107 OpFoldResult zero = rewriter.getIndexAttr(0); 108 SmallVector<OpFoldResult> tileSizes, numThreads; 109 size_t numLoops = iterationDomain.size(); 110 111 // Check whether the number of tiles to use is specified. 112 if (options.numThreadsComputationFunction) { 113 numThreads = options.numThreadsComputationFunction(rewriter, op); 114 numThreads.resize(numLoops, zero); 115 116 // If the number of tiles is also specified, use that. 117 if (options.tileSizeComputationFunction) { 118 tileSizes = options.tileSizeComputationFunction(rewriter, op); 119 tileSizes.resize(numLoops, zero); 120 return {tileSizes, numThreads}; 121 } 122 123 // Compute the tile sizes from the iteration domain and number 124 // of tiles as follows 125 // - niters = ceilDiv(ub - lb, step) 126 // - tileSize = ceilDiv(niters, numThreads) 127 AffineExpr s0, s1, s2; 128 bindSymbols(rewriter.getContext(), s0, s1, s2); 129 // TODO: The step here is assumed to be 1. 130 AffineExpr numItersExpr = (s1 - s0); 131 AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2); 132 tileSizes.resize(numLoops, zero); 133 for (auto [index, range, nt] : 134 llvm::enumerate(iterationDomain, numThreads)) { 135 if (isConstantIntValue(nt, 0)) 136 continue; 137 138 tileSizes[index] = affine::makeComposedFoldedAffineApply( 139 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt}); 140 } 141 tileSizes.resize(numLoops, zero); 142 return {tileSizes, numThreads}; 143 } 144 145 // Enforce the convention that "tiling by zero" 146 // skips tiling a particular dimension. This convention is significantly 147 // simpler to handle instead of adjusting affine maps to account for missing 148 // dimensions. 149 assert(options.tileSizeComputationFunction && 150 "expected tile sizes to be specified"); 151 tileSizes = options.tileSizeComputationFunction(rewriter, op); 152 tileSizes.resize(numLoops, zero); 153 154 return {tileSizes, numThreads}; 155 } 156 157 /// Checks if any of the tiled loops are not parallel. 158 static void checkSafeToTileToForall(TilingInterface op, 159 ArrayRef<OpFoldResult> tileSizes, 160 ArrayRef<OpFoldResult> numThreads) { 161 auto iterators = op.getLoopIteratorTypes(); 162 assert(iterators.size() == tileSizes.size() && 163 "expected as many tile size values as number of loops"); 164 assert((numThreads.empty() || (numThreads.size() == iterators.size())) && 165 "when specified, expected number of threads to use for each loop"); 166 167 for (auto [index, iterator, tileSize] : 168 llvm::enumerate(iterators, tileSizes)) { 169 // If num threads is specified, check that it is greater than one only for 170 // parallel dimensions. 171 if (!numThreads.empty()) { 172 if (std::optional<int64_t> constNumThreads = 173 getConstantIntValue(numThreads[index])) { 174 if (constNumThreads.value() > 1 && 175 iterator != utils::IteratorType::parallel) { 176 op.emitWarning() << "tiling is not thread safe at axis #" << index; 177 } 178 } 179 continue; 180 } 181 182 if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) { 183 if (constTileSize.value() > 0 && 184 iterator != utils::IteratorType::parallel) { 185 op.emitWarning() << "tiling is not thread safe at axis #" << index; 186 } 187 } 188 } 189 } 190 191 /// Check if `stride` evenly divides the trip count `size - offset`. 192 static bool tileDividesIterationDomain(Range loopRange) { 193 std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 194 if (!offsetAsInt) 195 return false; 196 std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 197 if (!sizeAsInt) 198 return false; 199 std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 200 if (!strideAsInt) 201 return false; 202 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 203 } 204 205 /// Returns the bounded tile size given the current `offset`, `loopRange` and 206 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`. 207 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 208 Range loopRange, OpFoldResult offset, 209 OpFoldResult tileSize) { 210 std::optional<int64_t> ts = getConstantIntValue(tileSize); 211 if (ts && ts.value() == 1) 212 return tileSize; 213 214 if (tileDividesIterationDomain( 215 Range{loopRange.offset, loopRange.size, tileSize})) 216 return tileSize; 217 218 // The tile size to use (to avoid out of bounds access) is minimum of 219 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 220 // loop. 221 AffineExpr s0, s1, d0; 222 bindDims(b.getContext(), d0); 223 bindSymbols(b.getContext(), s0, s1); 224 AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext()); 225 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 226 return affine::makeComposedFoldedAffineMin( 227 b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize}); 228 } 229 230 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less 231 /// than `iterationSize`. 232 static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, 233 OpFoldResult numThreads, 234 OpFoldResult iterationSize) { 235 std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize); 236 std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads); 237 std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize); 238 if (!tileSizeConst || !numThreadsConst || !iterSizeConst) 239 return false; 240 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; 241 } 242 243 /// Compute the `OpFoldResult`s that represents the multi-dimensional 244 /// `offset`s and `size`s of the tile of the iteration space that the 245 /// innermost loop body of the generated tiled loops corresponds to. 246 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> 247 getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, 248 ArrayRef<Range> iterationDomain, 249 ArrayRef<OpFoldResult> tileSizes, 250 ArrayRef<OpFoldResult> numThreads) { 251 SmallVector<OpFoldResult> offsets, sizes; 252 int materializedLoopNum = 0; 253 254 if (!numThreads.empty()) { 255 AffineExpr d0, d1, s0, s1; 256 AffineExpr offsetExpr, residualTileSizeExpr; 257 bindDims(rewriter.getContext(), d0, d1); 258 bindSymbols(rewriter.getContext(), s0, s1); 259 offsetExpr = d0 + d1 * s0; 260 residualTileSizeExpr = s1 - (d0 + d1 * s0); 261 262 for (auto [nt, tileSize, loopRange] : 263 llvm::zip_equal(numThreads, tileSizes, iterationDomain)) { 264 265 // Non-tiled cases, set the offset and size to the 266 // `loopRange.offset/size`. 267 if (isConstantIntValue(nt, 0)) { 268 offsets.push_back(loopRange.offset); 269 sizes.push_back(loopRange.size); 270 continue; 271 } 272 273 Value iv = ivs[materializedLoopNum++]; 274 OpFoldResult offset = affine::makeComposedFoldedAffineApply( 275 rewriter, loc, offsetExpr, 276 ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize}); 277 OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply( 278 rewriter, loc, residualTileSizeExpr, 279 {loopRange.offset, nt, tileSize, loopRange.size}); 280 281 OpFoldResult size = tileSize; 282 if (!isConstantIntValue(residualTileSize, 0)) { 283 OpFoldResult sizeMinusOffsetPerThread = 284 affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, 285 {offset, loopRange.size}); 286 size = affine::makeComposedFoldedAffineMin( 287 rewriter, loc, 288 AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), 289 {sizeMinusOffsetPerThread, tileSize}); 290 } 291 292 // Consider the case where the original loop was `[0, 100)`. 293 // If number of threads are `7`, the tile size would be computed as 294 // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6) 295 // - `offset = 0 + 6 * 15 = 105` 296 // - `tileSize = min(15, 100 - 105) = -5` 297 // To avoid negative tile sizes, we need to do a further 298 // `nonNegativeTileSize = affine.max(0, tileSize)`. 299 // This `max` can be avoided if 300 // `offset + tileSize * (numThreads - 1) < (ub - lb)` 301 if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) { 302 AffineMap maxMap = 303 AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); 304 size = affine::makeComposedFoldedAffineMax( 305 rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size}); 306 } 307 308 offsets.push_back(offset); 309 sizes.push_back(size); 310 } 311 return {offsets, sizes}; 312 } else { 313 for (auto [tileSize, loopRange] : 314 llvm::zip_equal(tileSizes, iterationDomain)) { 315 316 // Non-tiled cases, set the offset and size to the 317 // `loopRange.offset/size`. 318 if (isConstantIntValue(tileSize, 0)) { 319 offsets.push_back(loopRange.offset); 320 sizes.push_back(loopRange.size); 321 continue; 322 } 323 324 Value iv = ivs[materializedLoopNum++]; 325 OpFoldResult offset = getAsOpFoldResult(iv); 326 offsets.push_back(offset); 327 OpFoldResult size = 328 getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize); 329 sizes.push_back(size); 330 } 331 return {offsets, sizes}; 332 } 333 } 334 335 /// Function to return the bounds of the loops to be generated. 336 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, 337 SmallVector<OpFoldResult>> 338 getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 339 ArrayRef<OpFoldResult> tileSizes) { 340 SmallVector<OpFoldResult> lbs, ubs, steps; 341 for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) { 342 // No loop if the tile size is 0. 343 if (isConstantIntValue(tileSize, 0)) 344 continue; 345 lbs.push_back(loopRange.offset); 346 ubs.push_back(loopRange.size); 347 steps.push_back(tileSize); 348 } 349 return {lbs, ubs, steps}; 350 } 351 352 /// A function that allows returning additional yielded values during 353 /// `yieldTiledValuesAndReplace`. 354 /// - `ivs` induction variable for the loop. 355 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args. 356 /// - `tiledValues` the tiled values to return. Must be of same size as 357 /// `newbbArgs`, each element of this array is inserted into the corresponding 358 /// element in `newbbArgs`. 359 /// - `resultOffsets` is of the same size as `tiledValues` and represents 360 /// the offsets to use when inserting corresponding element from `tiledValues` 361 /// into the element from `newBbArgs`. 362 /// - `resultSizes` is of the same size as `tiledValues` and represents 363 /// the size of the corresponding element from `tiledValues` inserted into 364 /// the element from `newBbArgs`. 365 /// In case the method needs to return `failure()` the method is expected 366 /// to clean up any inserted operations. 367 using YieldTiledValuesFn = std::function<LogicalResult( 368 RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, 369 SmallVector<Value> &tiledValues, 370 SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 371 SmallVector<SmallVector<OpFoldResult>> &resultSizes)>; 372 373 /// Clones the operation and updates the destination if the operation 374 /// implements the `DestinationStyleOpInterface`. 375 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, 376 Operation *op, 377 ValueRange newDestArgs) { 378 Operation *clonedOp = rewriter.clone(*op); 379 if (newDestArgs.empty()) 380 return clonedOp; 381 if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp)) 382 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs); 383 return clonedOp; 384 } 385 386 /// Generate the tile-loop nest using `scf.for` operation. 387 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 388 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 389 /// - `destinationTensors` are the init values to use for the outer most loop. 390 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 391 /// most 392 /// loop. 393 /// - `loops` is an in-out parameter into which the generated loops are 394 /// populated. 395 static LogicalResult generateLoopNestUsingForOp( 396 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 397 ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors, 398 YieldTiledValuesFn yieldTiledValuesFn, 399 SmallVector<LoopLikeOpInterface> &loops) { 400 assert(!loopRanges.empty() && "unexpected empty loop ranges"); 401 assert(loopRanges.size() == tileSizes.size() && 402 "expected as many tile sizes as loop ranges"); 403 OpBuilder::InsertionGuard guard(rewriter); 404 405 SmallVector<OpFoldResult> lbs, ubs, steps; 406 std::tie(lbs, ubs, steps) = 407 getLoopBounds(rewriter, loc, loopRanges, tileSizes); 408 SmallVector<Value> lbVals = 409 getValueOrCreateConstantIndexOp(rewriter, loc, lbs); 410 SmallVector<Value> ubVals = 411 getValueOrCreateConstantIndexOp(rewriter, loc, ubs); 412 SmallVector<Value> stepVals = 413 getValueOrCreateConstantIndexOp(rewriter, loc, steps); 414 415 SmallVector<Value> ivs; 416 for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) { 417 auto loop = 418 rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors, 419 [](OpBuilder &bodyBuilder, Location bodyLoc, 420 Value iv, ValueRange /*iterArgs*/) {}); 421 loops.push_back(loop); 422 ivs.push_back(loop.getInductionVar()); 423 rewriter.setInsertionPointToEnd(loop.getBody()); 424 destinationTensors = loop.getRegionIterArgs(); 425 } 426 427 SmallVector<Value> tiledResults; 428 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 429 if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors, 430 tiledResults, resultOffsets, resultSizes))) { 431 return rewriter.notifyMatchFailure( 432 loc, "failed to generate inner tile loop body"); 433 } 434 if (loops.empty()) 435 return success(); 436 437 assert(tiledResults.size() == destinationTensors.size() && 438 "Number of results of body should be equal to number of iter args"); 439 440 // 6. Yield all the results of the tiled operation. 441 SmallVector<Value> yieldedValues; 442 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : 443 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, 444 resultSizes)) { 445 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 446 rewriter.getIndexAttr(1)); 447 auto insertSlice = rewriter.create<tensor::InsertSliceOp>( 448 loc, tiledValue, destinationTensor, resultOffset, resultSize, 449 resultStride); 450 yieldedValues.push_back(insertSlice); 451 } 452 rewriter.create<scf::YieldOp>(loc, yieldedValues); 453 454 // Add the scf.yield operations for all the outer loops. 455 for (auto [outerLoop, innerLoop] : 456 llvm::zip_equal(MutableArrayRef(loops).drop_back(), 457 MutableArrayRef(loops).drop_front())) { 458 rewriter.setInsertionPointToEnd( 459 cast<scf::ForOp>(outerLoop.getOperation()).getBody()); 460 rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults()); 461 } 462 return success(); 463 } 464 465 /// Generate the tile-loop nest using `scf.forall` operation. 466 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 467 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 468 /// - `destinationTensors` are the init values to use for the outer most loop. 469 /// - `mappingVector` is the mapping attributes to use for loop construction. 470 /// Can be empty. 471 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 472 /// most 473 /// loop. 474 /// - `loops` is an in-out parameter into which the generated loops are 475 /// populated. 476 static LogicalResult generateLoopNestUsingForallOp( 477 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 478 ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads, 479 ArrayRef<Attribute> mappingVector, ValueRange destinationTensors, 480 YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { 481 assert(!loopRanges.empty() && "unexpected empty loop ranges"); 482 assert(loopRanges.size() == tileSizes.size() && 483 "expected as many tile sizes as loop ranges"); 484 OpBuilder::InsertionGuard guard(rewriter); 485 SmallVector<OpFoldResult> offsets(loopRanges.size()), 486 sizes(loopRanges.size()); 487 488 std::optional<ArrayAttr> mappingAttr; 489 if (!mappingVector.empty()) 490 mappingAttr = rewriter.getArrayAttr(mappingVector); 491 492 scf::ForallOp forallOp; 493 bool useNumThreads = !numThreads.empty(); 494 495 if (useNumThreads) { 496 // Prune the zero numthreads. 497 SmallVector<OpFoldResult> nonZeroNumThreads; 498 for (auto nt : numThreads) { 499 if (isConstantIntValue(nt, 0)) 500 continue; 501 nonZeroNumThreads.push_back(nt); 502 } 503 forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads, 504 destinationTensors, mappingAttr); 505 } else { 506 SmallVector<OpFoldResult> lbs, ubs, steps; 507 std::tie(lbs, ubs, steps) = 508 getLoopBounds(rewriter, loc, loopRanges, tileSizes); 509 forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, 510 destinationTensors, mappingAttr); 511 } 512 loops.push_back(forallOp); 513 514 rewriter.setInsertionPoint(forallOp.getTerminator()); 515 destinationTensors = forallOp.getRegionOutArgs(); 516 517 SmallVector<Value> tiledResults; 518 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 519 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(), 520 destinationTensors, tiledResults, resultOffsets, 521 resultSizes))) 522 return rewriter.notifyMatchFailure(loc, "failed to generate loop body"); 523 524 rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); 525 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : 526 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, 527 resultSizes)) { 528 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 529 rewriter.getIndexAttr(1)); 530 531 rewriter.create<tensor::ParallelInsertSliceOp>( 532 loc, tiledValue, destinationTensor, resultOffset, resultSize, 533 resultStride); 534 } 535 return success(); 536 } 537 538 /// Generate the tile-loop nest using the loop construct specifed in `options`. 539 /// - `options`: Tiling options specified. 540 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 541 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 542 /// - `destinationTensors` are the init values to use for the outer most loop. 543 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 544 /// most 545 /// loop. 546 /// - `loops` is an in-out parameter into which the generated loops are 547 /// populated. 548 static LogicalResult generateLoopNest( 549 RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, 550 ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes, 551 ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors, 552 YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { 553 // If the tile sizes are all zero, no loops are generated. Just call the 554 // callback function to handle untiled case. 555 if (llvm::all_of(tileSizes, isZeroIndex)) { 556 SmallVector<Value> tiledResults; 557 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 558 return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, 559 tiledResults, resultOffsets, resultSizes); 560 } 561 if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { 562 return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, 563 destinationTensors, tiledBodyFn, loops); 564 } 565 if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { 566 return generateLoopNestUsingForallOp( 567 rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector, 568 destinationTensors, tiledBodyFn, loops); 569 } 570 return rewriter.notifyMatchFailure(loc, "unhandled loop type"); 571 } 572 573 /// Append the specified additional `newInitOperands` operands to the 574 /// loops existing `init` operands (or similar), and replace `loopOp` with 575 /// the new loop that has the additional init operands. The loop body of 576 /// this loop is moved over to the new loop. `yieldTiledValuesFn` 577 /// is called to get the new tiled values returned, and the offset 578 /// and sizes at which the tiled value is inserted into the 579 /// new region iter_args that correspond to the newly added init operands. 580 template <typename LoopType> 581 FailureOr<LoopLikeOpInterface> 582 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, 583 ValueRange newInitOperands, 584 YieldTiledValuesFn yieldTiledValuesFn) { 585 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); 586 } 587 588 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`. 589 template <> 590 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( 591 scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, 592 YieldTiledValuesFn yieldTiledValuesFn) { 593 OpBuilder::InsertionGuard g(rewriter); 594 Location loc = loopOp.getLoc(); 595 rewriter.setInsertionPoint(loopOp); 596 597 auto inits = llvm::to_vector(loopOp.getInitArgs()); 598 inits.append(newInitOperands.begin(), newInitOperands.end()); 599 auto newLoop = rewriter.create<scf::ForOp>( 600 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(), 601 inits, [](OpBuilder &, Location, Value, ValueRange) {}); 602 603 // Move the loop body to the new op. 604 Block *loopBody = loopOp.getBody(); 605 Block *newLoopBody = newLoop.getBody(); 606 rewriter.mergeBlocks( 607 loopBody, newLoopBody, 608 newLoopBody->getArguments().take_front(loopBody->getNumArguments())); 609 610 auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator()); 611 rewriter.setInsertionPoint(yieldOp); 612 613 SmallVector<Value> tiledValues; 614 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 615 ValueRange newRegionIterArgs = 616 newLoop.getRegionIterArgs().take_back(newInitOperands.size()); 617 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(), 618 newRegionIterArgs, tiledValues, resultOffsets, 619 resultSizes))) { 620 rewriter.eraseOp(newLoop); 621 return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values"); 622 } 623 624 SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands()); 625 for (auto [tiledValue, regionIterArg, resultOffset, resultSize] : 626 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets, 627 resultSizes)) { 628 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 629 rewriter.getIndexAttr(1)); 630 Value insert = rewriter.create<tensor::InsertSliceOp>( 631 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize, 632 resultStride); 633 newYieldValues.push_back(insert); 634 } 635 636 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 637 rewriter.replaceOp(loopOp, 638 newLoop->getResults().take_front(loopOp.getNumResults())); 639 return cast<LoopLikeOpInterface>(newLoop.getOperation()); 640 } 641 642 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall` 643 template <> 644 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>( 645 scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, 646 YieldTiledValuesFn yieldTiledValuesFn) { 647 OpBuilder::InsertionGuard g(rewriter); 648 Location loc = loopOp.getLoc(); 649 rewriter.setInsertionPoint(loopOp); 650 auto inits = llvm::to_vector(loopOp.getOutputs()); 651 inits.append(newInitOperands.begin(), newInitOperands.end()); 652 auto newLoop = rewriter.create<scf::ForallOp>( 653 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(), 654 loopOp.getMixedStep(), inits, loopOp.getMapping(), 655 [](OpBuilder &, Location, ValueRange) {}); 656 657 // Move the region of the current block to the newly created op. 658 Block *loopBody = loopOp.getBody(); 659 Block *newLoopBody = newLoop.getBody(); 660 rewriter.mergeBlocks( 661 loopBody, newLoopBody, 662 newLoopBody->getArguments().take_front(loopBody->getNumArguments())); 663 664 auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator()); 665 rewriter.setInsertionPoint(terminator); 666 SmallVector<Value> tiledValues; 667 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 668 ValueRange regionIterArgs = 669 newLoop.getRegionIterArgs().take_back(newInitOperands.size()); 670 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(), 671 regionIterArgs, tiledValues, resultOffsets, 672 resultSizes))) { 673 rewriter.eraseOp(newLoop); 674 return rewriter.notifyMatchFailure(loopOp, 675 "failed to get yielded tiled values"); 676 } 677 678 // Update the terminator. 679 rewriter.setInsertionPointToEnd(terminator.getBody()); 680 681 for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal( 682 tiledValues, regionIterArgs, resultOffsets, resultSizes)) { 683 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 684 rewriter.getIndexAttr(1)); 685 rewriter.create<tensor::ParallelInsertSliceOp>( 686 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize, 687 resultStride); 688 } 689 690 rewriter.replaceOp(loopOp, 691 newLoop->getResults().take_front(loopOp.getNumResults())); 692 return cast<LoopLikeOpInterface>(newLoop.getOperation()); 693 } 694 695 /// Implementation of `yieldTiledValuesAndReplaceLoop` for 696 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each 697 /// supported loop type. 698 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop( 699 LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter, 700 ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { 701 return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>( 702 loopLikeOp.getOperation()) 703 .Case<scf::ForOp, scf::ForallOp>( 704 [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { 705 return yieldTiledValuesAndReplaceLoop( 706 loopOp, rewriter, newInitOperands, yieldTiledValuesFn); 707 }) 708 .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { 709 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); 710 }); 711 } 712 713 /// Method to add new init values to a loop nest. Updates `loops` in-place with 714 /// new loops that use the `newInitValues`. 715 /// The outer-loops are updated to yield the new result values of the inner 716 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get 717 /// the additional values to yield form the innermost loop. 718 static LogicalResult addInitOperandsToLoopNest( 719 RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops, 720 ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) { 721 SmallVector<scf::ForOp> newLoops; 722 if (loops.empty()) 723 return success(); 724 OpBuilder::InsertionGuard g(rewriter); 725 rewriter.setInsertionPoint(loops.front()); 726 727 SmallVector<Value> ivs; 728 for (auto &loop : loops.drop_back()) { 729 rewriter.setInsertionPoint(loop); 730 731 // if loops.size() > 1 we assume that scf.for is used for the loops. 732 auto forLoop = cast<scf::ForOp>(loop.getOperation()); 733 734 // Create a new loop with the new init values for this loop. 735 SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs()); 736 newInits.append(newInitValues.begin(), newInitValues.end()); 737 auto newLoop = rewriter.create<scf::ForOp>( 738 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), 739 forLoop.getStep(), newInits, 740 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); 741 742 // Merge the body of the new loop with the body of the old loops. 743 SmallVector<Value> sourceBlockArgs; 744 sourceBlockArgs.push_back(newLoop.getInductionVar()); 745 auto newRegionIterArgs = newLoop.getRegionIterArgs(); 746 sourceBlockArgs.append( 747 newRegionIterArgs.begin(), 748 std::next(newRegionIterArgs.begin(), forLoop.getNumResults())); 749 rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs); 750 rewriter.replaceOp( 751 forLoop, newLoop.getResults().take_front(forLoop.getNumResults())); 752 loop = newLoop; 753 ivs.push_back(newLoop.getInductionVar()); 754 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size()); 755 } 756 757 // Update the loop body of the innermost loop to get new yield values. 758 LoopLikeOpInterface innerMostLoop = loops.back(); 759 FailureOr<LoopLikeOpInterface> newInnerMostLoop = 760 yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues, 761 getNewTiledYieldsFn); 762 763 if (failed(newInnerMostLoop)) 764 return innerMostLoop.emitOpError("failed to return additional yields"); 765 loops.back() = newInnerMostLoop.value(); 766 767 // Make all other loops except the innermost loops yield the values returned 768 // by the inner loop. 769 for (auto [outerLoop, innerLoop] : 770 llvm::zip_equal(loops.drop_back(), loops.drop_front())) { 771 // Again assume that all the outer loops are scf.for operations. 772 auto outerForLoop = cast<scf::ForOp>(outerLoop); 773 auto outerLoopYield = 774 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); 775 SmallVector<Value> newYields = 776 llvm::to_vector(outerLoopYield.getOperands()); 777 ValueRange additionalYields = 778 innerLoop->getResults().take_back(newInitValues.size()); 779 newYields.append(additionalYields.begin(), additionalYields.end()); 780 rewriter.setInsertionPoint(outerLoopYield); 781 rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields); 782 } 783 return success(); 784 } 785 786 /// Implementation of tiling transformation of `op` that implements the 787 /// `TilingInterface` using `scf.for` to iterate over the tiles. 788 FailureOr<scf::SCFTilingResult> 789 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, 790 const scf::SCFTilingOptions &options) { 791 if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) { 792 return failure(); 793 } 794 795 OpBuilder::InsertionGuard guard(rewriter); 796 rewriter.setInsertionPointAfter(op); 797 798 // 1. Get the range of the loops that are represented by the operation. 799 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 800 801 // 2. Materialize the tile sizes and/or number of threads; 802 SmallVector<OpFoldResult> tileSizes, numThreads; 803 std::tie(tileSizes, numThreads) = 804 getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options); 805 806 // Check if it is safe to tile. This is hold over from previous iterations 807 // of tile to for-all. Consider dropping it. 808 if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { 809 checkSafeToTileToForall(op, tileSizes, numThreads); 810 } 811 812 // 3. If there is an interchange specified, permute the iteration domain and 813 // the tile sizes. 814 SmallVector<int64_t> interchangeVector; 815 if (!options.interchangeVector.empty()) { 816 interchangeVector = fillInterchangeVector(options.interchangeVector, 817 iterationDomain.size()); 818 assert(isPermutationVector(interchangeVector) && 819 "expected interchange vector to be a permutation"); 820 821 applyPermutationToVector(iterationDomain, interchangeVector); 822 applyPermutationToVector(tileSizes, interchangeVector); 823 if (!numThreads.empty()) 824 applyPermutationToVector(numThreads, interchangeVector); 825 } 826 827 FailureOr<TilingResult> tilingResult; 828 // 4. Define the lambda function used later to generate the body of the 829 // innermost tiled loop. 830 YieldTiledValuesFn innerYieldTiledValuesFn = 831 [&](RewriterBase &rewriter, Location loc, ValueRange ivs, 832 ValueRange regionIterArgs, SmallVector<Value> &tiledResults, 833 SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 834 SmallVector<SmallVector<OpFoldResult>> &resultSizes) 835 -> LogicalResult { 836 // 4a. Compute the `offsets` and `sizes` to use for tiling. 837 SmallVector<OpFoldResult> offsets, sizes; 838 std::tie(offsets, sizes) = getTileOffsetAndSizes( 839 rewriter, loc, ivs, iterationDomain, tileSizes, numThreads); 840 841 // 4b. If interchange was provided, apply inverse of the interchange 842 // to get back the offsets/sizes in the order to be specified. 843 if (!interchangeVector.empty()) { 844 auto inversePermutation = invertPermutationVector(interchangeVector); 845 applyPermutationToVector(offsets, inversePermutation); 846 applyPermutationToVector(sizes, inversePermutation); 847 } 848 849 // 5. Generate the tiled implementation within the inner most loop. 850 851 // 5a. Clone the operation within the loop body. 852 auto clonedOp = cast<TilingInterface>( 853 cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs)); 854 855 // 5b. Early return cloned op if tiling is not happening. We can not return 856 // the original op because it could lead to 857 // `rewriter.replaceOp(op, op->getResults())` and users would get crash. 858 if (llvm::all_of(tileSizes, isZeroIndex)) { 859 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); 860 tilingResult = 861 TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(), 862 /*generatedSlices=*/{}}; 863 return success(); 864 } 865 866 // 5c. Tile the cloned operation. 867 tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes); 868 if (failed(tilingResult)) { 869 rewriter.eraseOp(clonedOp); 870 return op.emitOpError("faild to tile operation"); 871 } 872 873 // 5d. Delete the cloned operation. 874 rewriter.eraseOp(clonedOp); 875 876 // 5e. Compute the offsets at which the result values are to be inserted 877 // back into its destinations. 878 for (auto [index, tiledValue] : 879 llvm::enumerate(tilingResult->tiledValues)) { 880 tiledResults.push_back(tiledValue); 881 SmallVector<OpFoldResult> resultOffset, resultSize; 882 if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes, 883 resultOffset, resultSize))) { 884 for (auto op : tilingResult->tiledOps) { 885 rewriter.eraseOp(op); 886 } 887 return rewriter.notifyMatchFailure( 888 op, "failed to get slice of result produced"); 889 } 890 resultOffsets.emplace_back(std::move(resultOffset)); 891 resultSizes.emplace_back(std::move(resultSize)); 892 } 893 894 return success(); 895 }; 896 897 // 6. Find the destination tensors to use for the operation. 898 SmallVector<Value> destinationTensors; 899 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 900 destinationTensors))) { 901 return rewriter.notifyMatchFailure(op, 902 "unable to create destination tensors"); 903 } 904 905 // 7. Generate the tiled loops nest using the callback defined above. 906 SmallVector<LoopLikeOpInterface> loops; 907 if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain, 908 tileSizes, numThreads, destinationTensors, 909 innerYieldTiledValuesFn, loops))) 910 return op.emitOpError("failed to generate tiling loops"); 911 assert(succeeded(tilingResult) && 912 "expected tiling result to be computed after loop generation"); 913 914 // If loops are empty, the tiled op is used as the replacement for the untiled 915 // op. 916 if (loops.empty()) { 917 return scf::SCFTilingResult{tilingResult->tiledOps, loops, 918 tilingResult->tiledValues, 919 tilingResult->generatedSlices}; 920 } 921 922 SmallVector<Value> replacements = llvm::map_to_vector( 923 loops.front()->getResults(), [](OpResult r) -> Value { return r; }); 924 return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements, 925 tilingResult->generatedSlices}; 926 } 927 928 FailureOr<scf::SCFReductionTilingResult> 929 mlir::scf::tileReductionUsingScf(RewriterBase &b, 930 PartialReductionOpInterface op, 931 ArrayRef<OpFoldResult> tileSizes) { 932 Location loc = op.getLoc(); 933 // Ops implementing PartialReductionOpInterface are expected to implement 934 // TilingInterface. 935 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 936 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 937 auto tileSizesVector = llvm::to_vector(tileSizes); 938 if (tileSizesVector.size() < iterationDomain.size()) { 939 auto zero = b.getIndexAttr(0); 940 tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(), 941 zero); 942 } 943 SmallVector<utils::IteratorType> iterators = 944 tilingInterfaceOp.getLoopIteratorTypes(); 945 946 SmallVector<int> reductionDims; 947 for (auto [idx, iteratorType] : 948 llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { 949 if (iteratorType == utils::IteratorType::reduction) 950 reductionDims.push_back(idx); 951 } 952 953 // 2. create the inital tensor value. 954 FailureOr<SmallVector<Value>> maybeInitTensors = 955 op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector, 956 reductionDims); 957 if (failed(maybeInitTensors)) { 958 return b.notifyMatchFailure(op, "Failed to create initial tensors."); 959 } 960 SmallVector<Value> &initTensors = maybeInitTensors.value(); 961 962 // 3. Define the callback to use for generating the inner most tile loop body. 963 SmallVector<Operation *> parallelTiledOps; 964 auto innerYieldTiledValuesFn = 965 [&](RewriterBase &rewriter, Location loc, ValueRange ivs, 966 ValueRange regionIterArgs, SmallVector<Value> &tiledResult, 967 SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 968 SmallVector<SmallVector<OpFoldResult>> &resultSizes) 969 -> LogicalResult { 970 SmallVector<OpFoldResult> offsets, sizes; 971 { 972 int materializedLoopNum = 0; 973 for (auto [tileSize, loopRange] : 974 llvm::zip_equal(tileSizesVector, iterationDomain)) { 975 if (isConstantIntValue(tileSize, 0)) { 976 offsets.push_back(loopRange.offset); 977 sizes.push_back(loopRange.size); 978 continue; 979 } 980 Value iv = ivs[materializedLoopNum++]; 981 offsets.push_back(iv); 982 sizes.push_back( 983 getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize)); 984 } 985 } 986 987 // 4a. Clone the operation. 988 { 989 auto clonedOp = cast<PartialReductionOpInterface>( 990 cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs)); 991 992 // 4b. Tile the cloned operation. 993 FailureOr<TilingResult> partialTilingResult = 994 clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets, 995 sizes, reductionDims); 996 if (failed(partialTilingResult)) { 997 return failure(); 998 } 999 std::swap(parallelTiledOps, partialTilingResult->tiledOps); 1000 std::swap(tiledResult, partialTilingResult->tiledValues); 1001 1002 // 4c. Delete the cloned operation. 1003 b.eraseOp(clonedOp); 1004 } 1005 1006 // 4d. Compute the offsets and sizes needed to insert the result of the 1007 // tiled value back into destination before yielding the destination. 1008 for (auto result : tiledResult) { 1009 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); 1010 resultOffsets.emplace_back(std::move(outOffsets)); 1011 1012 SmallVector<OpFoldResult> outSizes; 1013 for (size_t i = 0; i < offsets.size(); i++) { 1014 outSizes.push_back(tensor::getMixedSize(b, loc, result, i)); 1015 } 1016 resultSizes.emplace_back(std::move(outSizes)); 1017 } 1018 return success(); 1019 }; 1020 1021 // 5. Generate the tiled implementation using the destination tensors. 1022 SmallVector<LoopLikeOpInterface> loops; 1023 scf::SCFTilingOptions options; 1024 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); 1025 if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector, 1026 /*numThreads=*/ArrayRef<OpFoldResult>{}, 1027 initTensors, innerYieldTiledValuesFn, loops))) 1028 return b.notifyMatchFailure(op, "failed to tile for parallel reduction"); 1029 1030 SmallVector<Value> replacements = llvm::map_to_vector( 1031 loops.front()->getResults(), [](OpResult r) -> Value { return r; }); 1032 1033 // 5. Apply the merge reduction to combine all the partial values. 1034 b.setInsertionPointAfter(*loops.begin()); 1035 FailureOr<MergeResult> mergeResult = 1036 op.mergeReductions(b, loc, replacements, reductionDims); 1037 if (failed(mergeResult)) { 1038 return failure(); 1039 } 1040 b.replaceOp(op, mergeResult->replacements); 1041 1042 SCFReductionTilingResult reductionTilingResult; 1043 std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps); 1044 std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps); 1045 std::swap(reductionTilingResult.initialValues, initTensors); 1046 std::swap(reductionTilingResult.loops, loops); 1047 std::swap(reductionTilingResult.replacements, mergeResult->replacements); 1048 1049 return reductionTilingResult; 1050 } 1051 1052 //===----------------------------------------------------------------------===// 1053 // tileConsumerAndFuseProducersUsingSCF implementation. 1054 //===----------------------------------------------------------------------===// 1055 1056 /// Return the untiled producer whose slice is used in a tiled consumer. The 1057 /// method traverses the tile loop nest (`loops`) if needed, and returns the 1058 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 1059 /// indicates that this is a destination operand of the consumer. If there was 1060 /// no loop traversal needed, the second value of the returned tuple is empty. 1061 static std::tuple<OpResult, std::optional<OpOperand *>> 1062 getUntiledProducerFromSliceSource(OpOperand *source, 1063 ArrayRef<LoopLikeOpInterface> loops) { 1064 std::optional<OpOperand *> destinationIterArg; 1065 auto loopIt = loops.rbegin(); 1066 while (auto iterArg = dyn_cast<BlockArgument>(source->get())) { 1067 auto loop = *loopIt; 1068 if (iterArg.getOwner()->getParentOp() != loop) 1069 break; 1070 source = loop.getTiedLoopInit(iterArg); 1071 loopIt++; 1072 } 1073 if (loopIt == loops.rend()) 1074 destinationIterArg = source; 1075 return {dyn_cast<OpResult>(source->get()), destinationIterArg}; 1076 } 1077 1078 /// Implementation of fusing producer of a single slice by computing the 1079 /// slice of the producer in-place. 1080 std::optional<scf::SCFFuseProducerOfSliceResult> 1081 mlir::scf::tileAndFuseProducerOfSlice( 1082 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, 1083 MutableArrayRef<LoopLikeOpInterface> loops) { 1084 // 1. Get the producer of the source (potentially walking through 1085 // `iter_args` of nested `scf.for`) 1086 auto [fusableProducer, destinationInitArg] = 1087 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), 1088 loops); 1089 if (!fusableProducer) 1090 return std::nullopt; 1091 unsigned resultNumber = fusableProducer.getResultNumber(); 1092 1093 OpBuilder::InsertionGuard g(rewriter); 1094 rewriter.setInsertionPoint(candidateSliceOp); 1095 1096 // 2. Clone the fused producer 1097 // 2a. Compute the destination operands to use for the cloned operation. 1098 SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors; 1099 Operation *fusableProducerOp = fusableProducer.getOwner(); 1100 if (isa<DestinationStyleOpInterface>(fusableProducerOp) && 1101 failed(tensor::getOrCreateDestinations( 1102 rewriter, fusableProducerOp->getLoc(), fusableProducerOp, 1103 origDestinationTensors))) 1104 return std::nullopt; 1105 1106 clonedOpDestinationTensors = origDestinationTensors; 1107 if (destinationInitArg && 1108 isa<DestinationStyleOpInterface>(fusableProducerOp)) { 1109 // 2b. If the producer is also destination style, then to maintain the 1110 // destination passing style, update the destination of the producer to be 1111 // the source of the slice. 1112 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource(); 1113 } 1114 // 2c. Clone the fused producer. 1115 Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs( 1116 rewriter, fusableProducerOp, clonedOpDestinationTensors); 1117 // 2d. Update the source of the candidateSlice to be the cloned producer. 1118 // Easier to just clone the slice with different source since replacements 1119 // and DCE of cloned ops becomes easier 1120 SmallVector<Value> candidateSliceOpOperands = 1121 llvm::to_vector(candidateSliceOp->getOperands()); 1122 candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber); 1123 tensor::ExtractSliceOp clonedCandidateSliceOp = 1124 mlir::clone(rewriter, candidateSliceOp, 1125 candidateSliceOp->getResultTypes(), candidateSliceOpOperands); 1126 1127 // 3. Generate the tiled implementation of the producer of the source 1128 FailureOr<TilingResult> tileAndFuseResult = 1129 tensor::replaceExtractSliceWithTiledProducer( 1130 rewriter, clonedCandidateSliceOp, 1131 clonedProducerOp->getResult(resultNumber)); 1132 if (failed(tileAndFuseResult)) 1133 return std::nullopt; 1134 // Note: Do not delete the candidateSliceOp, since its passed in from the 1135 // caller. 1136 rewriter.replaceAllUsesWith(candidateSliceOp, 1137 tileAndFuseResult->tiledValues[0]); 1138 rewriter.eraseOp(clonedCandidateSliceOp); 1139 rewriter.eraseOp(clonedProducerOp); 1140 1141 // 3. If the slice is for a destination operand, for example, 1142 // 1143 // ```mlir 1144 // %0 = linalg.init 1145 // %1 = linalg.fill .. outs(%0 : ) 1146 // %2 = scf.for .. iter_args(%arg0 = %1) { 1147 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 1148 // %4 = tensor.extract_slice %arg1 [..] 1149 // .. = linalg.matmul .. outs(%4 : ) 1150 // } 1151 // } 1152 // ``` 1153 // 1154 // the IR is currently 1155 // 1156 // ``` 1157 // %0 = linalg.init 1158 // %1 = linalg.fill 1159 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 1160 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 1161 // %4 = tensor.extract_slice %arg1[..] 1162 // %5 = linalg.fill .. outs(%4 : ) 1163 // .. = linalg.matmul .. outs(%5 : ) 1164 // } 1165 // } 1166 // ``` 1167 // 1168 // The untiled `linalg.fill` is still used as the `init_value` since it 1169 // was originally a destination operand of the untiled `linalg.matmul`. 1170 // When fusing an operand that is a destination operand, the iter_arg of 1171 // the outer most loop should be changed to use the destination of the 1172 // fused operation. With this the IR will be. 1173 // 1174 // ``` 1175 // %0 = linalg.init 1176 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 1177 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 1178 // %3 = tensor.extract_slice %arg1[..] 1179 // %4 = linalg.fill .. outs(%3 : ) 1180 // .. = linalg.matmul .. outs(%4 : ) 1181 // } 1182 // } 1183 // ``` 1184 if (destinationInitArg && 1185 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) { 1186 loops.front() 1187 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] 1188 .set(origDestinationTensors[resultNumber]); 1189 } 1190 return scf::SCFFuseProducerOfSliceResult{ 1191 fusableProducer, tileAndFuseResult->tiledValues[0], 1192 tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices}; 1193 } 1194 1195 /// Reconstruct the fused producer from within the tiled-and-fused code. 1196 FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer( 1197 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, 1198 scf::SCFFuseProducerOfSliceResult fusedProducerInfo, 1199 MutableArrayRef<LoopLikeOpInterface> loops, 1200 ArrayRef<unsigned> yieldResultNumber) { 1201 if (loops.empty()) 1202 return success(); 1203 1204 Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(), 1205 *tiledOwner = fusedProducerInfo.tiledOps[0]; 1206 1207 Location loc = originalOwner->getLoc(); 1208 // a. collect all init Value to be appended 1209 SmallVector<unsigned> initNumberList = 1210 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>( 1211 0, originalOwner->getNumResults())) 1212 : llvm::to_vector(yieldResultNumber); 1213 SmallVector<Value> initValueList; 1214 for (const auto &resultNumber : initNumberList) { 1215 FailureOr<Value> initValue = tensor::getOrCreateDestination( 1216 rewriter, loc, originalOwner->getResult(resultNumber)); 1217 if (succeeded(initValue)) { 1218 initValueList.push_back(initValue.value()); 1219 } else { 1220 return failure(); 1221 } 1222 } 1223 1224 SmallVector<Operation *> generatedSlices; 1225 YieldTiledValuesFn newYieldValuesFn = 1226 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, 1227 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, 1228 SmallVector<SmallVector<OpFoldResult>> &tiledOffset, 1229 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { 1230 OpBuilder::InsertionGuard g(innerRewriter); 1231 1232 // get sliceOp tile information 1233 SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(), 1234 sliceSizes = sliceOp.getMixedSizes(); 1235 1236 // expect all strides of sliceOp being 1 1237 if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { 1238 return !isConstantIntValue(ofr, 1); 1239 })) 1240 return failure(); 1241 1242 unsigned sliceResultNumber = 1243 fusedProducerInfo.origProducer.getResultNumber(); 1244 1245 auto tilableOp = cast<TilingInterface>(originalOwner); 1246 // b. get iterDomain Offset and Sizes based on sliceOp tile 1247 SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes; 1248 // skip tensor.pack/unpack/pad, which expects single opResult 1249 if (tilableOp->getNumResults() > 1 && 1250 failed(tilableOp.getIterationDomainTileFromResultTile( 1251 rewriter, sliceResultNumber, sliceOffset, sliceSizes, 1252 iterDomainOffset, iterDomainSizes))) { 1253 // In theory, it is unnecessary to raise an error here. Actually although 1254 // it fails to reconstruct the result tensor, it should not broke current 1255 // fusion anyway. The reason why we must return failure currently is that 1256 // the callback function `newYieldValuesFn` will be called after new init 1257 // operand(s) has already been appended. It will take more refactoring to 1258 // make sure the init operands are added consistently in the future. For 1259 // more details, please refer to: 1260 // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814 1261 return failure(); 1262 } 1263 1264 // c. calculate offsets and sizes info of all OpResults respectively based 1265 // on iteration Domain Tile 1266 SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList; 1267 for (const auto &resultNumber : initNumberList) { 1268 if (resultNumber == sliceResultNumber) { 1269 offsetList.push_back(sliceOffset); 1270 sizesList.push_back(sliceSizes); 1271 } else { 1272 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty()); 1273 // infer result tile according to the iteration domain tile 1274 SmallVector<OpFoldResult> offset, sizes; 1275 if (failed(tilableOp.getResultTilePosition( 1276 rewriter, resultNumber, iterDomainOffset, iterDomainSizes, 1277 offset, sizes))) { 1278 return failure(); 1279 } 1280 offsetList.push_back(offset); 1281 sizesList.push_back(sizes); 1282 } 1283 } 1284 1285 // d. create `extract_slice` for `iter_args` for DPS operation if necessary 1286 if (auto tiledDestStyleOp = 1287 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) { 1288 rewriter.setInsertionPoint(tiledDestStyleOp); 1289 for (const auto &&[index, newRegionArg] : 1290 llvm::enumerate(newRegionIterArgs)) { 1291 auto destSlice = rewriter.create<tensor::ExtractSliceOp>( 1292 loc, newRegionArg, offsetList[index], sizesList[index], 1293 SmallVector<OpFoldResult>(offsetList[index].size(), 1294 rewriter.getIndexAttr(1))); 1295 generatedSlices.push_back(destSlice); 1296 unsigned resultNumber = initNumberList[index]; 1297 rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { 1298 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); 1299 }); 1300 } 1301 } 1302 1303 // e. prepare tiled offset and sizes for later `insert_slice` creation by 1304 // caller 1305 Block *block = rewriter.getInsertionPoint()->getBlock(); 1306 rewriter.setInsertionPoint(block->getTerminator()); 1307 for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) { 1308 tiledResult.push_back(tiledOwner->getResult(resultNumber)); 1309 tiledOffset.emplace_back(offsetList[index]); 1310 tiledSizes.emplace_back(sizesList[index]); 1311 } 1312 return success(); 1313 }; 1314 1315 if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList, 1316 newYieldValuesFn))) { 1317 return failure(); 1318 } 1319 return generatedSlices; 1320 } 1321 1322 namespace { 1323 1324 //===----------------------------------------------------------------------===// 1325 // SliceTrackingListener 1326 //===----------------------------------------------------------------------===// 1327 1328 /// This class is a listener for tracking the insertion and removal of 1329 /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy 1330 /// fusion algorithm to apply cleanup patterns in between fusion steps. 1331 class SliceTrackingListener : public RewriterBase::Listener { 1332 public: 1333 explicit SliceTrackingListener( 1334 std::optional<FrozenRewritePatternSet> patterns); 1335 SliceTrackingListener() = default; 1336 1337 /// Adds the given list of operations to the worklist, and if present, applies 1338 /// the list of `patterns` to the newly added operations. This only processes 1339 /// the given operations and any newly inserted ones by the pattern set. 1340 LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps); 1341 1342 /// Add to the new operation worklist if it is an extract_slice. 1343 void notifyOperationInserted(Operation *op, 1344 OpBuilder::InsertPoint previous) override; 1345 1346 /// Shared helper for operation removal from the worklist. 1347 void removeOp(Operation *op); 1348 1349 /// Remove the operation from the worklist. 1350 void notifyOperationErased(Operation *op) override; 1351 1352 /// Remove the operation from the worklist. 1353 void notifyOperationReplaced(Operation *op, ValueRange replacement) override; 1354 1355 /// The worklist for this transformation keeps track of the slices to visit 1356 /// next for fusion. 1357 std::deque<tensor::ExtractSliceOp> worklist; 1358 1359 private: 1360 /// Optional pattern set to apply when adding new operations to the worklist. 1361 std::optional<FrozenRewritePatternSet> patterns = std::nullopt; 1362 }; 1363 1364 SliceTrackingListener::SliceTrackingListener( 1365 std::optional<FrozenRewritePatternSet> p) { 1366 patterns = std::move(p); 1367 } 1368 1369 LogicalResult 1370 SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) { 1371 for (Operation *op : ops) { 1372 if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op)) 1373 worklist.push_back(slice); 1374 } 1375 1376 if (!patterns) 1377 return success(); 1378 1379 GreedyRewriteConfig config; 1380 config.listener = this; 1381 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 1382 return applyOpPatternsAndFold(ops, patterns.value(), config); 1383 } 1384 1385 void SliceTrackingListener::notifyOperationInserted( 1386 Operation *op, OpBuilder::InsertPoint previous) { 1387 auto slice = dyn_cast<tensor::ExtractSliceOp>(op); 1388 if (!slice) 1389 return; 1390 worklist.push_back(slice); 1391 } 1392 1393 // Scan the worklist for the given op and remove it if present. The expectation 1394 // is for the worklist to be small and for removal to be relatively rare. 1395 void SliceTrackingListener::removeOp(Operation *op) { 1396 if (!isa<tensor::ExtractSliceOp>(op)) 1397 return; 1398 auto iter = worklist.begin(); 1399 while (iter != worklist.end()) { 1400 if (*iter == op) 1401 break; 1402 iter++; 1403 } 1404 if (iter == worklist.end()) 1405 return; 1406 1407 worklist.erase(iter); 1408 } 1409 1410 void SliceTrackingListener::notifyOperationErased(Operation *op) { 1411 removeOp(op); 1412 } 1413 1414 void SliceTrackingListener::notifyOperationReplaced(Operation *op, 1415 ValueRange replacement) { 1416 removeOp(op); 1417 } 1418 } // namespace 1419 1420 /// Implementation of tile consumer and fuse producer greedily. 1421 FailureOr<scf::SCFTileAndFuseResult> 1422 mlir::scf::tileConsumerAndFuseProducersUsingSCF( 1423 RewriterBase &rewriter, TilingInterface consumer, 1424 const scf::SCFTileAndFuseOptions &options) { 1425 // This transformation is only valid for ops that return values (i.e. not 1426 // valid to use with operations that have memref operands). 1427 if (!consumer->getNumResults()) { 1428 return rewriter.notifyMatchFailure( 1429 consumer, "invalid pattern for op with no results"); 1430 } 1431 1432 // 1. First tile the consumer. 1433 SetVector<Operation *> fusedProducers, tiledAndFusedOps; 1434 llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum; 1435 1436 FailureOr<scf::SCFTilingResult> tilingResult = 1437 tileUsingSCF(rewriter, consumer, options.tilingOptions); 1438 1439 if (failed(tilingResult)) 1440 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 1441 for (auto *tiledOp : tilingResult->tiledOps) 1442 tiledAndFusedOps.insert(tiledOp); 1443 1444 // If there are no loops generated, fusion is immaterial. 1445 auto &loops = tilingResult->loops; 1446 if (loops.empty()) { 1447 DenseMap<Value, Value> replacements; 1448 for (auto [origVal, replacement] : 1449 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) { 1450 replacements[origVal] = replacement; 1451 } 1452 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, 1453 replacements}; 1454 } 1455 1456 // To keep track of replacements for now just record the map from the original 1457 // untiled value to the result number of the for loop. Since the loop gets 1458 // potentially replaced during fusion, keeping the value directly wont work. 1459 DenseMap<Value, size_t> origValToResultNumber; 1460 for (auto [index, result] : llvm::enumerate(consumer->getResults())) { 1461 origValToResultNumber[result] = index; 1462 } 1463 1464 // 2. Typically, the operands of the tiled operation are slices of the 1465 // operands of the untiled operation. These are expressed in IR using 1466 // `tensor.extract_slice` operations with source being the operands of the 1467 // untiled operation. Create a worklist of these `tensor.extract_slice` 1468 // operations. If the producers of the source of the `tensor.extract_slice` 1469 // can be tiled such that the tiled value is generated in-place, that 1470 // effectively tiles + fuses the operations. 1471 struct WorklistItem { 1472 tensor::ExtractSliceOp candidateSlice; 1473 SCFTileAndFuseOptions::ControlFnResult controlFnResult; 1474 }; 1475 1476 SliceTrackingListener sliceTracker = 1477 SliceTrackingListener(options.cleanupPatterns); 1478 1479 if (failed( 1480 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) { 1481 return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); 1482 } 1483 OpBuilder::InsertionGuard g(rewriter); 1484 while (!sliceTracker.worklist.empty()) { 1485 auto candidateSlice = sliceTracker.worklist.front(); 1486 sliceTracker.worklist.pop_front(); 1487 1488 auto [fusableProducer, destinationInitArg] = 1489 getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(), 1490 loops); 1491 if (!fusableProducer) 1492 continue; 1493 1494 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult = 1495 options.fusionControlFn(candidateSlice, fusableProducer, 1496 destinationInitArg.has_value()); 1497 if (!controlFnResult) 1498 continue; 1499 1500 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()}; 1501 1502 // The operands of the fused producer might themselved be slices of 1503 // values produced by operations that implement the `TilingInterface`. 1504 // Add these operations to the worklist. 1505 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult = 1506 tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice, 1507 loops); 1508 if (!fusedResult) 1509 continue; 1510 1511 SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices; 1512 1513 if (worklistItem.controlFnResult.yieldProducerReplacement) { 1514 // Reconstruct and yield all opResult of fusableProducerOp by default. The 1515 // caller can specific which one to yield by designating optional argument 1516 // named `yieldResultNumber` of `yieldReplacementForFusedProducer`. 1517 Operation *fusableProducerOp = fusedResult->origProducer.getOwner(); 1518 FailureOr<SmallVector<Operation *>> newSlices = 1519 yieldReplacementForFusedProducer(rewriter, 1520 worklistItem.candidateSlice, 1521 fusedResult.value(), loops); 1522 if (failed(newSlices)) { 1523 return rewriter.notifyMatchFailure( 1524 fusableProducerOp, "failed to replacement value for this " 1525 "operation from within the tiled loop"); 1526 } 1527 worklistCandidates.append(newSlices.value()); 1528 for (auto [index, result] : 1529 llvm::enumerate(fusableProducerOp->getResults())) { 1530 origValToResultNumber[result] = loops.front()->getNumResults() - 1531 fusableProducerOp->getNumResults() + 1532 index; 1533 } 1534 } 1535 if (Operation *tiledAndFusedOp = 1536 fusedResult->tiledAndFusedProducer.getDefiningOp()) { 1537 fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); 1538 tiledAndFusedOps.insert(tiledAndFusedOp); 1539 } 1540 1541 if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) { 1542 return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); 1543 } 1544 } 1545 1546 DenseMap<Value, Value> replacements; 1547 for (auto [origVal, resultNumber] : origValToResultNumber) { 1548 replacements[origVal] = loops.front()->getResult(resultNumber); 1549 } 1550 1551 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, 1552 replacements}; 1553 } 1554 1555 //===----------------------------------------------------------------------===// 1556 // tileAndFuseConsumerUsingSCF implementation. 1557 //===----------------------------------------------------------------------===// 1558 1559 /// A utility function that checks whether the only use of the result of a 1560 /// tensor.insert_slice op is in a scf.yield op. 1561 static LogicalResult 1562 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { 1563 Value result = candidateSliceOp.getResult(); 1564 Value::use_range uses = result.getUses(); 1565 if (!llvm::hasSingleElement(uses)) { 1566 LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n"); 1567 return failure(); 1568 } 1569 OpOperand &operandUse = (*uses.begin()); 1570 Operation *userOp = operandUse.getOwner(); 1571 if (!isa<scf::YieldOp>(userOp)) { 1572 LLVM_DEBUG(llvm::dbgs() 1573 << "Expected scf.yield to be the only user, but got -> " 1574 << (*userOp)); 1575 return failure(); 1576 } 1577 if (result.getDefiningOp()->getBlock() != userOp->getBlock()) { 1578 LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to " 1579 "be in the same block\n"); 1580 return failure(); 1581 } 1582 return success(); 1583 } 1584 1585 /// An utility to get the first user of the given loopOp. If any of user stay in 1586 /// different block of loopOp, return failure. 1587 static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) { 1588 if (!isa<LoopLikeOpInterface>(loopOp)) 1589 return failure(); 1590 Operation *firstUserOfLoop = nullptr; 1591 for (Operation *userOp : loopOp->getUsers()) { 1592 // `ParallelInsertSlice` located inside `InParallelOp` has no same parent 1593 // block with any other types of operation. Thus, just redirecting to its 1594 // parent `InParallelOp`. E.g. 1595 // 1596 // ``` 1597 // %1 = scf.for { 1598 // ... 1599 // } 1600 // %2 = consumerOp ins(%1, ...) 1601 // scf.forall.in_parallel { 1602 // tensor.parallel_insert_slice %1 1603 // } 1604 // ``` 1605 // where `InParallelOp` but not `ParallelInsertSlice` stays in the same 1606 // same block with `consumerOp`. 1607 if (isa<tensor::ParallelInsertSliceOp>(userOp)) 1608 userOp = userOp->getParentOfType<scf::InParallelOp>(); 1609 1610 if (loopOp->getBlock() != userOp->getBlock()) 1611 return failure(); 1612 1613 if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop)) 1614 firstUserOfLoop = userOp; 1615 } 1616 return firstUserOfLoop; 1617 } 1618 1619 /// This utility currently checks whether the first userOp of loop is NOT before 1620 /// the last defineOp of consumer operand. Because that we need to move the 1621 /// whole loop structure right before the `firstUserOfLoop`. This utility thus 1622 /// helps ensuring that no invalid IR is formed, i.e. no backward slice of 1623 /// consumerOp is dominated by the `firstUserOfLoop`. Saying that: 1624 /// 1625 /// ``` 1626 /// %0 = scf.for() { 1627 /// ... 1628 /// } 1629 /// ... 1630 /// %1 = firstUserOfLoop(%0) 1631 /// ... 1632 /// %2 = lastDefOfConsumerOperand 1633 /// ... 1634 /// %3 = consumerOp(%2) 1635 /// ``` 1636 /// 1637 /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it would 1638 /// be invalid to move the `loopOp` right before the `firstUserOfLoop`, a.k.a. 1639 /// use-def chain violation: 1640 /// 1641 /// ``` 1642 /// %0:2 = scf.for() { 1643 /// // use before define error 1644 /// %3 = tiledConsumerOp(%2) 1645 /// } 1646 /// %1 = firstUserOfLoop(%0) 1647 /// ... 1648 /// %2 = lastDefOfConsumerOperand 1649 /// ``` 1650 /// 1651 /// @param loopOp: loop operation 1652 /// @param consumerOp: consumer operation 1653 /// @param reorderOperations: the flag controls whether to reorder the backward 1654 /// slice w.r.t. the defineOp of `consumerOp` operands. 1655 /// @return: computed backward slice of consumerOp, but excluding those already 1656 /// dominates `firstUserOfLoop`. 1657 static FailureOr<llvm::SetVector<Operation *>> 1658 checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, 1659 bool reorderOperations) { 1660 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp); 1661 if (failed(firstUserOfLoop)) 1662 return failure(); 1663 1664 BackwardSliceOptions options; 1665 DominanceInfo dominanceInfo; 1666 options.inclusive = true; 1667 options.omitBlockArguments = true; 1668 bool includeLoopOp = false; 1669 options.filter = [&](Operation *op) { 1670 if (op == loopOp) { 1671 includeLoopOp = true; 1672 return false; 1673 } 1674 // Cut off the slice to not include any operation that already dominates 1675 // firstUserOfLoop. 1676 return !dominanceInfo.properlyDominates(op, *firstUserOfLoop); 1677 }; 1678 llvm::SetVector<Operation *> slice; 1679 for (auto operand : consumerOp->getOperands()) { 1680 getBackwardSlice(operand, &slice, options); 1681 } 1682 1683 if (!slice.empty()) { 1684 // If consumerOp has one producer, which is also the user of loopOp. 1685 // E.g. 1686 // ``` 1687 // %0 = %loopOp 1688 // %1 = consumerOp1 ins(%0) 1689 // %2 = consumerOp2 ins(%0, %1) 1690 // ``` 1691 // We can not fuse consumerOp2 into loopOp due to UD chain, unless 1692 // consumerOp1 has already been fused into loopOp before. 1693 if (includeLoopOp || !reorderOperations) 1694 return failure(); 1695 } 1696 1697 return slice; 1698 } 1699 1700 /// Fetches the OpOperand of the first valid user (and use) of the value `val` 1701 /// which implements `TilingInterface` and `DestinationStyleOpInterface`. 1702 /// Returns failure otherwise. 1703 static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter, 1704 Operation *loopOp, 1705 unsigned resultNumber) { 1706 if (!isa<LoopLikeOpInterface>(loopOp)) 1707 return failure(); 1708 Value val = loopOp->getResult(resultNumber); 1709 Block *loopBlock = loopOp->getBlock(); 1710 for (OpOperand &opOperand : val.getUses()) { 1711 Operation *consumerOp = opOperand.getOwner(); 1712 // Step 1. Check if the user is tilable. 1713 if (!isa<TilingInterface>(consumerOp) || 1714 !isa<DestinationStyleOpInterface>(consumerOp)) { 1715 // TODO: We have to init result of consumer before scf.for, use 1716 // DestinationStyleOpInterface to get result shape from init for now. Add 1717 // support for other op such as op has InferTypeOpInterface. 1718 continue; 1719 } 1720 // Step 2. Check if user stay in the same block. 1721 if (loopBlock != consumerOp->getBlock()) 1722 continue; 1723 // Step 3. Check if user has succeeding user. Otherwise, it usually 1724 // represents already tiled. 1725 if (consumerOp->use_empty()) 1726 continue; 1727 // Step 4. Check assumption for loop with `reorderOperations` enabled. 1728 FailureOr<llvm::SetVector<Operation *>> slice = 1729 checkAssumptionForLoop(loopOp, consumerOp, true); 1730 if (failed(slice)) 1731 continue; 1732 // Step 5. If backward sice is not empty, move them before firstUserOfLoop. 1733 if (!slice->empty()) { 1734 mlir::topologicalSort(*slice); 1735 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp); 1736 assert(succeeded(firstUserOfLoop) && "First user of loop is not found"); 1737 for (auto op : *slice) { 1738 rewriter.moveOpBefore(op, *firstUserOfLoop); 1739 } 1740 } 1741 return &opOperand; 1742 } 1743 return failure(); 1744 } 1745 1746 /// Find the perfectly nested loops outside of given loop(included) sorted from 1747 /// outer to inner. 1748 /// 1749 /// E.g. 1750 /// 1751 /// ``` 1752 /// %0 = scf.for() 1753 /// %1 = scf.for() 1754 /// %2 = scf.for() 1755 /// %3 = ... 1756 /// yield %3 1757 /// yield %2 1758 /// yield %1 1759 /// ``` 1760 /// 1761 /// This function will return three perfectly nested loops: %0 + %1 + %2, when 1762 /// target inner loop is %2. 1763 static SmallVector<scf::ForOp> 1764 getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) { 1765 SmallVector<scf::ForOp> nestLoops = {loop}; 1766 auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp()); 1767 1768 // Check if it is the ForOp that yield the result of inner loop. 1769 auto isForOpYieldResultOfInnerLoop = 1770 [](scf::ForOp outerLoop) -> LogicalResult { 1771 Block *body = outerLoop.getBody(); 1772 if (!llvm::hasSingleElement(body->without_terminator())) 1773 return failure(); 1774 auto yieldOp = cast<scf::YieldOp>(body->getTerminator()); 1775 auto innerForOp = dyn_cast<scf::ForOp>(body->front()); 1776 if (!innerForOp) 1777 return failure(); 1778 // All of innerForOp results should be yielded. 1779 return success(innerForOp->getNumResults() == yieldOp->getNumOperands()); 1780 }; 1781 1782 while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) { 1783 nestLoops.push_back(outerLoop); 1784 outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp()); 1785 } 1786 // sorted from outer to inner 1787 return {nestLoops.rbegin(), nestLoops.rend()}; 1788 } 1789 1790 /// Fetch the untiled consumer of a scf.for's result which is yielded by a 1791 /// tensor.insert_slice. This function makes the following assumptions : 1792 /// 1. tensor.insert_slice has scf.yield as its only user. 1793 /// 2. scf.for's corresponding result has only one use. 1794 static FailureOr<OpOperand *> 1795 getUntiledConsumerFromSlice(RewriterBase &rewriter, 1796 tensor::InsertSliceOp candidateSliceOp) { 1797 if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) 1798 return failure(); 1799 Value sliceResult = candidateSliceOp.getResult(); 1800 // Step 1. Fetch the corresponding output. 1801 OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); 1802 unsigned resultNumber = yieldOpOperand.getOperandNumber(); 1803 // Step 2. Check containing op is scf.for. 1804 Operation *containingOp = candidateSliceOp->getParentOp(); 1805 auto forOp = dyn_cast<scf::ForOp>(containingOp); 1806 if (!forOp) 1807 return failure(); 1808 scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front(); 1809 1810 return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber); 1811 } 1812 1813 /// Fetch the first untiled consumer of a scf.forall's result which is yielded 1814 /// by a tensor.parallel_insert_slice. 1815 static FailureOr<OpOperand *> 1816 getUntiledConsumerFromSlice(RewriterBase &rewriter, 1817 tensor::ParallelInsertSliceOp candidateSliceOp) { 1818 // Step 1. Fetch the corresponding output 1819 Value sliceDest = candidateSliceOp.getDest(); 1820 auto iterArg = dyn_cast<BlockArgument>(sliceDest); 1821 if (!iterArg) 1822 return failure(); 1823 Operation *containingOp = iterArg.getOwner()->getParentOp(); 1824 if (containingOp != candidateSliceOp->getParentOp()->getParentOp()) 1825 return failure(); 1826 // Step 2. Check that the containing op is scf.forall. 1827 auto forallOp = dyn_cast<scf::ForallOp>(containingOp); 1828 if (!forallOp) 1829 return failure(); 1830 unsigned resultNumber = 1831 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)) 1832 .getResultNumber(); 1833 1834 return getConsumerFromLoopUses(rewriter, containingOp, resultNumber); 1835 } 1836 1837 /// A utility to fetch an untiled consumer of 1838 /// tensor.insert_slice/tensor.parallel_insert_slice. 1839 static FailureOr<OpOperand *> 1840 getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) { 1841 if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) { 1842 return getUntiledConsumerFromSlice(rewriter, insertSlice); 1843 } else if (auto parallelInsertSlice = 1844 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) { 1845 return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice); 1846 } else { 1847 return failure(); 1848 } 1849 } 1850 1851 /// Implementation of fusing consumer of a single slice by computing the 1852 /// slice of the consumer in-place for scf loop. 1853 FailureOr<scf::SCFFuseConsumerOfSliceResult> 1854 mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, 1855 Operation *candidateSliceOp) { 1856 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>( 1857 candidateSliceOp)) 1858 return failure(); 1859 1860 bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp); 1861 1862 // 1. Get the consumer of scf.for for the result yielded by 1863 // tensor.insert_slice/parallel_insert_slice. 1864 FailureOr<OpOperand *> maybeConsumerOpOperand = 1865 getUntiledConsumerFromSlice(rewriter, candidateSliceOp); 1866 if (failed(maybeConsumerOpOperand)) { 1867 return rewriter.notifyMatchFailure(candidateSliceOp, 1868 "could not fetch consumer to fuse"); 1869 } 1870 OpOperand *consumerOpOperand = *maybeConsumerOpOperand; 1871 Operation *consumerOp = consumerOpOperand->getOwner(); 1872 unsigned operandNumber = consumerOpOperand->getOperandNumber(); 1873 unsigned resultNumber = 0; 1874 if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) { 1875 resultNumber = producerResult.getResultNumber(); 1876 } else { 1877 return rewriter.notifyMatchFailure( 1878 consumerOp, "consumer op's operand doesn't seem to be an OpResult"); 1879 } 1880 1881 // There are two possible cases regarding `oldLoopOp` here: 1882 // 1. single `scf.forall` or `scf.for`. 1883 // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the 1884 // top-level loop is the outer-most one of these nested loops. 1885 LoopLikeOpInterface innerMostLoop = 1886 candidateSliceOp->getParentOfType<LoopLikeOpInterface>(); 1887 SmallVector<LoopLikeOpInterface> nestedLoops; 1888 if (isInsertSliceOp) { 1889 nestedLoops = llvm::map_to_vector( 1890 getPerfectlyNestedLoopsOutsideOf( 1891 cast<scf::ForOp>(innerMostLoop.getOperation())), 1892 [](scf::ForOp forOp) { 1893 return cast<LoopLikeOpInterface>(forOp.getOperation()); 1894 }); 1895 } else { 1896 nestedLoops = {innerMostLoop}; 1897 } 1898 1899 LoopLikeOpInterface outerMostLoop = nestedLoops.front(); 1900 1901 // Check assumption for loop with `reorderOperations` disabled. 1902 if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { 1903 return rewriter.notifyMatchFailure( 1904 outerMostLoop, "the first user of loop should not dominate any define " 1905 "of consumer operand(s)"); 1906 } 1907 1908 OpBuilder::InsertionGuard g(rewriter); 1909 1910 // 2. Check consumer is not using scf loop's output as init. 1911 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); 1912 if (!dstOp) 1913 return rewriter.notifyMatchFailure(consumerOp, 1914 "consumer op is not DPS operation"); 1915 SmallVector<Value> dpsInits = 1916 llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; }); 1917 if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) { 1918 return rewriter.notifyMatchFailure( 1919 consumerOp, 1920 "consumer op taking the result of scf.for as init is not supported"); 1921 } 1922 SmallVector<Value> newInits = dpsInits; 1923 1924 Location loc = outerMostLoop->getLoc(); 1925 1926 // 3. Move the whole loop structure right before firstUserOfLoop, the 1927 // dominance should be already ensured by `checkAssumptionForLoop`. 1928 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop); 1929 if (failed(firstUserOfLoop)) { 1930 return rewriter.notifyMatchFailure( 1931 outerMostLoop, "could not find the first user of outer most loop"); 1932 } 1933 rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop); 1934 1935 // 4. Set insertion point before terminator op of the loop and create a new 1936 // tensor.insert_slice. In the scf.for case this is a clone of the 1937 // candidateSliceOp whereas in the scf.forall case this is created from the 1938 // operands of tensor.parallel_insert_slice. 1939 tensor::InsertSliceOp clonedInsertSliceOp; 1940 if (auto sliceOp = 1941 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) { 1942 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation()); 1943 rewriter.setInsertionPoint(newForallOp.getTerminator()); 1944 clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>( 1945 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(), 1946 sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); 1947 } else { 1948 rewriter.setInsertionPoint(candidateSliceOp); 1949 clonedInsertSliceOp = 1950 cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp)); 1951 } 1952 1953 // 5.a. Clone consumer op. 1954 auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp)); 1955 1956 // 5.b. Replace all uses of the loop result with the result of the cloned 1957 // tensor.insert_slice. 1958 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber); 1959 rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { 1960 operandToReplace.set(clonedInsertSliceOp.getResult()); 1961 }); 1962 1963 // 6. Perform tiling of the cloned consumer and replace the operand at 1964 // `operandNumber` with the source of the cloned tensor.insert_slice op. 1965 auto ossSliceOp = 1966 cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation()); 1967 FailureOr<TilingResult> tileAndFuseResult = 1968 tensor::replaceInsertSliceWithTiledConsumer( 1969 rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber)); 1970 if (failed(tileAndFuseResult)) { 1971 return failure(); 1972 } 1973 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]); 1974 rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber), 1975 clonedInsertSliceOp.getSource()); 1976 1977 // 7. Reconstruct [nested] loop with new inits. 1978 YieldTiledValuesFn newYieldValuesFn = 1979 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, 1980 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, 1981 SmallVector<SmallVector<OpFoldResult>> &tiledOffset, 1982 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { 1983 OpBuilder::InsertionGuard g(innerRewriter); 1984 // 8. Set inner insertPoint right before tiled consumer op. 1985 innerRewriter.setInsertionPoint(tiledConsumerOp); 1986 1987 SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets(); 1988 SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes(); 1989 SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides(); 1990 1991 // 9. Check all insert stride is 1. 1992 if (llvm::any_of(strides, [](OpFoldResult stride) { 1993 return !isConstantIntValue(stride, 1); 1994 })) { 1995 return rewriter.notifyMatchFailure( 1996 candidateSliceOp, "containingOp's result yield with stride"); 1997 } 1998 1999 // 10. Try to get iter domain position from input position. Use 2000 // clonedConsumerOp instead of tiledConsumerOp, because the iteration domain 2001 // may require index computation based on the result size. The sizes and 2002 // offsets should be the same either way, but using tiledConsumerOp could 2003 // lead to some chained unnecessary extra index computation. 2004 SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes; 2005 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile( 2006 rewriter, operandNumber, offsets, sizes, iterDomainOffsets, 2007 iterDomainSizes))) { 2008 return rewriter.notifyMatchFailure( 2009 clonedConsumerOp, 2010 "can't get iter domain position from input position"); 2011 } 2012 2013 // 11. Try to fetch the offset and size for all results of the cloned 2014 // consumer. This would then be used to form the corresponding 2015 // tensor.insert_slice/parallel_insert_slice later. 2016 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults(); 2017 SmallVector<SmallVector<OpFoldResult>> resultOffsets( 2018 totalNumResultsOfConsumer); 2019 SmallVector<SmallVector<OpFoldResult>> resultSizes( 2020 totalNumResultsOfConsumer); 2021 for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) { 2022 if (failed(tiledConsumerOp.getResultTilePosition( 2023 rewriter, idx, iterDomainOffsets, iterDomainSizes, 2024 resultOffsets[idx], resultSizes[idx]))) { 2025 return rewriter.notifyMatchFailure( 2026 tiledConsumerOp, 2027 "can't get result domain position from iter domain position"); 2028 } 2029 } 2030 2031 // 12. Create `extract_slice` for `iter_args` for DPS operation if 2032 // necessary. 2033 if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>( 2034 tiledConsumerOp.getOperation())) { 2035 rewriter.setInsertionPoint(tiledDestStyleOp); 2036 for (const auto &&[index, newRegionArg] : 2037 llvm::enumerate(newRegionIterArgs)) { 2038 auto destSlice = rewriter.create<tensor::ExtractSliceOp>( 2039 loc, newRegionArg, resultOffsets[index], resultSizes[index], 2040 SmallVector<OpFoldResult>(resultOffsets[index].size(), 2041 rewriter.getIndexAttr(1))); 2042 // Make a copy of index to avoid a capturing structured binding, which 2043 // is a C++20 extension. 2044 auto dstNumber = index; 2045 rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { 2046 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice); 2047 }); 2048 } 2049 } 2050 2051 // 13. Prepare tiled offset and sizes for later `insert_slice` creation by 2052 // caller. 2053 Block *block = rewriter.getInsertionPoint()->getBlock(); 2054 rewriter.setInsertionPoint(block->getTerminator()); 2055 for (const auto &&[index, result] : 2056 llvm::enumerate(tiledConsumerOp->getResults())) { 2057 tiledResult.push_back(result); 2058 tiledOffset.emplace_back(resultOffsets[index]); 2059 tiledSizes.emplace_back(resultSizes[index]); 2060 } 2061 return success(); 2062 }; 2063 // 14. Add new inits to [nested] loops. 2064 if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits, 2065 newYieldValuesFn))) { 2066 return rewriter.notifyMatchFailure(tiledConsumerOp, 2067 "unable to add new inits to nest loop"); 2068 } 2069 2070 // 15. Replace the result of scf loop and consumer op with new loop's results. 2071 2072 for (auto &&[oldResult, newResult] : llvm::zip( 2073 consumerOp->getResults(), 2074 nestedLoops.front()->getResults().take_back(newInits.size()))) { 2075 rewriter.replaceAllUsesWith(oldResult, newResult); 2076 } 2077 2078 // 16. Need to erase the old scf loop and the cloned consumer op. 2079 rewriter.eraseOp(clonedConsumerOp); 2080 2081 return scf::SCFFuseConsumerOfSliceResult{ 2082 consumerOpOperand, 2083 &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)), 2084 tileAndFuseResult->tiledOps}; 2085 } 2086 2087 //===----------------------------------------------------------------------===// 2088 // lowerToLoopsUsingSCFForOp implementation. 2089 //===----------------------------------------------------------------------===// 2090 2091 FailureOr<SmallVector<scf::ForOp>> 2092 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 2093 TilingInterface op) { 2094 // TODO: Handle cases where the op has results if needed. 2095 if (op->getNumResults() > 0) { 2096 return rewriter.notifyMatchFailure( 2097 op, "unable to lower to loops operations with return values"); 2098 } 2099 2100 SmallVector<Range> domain = op.getIterationDomain(rewriter); 2101 SmallVector<Value> ivs; 2102 SmallVector<scf::ForOp> loops; 2103 Location loc = op.getLoc(); 2104 for (auto loopRange : domain) { 2105 Value offsetVal = 2106 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 2107 Value sizeVal = 2108 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 2109 Value strideVal = 2110 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 2111 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 2112 strideVal, ValueRange{}); 2113 loops.push_back(loop); 2114 ivs.push_back(loop.getInductionVar()); 2115 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 2116 } 2117 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 2118 return failure(); 2119 } 2120 return loops; 2121 } 2122