1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/SCF/Utils/Utils.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Dialect/Utils/IndexingUtils.h" 23 #include "mlir/IR/Dominance.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 27 #include "mlir/Interfaces/TilingInterface.h" 28 #include "llvm/ADT/TypeSwitch.h" 29 #include "llvm/Support/Debug.h" 30 #include <optional> 31 32 #define DEBUG_TYPE "tile-using-interface" 33 34 using namespace mlir; 35 36 scf::SCFTilingOptions & 37 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) { 38 assert(!tileSizeComputationFunction && "tile sizes already set"); 39 auto tileSizes = llvm::to_vector(ts); 40 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 41 return tileSizes; 42 }; 43 return *this; 44 } 45 46 /// Helper method to adjust the interchange vector to match the iteration 47 /// domain. 48 static SmallVector<int64_t> 49 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 50 size_t iterationDomainSize) { 51 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 52 if (filledVector.size() < iterationDomainSize) { 53 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 54 filledVector.append(range.begin(), range.end()); 55 } 56 if (filledVector.size() > iterationDomainSize) 57 filledVector.resize(iterationDomainSize); 58 return filledVector; 59 } 60 61 //===----------------------------------------------------------------------===// 62 // tileUsingSCF implementation. 63 //===----------------------------------------------------------------------===// 64 65 // Check if `stride` evenly divides the trip count `size - offset`. 66 static bool tileDividesIterationDomain(Range loopRange) { 67 std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 68 if (!offsetAsInt) 69 return false; 70 std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 71 if (!sizeAsInt) 72 return false; 73 std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 74 if (!strideAsInt) 75 return false; 76 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 77 } 78 79 /// Returns the bounded tile size given the current `iv`, `loopRange` and 80 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. 81 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 82 Range loopRange, Value iv, 83 OpFoldResult tileSize) { 84 std::optional<int64_t> ts = getConstantIntValue(tileSize); 85 if (ts && ts.value() == 1) 86 return tileSize; 87 88 if (tileDividesIterationDomain( 89 Range{loopRange.offset, loopRange.size, tileSize})) 90 return tileSize; 91 92 // The tile size to use (to avoid out of bounds access) is minimum of 93 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 94 // loop. 95 AffineExpr s0, s1, d0; 96 bindDims(b.getContext(), d0); 97 bindSymbols(b.getContext(), s0, s1); 98 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); 99 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 100 return affine::makeComposedFoldedAffineMin( 101 b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size}); 102 } 103 104 /// A function that allows returning additional yielded values during 105 /// `yieldTiledValuesAndReplace`. 106 /// - `ivs` induction variable for the loop. 107 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args. 108 /// - `tiledValues` the tiled values to return. Must be of same size as 109 /// `newbbArgs`, each element of this array is inserted into the corresponding 110 /// element in `newbbArgs`. 111 /// - `resultOffsets` is of the same size as `tiledValues` and represents 112 /// the offsets to use when inserting corresponding element from `tiledValues` 113 /// into the element from `newBbArgs`. 114 /// - `resultSizes` is of the same size as `tiledValues` and represents 115 /// the size of the corresponding element from `tiledValues` inserted into 116 /// the element from `newBbArgs`. 117 /// In case the method needs to return `failure()` the method is expected 118 /// to clean up any inserted operations. 119 using YieldTiledValuesFn = std::function<LogicalResult( 120 RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, 121 SmallVector<Value> &tiledValues, 122 SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 123 SmallVector<SmallVector<OpFoldResult>> &resultSizes)>; 124 125 /// Clones the operation and updates the destination if the operation 126 /// implements the `DestinationStyleOpInterface`. 127 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, 128 Operation *op, 129 ValueRange newDestArgs) { 130 Operation *clonedOp = rewriter.clone(*op); 131 if (newDestArgs.empty()) 132 return clonedOp; 133 if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp)) 134 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs); 135 return clonedOp; 136 } 137 138 /// Generate the tile-loop nest using `scf.for` operation. 139 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 140 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 141 /// - `destinationTensors` are the init values to use for the outer most loop. 142 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 143 /// most 144 /// loop. 145 /// - `loops` is an in-out parameter into which the generated loops are 146 /// populated. 147 static LogicalResult generateLoopNestUsingForOp( 148 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 149 ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors, 150 YieldTiledValuesFn yieldTiledValuesFn, 151 SmallVector<LoopLikeOpInterface> &loops) { 152 assert(!loopRanges.empty() && "unexpected empty loop ranges"); 153 assert(loopRanges.size() == tileSizes.size() && 154 "expected as many tile sizes as loop ranges"); 155 OpBuilder::InsertionGuard guard(rewriter); 156 SmallVector<Value> ivs; 157 158 for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) { 159 // No loops if tile size is zero. Set offset and size to the loop 160 // offset and size. 161 if (isConstantIntValue(tileSize, 0)) 162 continue; 163 164 Value lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 165 Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 166 Value step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize); 167 auto loop = 168 rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors, 169 [](OpBuilder &bodyBuilder, Location bodyLoc, 170 Value iv, ValueRange /*iterArgs*/) {}); 171 loops.push_back(loop); 172 ivs.push_back(loop.getInductionVar()); 173 rewriter.setInsertionPointToEnd(loop.getBody()); 174 destinationTensors = loop.getRegionIterArgs(); 175 } 176 177 SmallVector<Value> tiledResults; 178 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 179 if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors, 180 tiledResults, resultOffsets, resultSizes))) { 181 return rewriter.notifyMatchFailure( 182 loc, "failed to generate inner tile loop body"); 183 } 184 if (loops.empty()) 185 return success(); 186 187 assert(tiledResults.size() == destinationTensors.size() && 188 "Number of results of body should be equal to number of iter args"); 189 190 // 6. Yield all the results of the tiled operation. 191 SmallVector<Value> yieldedValues; 192 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : 193 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, 194 resultSizes)) { 195 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 196 rewriter.getIndexAttr(1)); 197 auto insertSlice = rewriter.create<tensor::InsertSliceOp>( 198 loc, tiledValue, destinationTensor, resultOffset, resultSize, 199 resultStride); 200 yieldedValues.push_back(insertSlice); 201 } 202 rewriter.create<scf::YieldOp>(loc, yieldedValues); 203 204 // Add the scf.yield operations for all the outer loops. 205 for (auto [outerLoop, innerLoop] : 206 llvm::zip_equal(MutableArrayRef(loops).drop_back(), 207 MutableArrayRef(loops).drop_front())) { 208 rewriter.setInsertionPointToEnd( 209 cast<scf::ForOp>(outerLoop.getOperation()).getBody()); 210 rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults()); 211 } 212 return success(); 213 } 214 215 /// Generate the tile-loop nest using `scf.forall` operation. 216 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 217 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 218 /// - `destinationTensors` are the init values to use for the outer most loop. 219 /// - `mappingVector` is the mapping attributes to use for loop construction. 220 /// Can be empty. 221 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 222 /// most 223 /// loop. 224 /// - `loops` is an in-out parameter into which the generated loops are 225 /// populated. 226 static LogicalResult generateLoopNestUsingForallOp( 227 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 228 ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector, 229 ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, 230 SmallVector<LoopLikeOpInterface> &loops) { 231 SmallVector<OpFoldResult> lbs, ubs, steps; 232 assert(!loopRanges.empty() && "unexpected empty loop ranges"); 233 assert(loopRanges.size() == tileSizes.size() && 234 "expected as many tile sizes as loop ranges"); 235 OpBuilder::InsertionGuard guard(rewriter); 236 SmallVector<OpFoldResult> offsets(loopRanges.size()), 237 sizes(loopRanges.size()); 238 239 for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) { 240 if (isConstantIntValue(tileSize, 0)) 241 continue; 242 lbs.push_back(loopRange.offset); 243 ubs.push_back(loopRange.size); 244 steps.push_back(tileSize); 245 } 246 assert(!lbs.empty() && "Expected at least one loop range"); 247 248 std::optional<ArrayAttr> mappingAttr; 249 if (!mappingVector.empty()) 250 mappingAttr = rewriter.getArrayAttr(mappingVector); 251 252 auto forallOp = rewriter.create<scf::ForallOp>( 253 loc, lbs, ubs, steps, destinationTensors, mappingAttr); 254 loops.push_back(forallOp); 255 256 rewriter.setInsertionPoint(forallOp.getTerminator()); 257 destinationTensors = forallOp.getRegionOutArgs(); 258 259 SmallVector<Value> tiledResults; 260 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 261 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(), 262 destinationTensors, tiledResults, resultOffsets, 263 resultSizes))) 264 return rewriter.notifyMatchFailure(loc, "failed to generate loop body"); 265 266 rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); 267 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : 268 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, 269 resultSizes)) { 270 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 271 rewriter.getIndexAttr(1)); 272 273 rewriter.create<tensor::ParallelInsertSliceOp>( 274 loc, tiledValue, destinationTensor, resultOffset, resultSize, 275 resultStride); 276 } 277 return success(); 278 } 279 280 /// Generate the tile-loop nest using the loop construct specifed in `options`. 281 /// - `options`: Tiling options specified. 282 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 283 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 284 /// - `destinationTensors` are the init values to use for the outer most loop. 285 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 286 /// most 287 /// loop. 288 /// - `loops` is an in-out parameter into which the generated loops are 289 /// populated. 290 static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, 291 const scf::SCFTilingOptions &options, 292 ArrayRef<Range> loopRanges, 293 ArrayRef<OpFoldResult> tileSizes, 294 ValueRange destinationTensors, 295 YieldTiledValuesFn tiledBodyFn, 296 SmallVector<LoopLikeOpInterface> &loops) { 297 // If the tile sizes are all zero, no loops are generated. Just call the 298 // callback function to handle untiled case. 299 if (llvm::all_of(tileSizes, isZeroIndex)) { 300 SmallVector<Value> tiledResults; 301 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 302 return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, 303 tiledResults, resultOffsets, resultSizes); 304 } 305 if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { 306 return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, 307 destinationTensors, tiledBodyFn, loops); 308 } 309 if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { 310 return generateLoopNestUsingForallOp( 311 rewriter, loc, loopRanges, tileSizes, options.mappingVector, 312 destinationTensors, tiledBodyFn, loops); 313 } 314 return rewriter.notifyMatchFailure(loc, "unhandled loop type"); 315 } 316 317 /// Append the specified additional `newInitOperands` operands to the 318 /// loops existing `init` operands (or similar), and replace `loopOp` with 319 /// the new loop that has the additional init operands. The loop body of 320 /// this loop is moved over to the new loop. `yieldTiledValuesFn` 321 /// is called to get the new tiled values returned, and the offset 322 /// and sizes at which the tiled value is inserted into the 323 /// new region iter_args that correspond to the newly added init operands. 324 template <typename LoopType> 325 FailureOr<LoopLikeOpInterface> 326 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, 327 ValueRange newInitOperands, 328 YieldTiledValuesFn yieldTiledValuesFn) { 329 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); 330 } 331 332 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`. 333 template <> 334 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( 335 scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, 336 YieldTiledValuesFn yieldTiledValuesFn) { 337 OpBuilder::InsertionGuard g(rewriter); 338 Location loc = loopOp.getLoc(); 339 rewriter.setInsertionPoint(loopOp); 340 341 auto inits = llvm::to_vector(loopOp.getInitArgs()); 342 inits.append(newInitOperands.begin(), newInitOperands.end()); 343 auto newLoop = rewriter.create<scf::ForOp>( 344 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(), 345 inits, [](OpBuilder &, Location, Value, ValueRange) {}); 346 347 // Move the loop body to the new op. 348 Block *loopBody = loopOp.getBody(); 349 Block *newLoopBody = newLoop.getBody(); 350 rewriter.mergeBlocks( 351 loopBody, newLoopBody, 352 newLoopBody->getArguments().take_front(loopBody->getNumArguments())); 353 354 auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator()); 355 rewriter.setInsertionPoint(yieldOp); 356 357 SmallVector<Value> tiledValues; 358 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 359 ValueRange newRegionIterArgs = 360 newLoop.getRegionIterArgs().take_back(newInitOperands.size()); 361 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(), 362 newRegionIterArgs, tiledValues, resultOffsets, 363 resultSizes))) { 364 rewriter.eraseOp(newLoop); 365 return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values"); 366 } 367 368 SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands()); 369 for (auto [tiledValue, regionIterArg, resultOffset, resultSize] : 370 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets, 371 resultSizes)) { 372 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 373 rewriter.getIndexAttr(1)); 374 Value insert = rewriter.create<tensor::InsertSliceOp>( 375 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize, 376 resultStride); 377 newYieldValues.push_back(insert); 378 } 379 380 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 381 rewriter.replaceOp(loopOp, 382 newLoop->getResults().take_front(loopOp.getNumResults())); 383 return cast<LoopLikeOpInterface>(newLoop.getOperation()); 384 } 385 386 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall` 387 template <> 388 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>( 389 scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, 390 YieldTiledValuesFn yieldTiledValuesFn) { 391 OpBuilder::InsertionGuard g(rewriter); 392 Location loc = loopOp.getLoc(); 393 rewriter.setInsertionPoint(loopOp); 394 auto inits = llvm::to_vector(loopOp.getOutputs()); 395 inits.append(newInitOperands.begin(), newInitOperands.end()); 396 auto newLoop = rewriter.create<scf::ForallOp>( 397 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(), 398 loopOp.getMixedStep(), inits, loopOp.getMapping(), 399 [](OpBuilder &, Location, ValueRange) {}); 400 401 // Move the region of the current block to the newly created op. 402 Block *loopBody = loopOp.getBody(); 403 Block *newLoopBody = newLoop.getBody(); 404 rewriter.mergeBlocks( 405 loopBody, newLoopBody, 406 newLoopBody->getArguments().take_front(loopBody->getNumArguments())); 407 408 auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator()); 409 rewriter.setInsertionPoint(terminator); 410 SmallVector<Value> tiledValues; 411 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 412 ValueRange regionIterArgs = 413 newLoop.getRegionIterArgs().take_back(newInitOperands.size()); 414 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(), 415 regionIterArgs, tiledValues, resultOffsets, 416 resultSizes))) { 417 rewriter.eraseOp(newLoop); 418 return rewriter.notifyMatchFailure(loopOp, 419 "failed to get yielded tiled values"); 420 } 421 422 // Update the terminator. 423 rewriter.setInsertionPointToEnd(terminator.getBody()); 424 425 for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal( 426 tiledValues, regionIterArgs, resultOffsets, resultSizes)) { 427 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 428 rewriter.getIndexAttr(1)); 429 rewriter.create<tensor::ParallelInsertSliceOp>( 430 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize, 431 resultStride); 432 } 433 434 rewriter.replaceOp(loopOp, 435 newLoop->getResults().take_front(loopOp.getNumResults())); 436 return cast<LoopLikeOpInterface>(newLoop.getOperation()); 437 } 438 439 /// Implementation of `yieldTiledValuesAndReplaceLoop` for 440 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each 441 /// supported loop type. 442 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop( 443 LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter, 444 ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { 445 return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>( 446 loopLikeOp.getOperation()) 447 .Case<scf::ForOp, scf::ForallOp>( 448 [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { 449 return yieldTiledValuesAndReplaceLoop( 450 loopOp, rewriter, newInitOperands, yieldTiledValuesFn); 451 }) 452 .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { 453 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); 454 }); 455 } 456 457 /// Method to add new init values to a loop nest. Updates `loops` in-place with 458 /// new loops that use the `newInitValues`. 459 /// The outer-loops are updated to yield the new result values of the inner 460 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get 461 /// the additional values to yield form the innermost loop. 462 static LogicalResult addInitOperandsToLoopNest( 463 RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops, 464 ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) { 465 SmallVector<scf::ForOp> newLoops; 466 if (loops.empty()) 467 return success(); 468 OpBuilder::InsertionGuard g(rewriter); 469 rewriter.setInsertionPoint(loops.front()); 470 471 SmallVector<Value> ivs; 472 for (auto &loop : loops.drop_back()) { 473 rewriter.setInsertionPoint(loop); 474 475 // if loops.size() > 1 we assume that scf.for is used for the loops. 476 auto forLoop = cast<scf::ForOp>(loop.getOperation()); 477 478 // Create a new loop with the new init values for this loop. 479 SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs()); 480 newInits.append(newInitValues.begin(), newInitValues.end()); 481 auto newLoop = rewriter.create<scf::ForOp>( 482 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), 483 forLoop.getStep(), newInits, 484 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); 485 486 // Merge the body of the new loop with the body of the old loops. 487 SmallVector<Value> sourceBlockArgs; 488 sourceBlockArgs.push_back(newLoop.getInductionVar()); 489 auto newRegionIterArgs = newLoop.getRegionIterArgs(); 490 sourceBlockArgs.append( 491 newRegionIterArgs.begin(), 492 std::next(newRegionIterArgs.begin(), forLoop.getNumResults())); 493 rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs); 494 rewriter.replaceOp( 495 forLoop, newLoop.getResults().take_front(forLoop.getNumResults())); 496 loop = newLoop; 497 ivs.push_back(newLoop.getInductionVar()); 498 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size()); 499 } 500 501 // Update the loop body of the innermost loop to get new yield values. 502 LoopLikeOpInterface innerMostLoop = loops.back(); 503 FailureOr<LoopLikeOpInterface> newInnerMostLoop = 504 yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues, 505 getNewTiledYieldsFn); 506 507 if (failed(newInnerMostLoop)) 508 return innerMostLoop.emitOpError("failed to return additional yields"); 509 loops.back() = newInnerMostLoop.value(); 510 511 // Make all other loops except the innermost loops yield the values returned 512 // by the inner loop. 513 for (auto [outerLoop, innerLoop] : 514 llvm::zip_equal(loops.drop_back(), loops.drop_front())) { 515 // Again assume that all the outer loops are scf.for operations. 516 auto outerForLoop = cast<scf::ForOp>(outerLoop); 517 auto outerLoopYield = 518 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); 519 SmallVector<Value> newYields = 520 llvm::to_vector(outerLoopYield.getOperands()); 521 ValueRange additionalYields = 522 innerLoop->getResults().take_back(newInitValues.size()); 523 newYields.append(additionalYields.begin(), additionalYields.end()); 524 rewriter.setInsertionPoint(outerLoopYield); 525 rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields); 526 } 527 return success(); 528 } 529 530 /// Implementation of tiling transformation of `op` that implements the 531 /// `TilingInterface` using `scf.for` to iterate over the tiles. 532 FailureOr<scf::SCFTilingResult> 533 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, 534 const scf::SCFTilingOptions &options) { 535 OpBuilder::InsertionGuard guard(rewriter); 536 rewriter.setInsertionPointAfter(op); 537 538 if (!options.tileSizeComputationFunction) { 539 return rewriter.notifyMatchFailure( 540 op, "missing tile size computation function"); 541 } 542 543 // 1. Get the range of the loops that are represented by the operation. 544 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 545 size_t numLoops = iterationDomain.size(); 546 547 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 548 // skips tiling a particular dimension. This convention is significantly 549 // simpler to handle instead of adjusting affine maps to account for missing 550 // dimensions. 551 SmallVector<OpFoldResult> tileSizes = 552 options.tileSizeComputationFunction(rewriter, op); 553 if (tileSizes.size() < iterationDomain.size()) { 554 auto zero = rewriter.getIndexAttr(0); 555 tileSizes.append(numLoops - tileSizes.size(), zero); 556 } 557 558 // 3. If there is an interchange specified, permute the iteration domain and 559 // the tile sizes. 560 SmallVector<int64_t> interchangeVector; 561 if (!options.interchangeVector.empty()) { 562 interchangeVector = fillInterchangeVector(options.interchangeVector, 563 iterationDomain.size()); 564 } 565 if (!interchangeVector.empty()) { 566 if (!isPermutationVector(interchangeVector)) { 567 return rewriter.notifyMatchFailure( 568 op, "invalid intechange vector, not a permutation of the entire " 569 "iteration space"); 570 } 571 572 applyPermutationToVector(iterationDomain, interchangeVector); 573 applyPermutationToVector(tileSizes, interchangeVector); 574 } 575 576 FailureOr<TilingResult> tilingResult; 577 // 4. Define the lambda function used later to generate the body of the 578 // innermost tiled loop. 579 YieldTiledValuesFn innerYieldTiledValuesFn = 580 [&](RewriterBase &rewriter, Location loc, ValueRange ivs, 581 ValueRange regionIterArgs, SmallVector<Value> &tiledResults, 582 SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 583 SmallVector<SmallVector<OpFoldResult>> &resultSizes) 584 -> LogicalResult { 585 // 4a. Compute the `offsets` and `sizes` to use for tiling. 586 SmallVector<OpFoldResult> offsets, sizes; 587 { 588 int materializedLoopNum = 0; 589 for (auto [tileSize, loopRange] : 590 llvm::zip_equal(tileSizes, iterationDomain)) { 591 if (isConstantIntValue(tileSize, 0)) { 592 offsets.push_back(loopRange.offset); 593 sizes.push_back(loopRange.size); 594 continue; 595 } 596 Value iv = ivs[materializedLoopNum++]; 597 offsets.push_back(iv); 598 sizes.push_back( 599 getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize)); 600 } 601 } 602 603 // 4b. If interchange was provided, apply inverse of the interchange 604 // to get back the offsets/sizes in the order to be specified. 605 if (!interchangeVector.empty()) { 606 auto inversePermutation = invertPermutationVector(interchangeVector); 607 applyPermutationToVector(offsets, inversePermutation); 608 applyPermutationToVector(sizes, inversePermutation); 609 } 610 611 // 5. Generate the tiled implementation within the inner most loop. 612 613 // 5a. Clone the operation within the loop body. 614 auto clonedOp = cast<TilingInterface>( 615 cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs)); 616 617 // 5b. Early return cloned op if tiling is not happening. We can not return 618 // the original op because it could lead to 619 // `rewriter.replaceOp(op, op->getResults())` and users would get crash. 620 if (llvm::all_of(tileSizes, isZeroIndex)) { 621 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); 622 tilingResult = 623 TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()}; 624 return success(); 625 } 626 627 // 5c. Tile the cloned operation. 628 tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes); 629 if (failed(tilingResult)) { 630 rewriter.eraseOp(clonedOp); 631 return op.emitOpError("faild to tile operation"); 632 } 633 634 // 5d. Delete the cloned operation. 635 rewriter.eraseOp(clonedOp); 636 637 // 5e. Compute the offsets at which the result values are to be inserted 638 // back into its destinations. 639 for (auto [index, tiledValue] : 640 llvm::enumerate(tilingResult->tiledValues)) { 641 tiledResults.push_back(tiledValue); 642 SmallVector<OpFoldResult> resultOffset, resultSize; 643 if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes, 644 resultOffset, resultSize))) { 645 for (auto op : tilingResult->tiledOps) { 646 rewriter.eraseOp(op); 647 } 648 return rewriter.notifyMatchFailure( 649 op, "failed to get slice of result produced"); 650 } 651 resultOffsets.emplace_back(std::move(resultOffset)); 652 resultSizes.emplace_back(std::move(resultSize)); 653 } 654 655 return success(); 656 }; 657 658 // 6. Find the destination tensors to use for the operation. 659 SmallVector<Value> destinationTensors; 660 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 661 destinationTensors))) { 662 return rewriter.notifyMatchFailure(op, 663 "unable to create destination tensors"); 664 } 665 666 // 7. Generate the tiled loops nest using the callback defined above. 667 SmallVector<LoopLikeOpInterface> loops; 668 if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain, 669 tileSizes, destinationTensors, 670 innerYieldTiledValuesFn, loops))) 671 return op.emitOpError("failed to generate tiling loops"); 672 assert(succeeded(tilingResult) && 673 "expected tiling result to be computed after loop generation"); 674 675 // If loops are empty, the tiled op is used as the replacement for the untiled 676 // op. 677 if (loops.empty()) { 678 return scf::SCFTilingResult{tilingResult->tiledOps, loops, 679 tilingResult->tiledValues}; 680 } 681 682 SmallVector<Value> replacements = llvm::map_to_vector( 683 loops.front()->getResults(), [](OpResult r) -> Value { return r; }); 684 return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements}; 685 } 686 687 FailureOr<scf::SCFReductionTilingResult> 688 mlir::scf::tileReductionUsingScf(RewriterBase &b, 689 PartialReductionOpInterface op, 690 ArrayRef<OpFoldResult> tileSizes) { 691 Location loc = op.getLoc(); 692 // Ops implementing PartialReductionOpInterface are expected to implement 693 // TilingInterface. 694 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 695 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 696 auto tileSizesVector = llvm::to_vector(tileSizes); 697 if (tileSizesVector.size() < iterationDomain.size()) { 698 auto zero = b.getIndexAttr(0); 699 tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(), 700 zero); 701 } 702 SmallVector<utils::IteratorType> iterators = 703 tilingInterfaceOp.getLoopIteratorTypes(); 704 705 SmallVector<int> reductionDims; 706 for (auto [idx, iteratorType] : 707 llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { 708 if (iteratorType == utils::IteratorType::reduction) 709 reductionDims.push_back(idx); 710 } 711 712 // 2. create the inital tensor value. 713 FailureOr<SmallVector<Value>> maybeInitTensors = 714 op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector, 715 reductionDims); 716 if (failed(maybeInitTensors)) { 717 return b.notifyMatchFailure(op, "Failed to create initial tensors."); 718 } 719 SmallVector<Value> &initTensors = maybeInitTensors.value(); 720 721 // 3. Define the callback to use for generating the inner most tile loop body. 722 Operation *parallelOp = nullptr; 723 auto innerYieldTiledValuesFn = 724 [&](RewriterBase &rewriter, Location loc, ValueRange ivs, 725 ValueRange regionIterArgs, SmallVector<Value> &tiledResult, 726 SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 727 SmallVector<SmallVector<OpFoldResult>> &resultSizes) 728 -> LogicalResult { 729 SmallVector<OpFoldResult> offsets, sizes; 730 { 731 int materializedLoopNum = 0; 732 for (auto [tileSize, loopRange] : 733 llvm::zip_equal(tileSizesVector, iterationDomain)) { 734 if (isConstantIntValue(tileSize, 0)) { 735 offsets.push_back(loopRange.offset); 736 sizes.push_back(loopRange.size); 737 continue; 738 } 739 Value iv = ivs[materializedLoopNum++]; 740 offsets.push_back(iv); 741 sizes.push_back( 742 getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize)); 743 } 744 } 745 746 // 4a. Clone the operation. 747 auto clonedOp = cast<PartialReductionOpInterface>( 748 cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs)); 749 750 // 4b. Tile the cloned operation. 751 parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs, 752 offsets, sizes, reductionDims); 753 // 4c. Delete the cloned operation. 754 b.eraseOp(clonedOp); 755 756 tiledResult.append(parallelOp->result_begin(), parallelOp->result_end()); 757 // 4d. Compute the offsets and sizes needed to insert the result of the 758 // tiled value back into destination before yielding the destination. 759 for (int resultIdx : llvm::seq<int>(0, parallelOp->getNumResults())) { 760 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); 761 resultOffsets.emplace_back(std::move(outOffsets)); 762 763 SmallVector<OpFoldResult> outSizes; 764 for (size_t i = 0; i < offsets.size(); i++) { 765 outSizes.push_back( 766 tensor::getMixedSize(b, loc, parallelOp->getResult(resultIdx), i)); 767 } 768 resultSizes.emplace_back(std::move(outSizes)); 769 } 770 return success(); 771 }; 772 773 // 5. Generate the tiled implementation using the destination tensors. 774 SmallVector<LoopLikeOpInterface> loops; 775 scf::SCFTilingOptions options; 776 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); 777 if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector, 778 initTensors, innerYieldTiledValuesFn, loops))) 779 return b.notifyMatchFailure(op, "failed to tile for parallel reduction"); 780 781 SmallVector<Value> replacements = llvm::map_to_vector( 782 loops.front()->getResults(), [](OpResult r) -> Value { return r; }); 783 784 // 5. Apply the merge reduction to combine all the partial values. 785 b.setInsertionPointAfter(*loops.begin()); 786 Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims); 787 b.replaceOp(op, mergeOp->getResults()); 788 789 SCFReductionTilingResult results; 790 results.initialValues = initTensors; 791 results.loops = loops; 792 results.parallelTiledOp = parallelOp; 793 results.mergeOp = mergeOp; 794 return results; 795 } 796 797 //===----------------------------------------------------------------------===// 798 // tileConsumerAndFuseProducersUsingSCF implementation. 799 //===----------------------------------------------------------------------===// 800 801 /// Return the untiled producer whose slice is used in a tiled consumer. The 802 /// method traverses the tile loop nest (`loops`) if needed, and returns the 803 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 804 /// indicates that this is a destination operand of the consumer. If there was 805 /// no loop traversal needed, the second value of the returned tuple is empty. 806 static std::tuple<OpResult, std::optional<OpOperand *>> 807 getUntiledProducerFromSliceSource(OpOperand *source, 808 ArrayRef<LoopLikeOpInterface> loops) { 809 std::optional<OpOperand *> destinationIterArg; 810 auto loopIt = loops.rbegin(); 811 while (auto iterArg = dyn_cast<BlockArgument>(source->get())) { 812 auto loop = *loopIt; 813 if (iterArg.getOwner()->getParentOp() != loop) 814 break; 815 source = loop.getTiedLoopInit(iterArg); 816 loopIt++; 817 } 818 if (loopIt == loops.rend()) 819 destinationIterArg = source; 820 return {dyn_cast<OpResult>(source->get()), destinationIterArg}; 821 } 822 823 /// Implementation of fusing producer of a single slice by computing the 824 /// slice of the producer in-place. 825 std::optional<scf::SCFFuseProducerOfSliceResult> 826 mlir::scf::tileAndFuseProducerOfSlice( 827 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, 828 MutableArrayRef<LoopLikeOpInterface> loops) { 829 // 1. Get the producer of the source (potentially walking through 830 // `iter_args` of nested `scf.for`) 831 auto [fusableProducer, destinationInitArg] = 832 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), 833 loops); 834 if (!fusableProducer) 835 return std::nullopt; 836 unsigned resultNumber = fusableProducer.getResultNumber(); 837 838 OpBuilder::InsertionGuard g(rewriter); 839 rewriter.setInsertionPoint(candidateSliceOp); 840 841 // 2. Clone the fused producer 842 // 2a. Compute the destination operands to use for the cloned operation. 843 SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors; 844 Operation *fusableProducerOp = fusableProducer.getOwner(); 845 if (isa<DestinationStyleOpInterface>(fusableProducerOp) && 846 failed(tensor::getOrCreateDestinations( 847 rewriter, fusableProducerOp->getLoc(), fusableProducerOp, 848 origDestinationTensors))) 849 return std::nullopt; 850 851 clonedOpDestinationTensors = origDestinationTensors; 852 if (destinationInitArg && 853 isa<DestinationStyleOpInterface>(fusableProducerOp)) { 854 // 2b. If the producer is also destination style, then to maintain the 855 // destination passing style, update the destination of the producer to be 856 // the source of the slice. 857 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource(); 858 } 859 // 2c. Clone the fused producer. 860 Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs( 861 rewriter, fusableProducerOp, clonedOpDestinationTensors); 862 // 2d. Update the source of the candidateSlice to be the cloned producer. 863 // Easier to just clone the slice with different source since replacements 864 // and DCE of cloned ops becomes easier 865 SmallVector<Value> candidateSliceOpOperands = 866 llvm::to_vector(candidateSliceOp->getOperands()); 867 candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber); 868 tensor::ExtractSliceOp clonedCandidateSliceOp = 869 mlir::clone(rewriter, candidateSliceOp, 870 candidateSliceOp->getResultTypes(), candidateSliceOpOperands); 871 872 // 3. Generate the tiled implementation of the producer of the source 873 FailureOr<TilingResult> tileAndFuseResult = 874 tensor::replaceExtractSliceWithTiledProducer( 875 rewriter, clonedCandidateSliceOp, 876 clonedProducerOp->getResult(resultNumber)); 877 if (failed(tileAndFuseResult)) 878 return std::nullopt; 879 // Note: Do not delete the candidateSliceOp, since its passed in from the 880 // caller. 881 rewriter.replaceAllUsesWith(candidateSliceOp, 882 tileAndFuseResult->tiledValues[0]); 883 rewriter.eraseOp(clonedCandidateSliceOp); 884 rewriter.eraseOp(clonedProducerOp); 885 886 // 3. If the slice is for a destination operand, for example, 887 // 888 // ```mlir 889 // %0 = linalg.init 890 // %1 = linalg.fill .. outs(%0 : ) 891 // %2 = scf.for .. iter_args(%arg0 = %1) { 892 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 893 // %4 = tensor.extract_slice %arg1 [..] 894 // .. = linalg.matmul .. outs(%4 : ) 895 // } 896 // } 897 // ``` 898 // 899 // the IR is currently 900 // 901 // ``` 902 // %0 = linalg.init 903 // %1 = linalg.fill 904 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 905 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 906 // %4 = tensor.extract_slice %arg1[..] 907 // %5 = linalg.fill .. outs(%4 : ) 908 // .. = linalg.matmul .. outs(%5 : ) 909 // } 910 // } 911 // ``` 912 // 913 // The untiled `linalg.fill` is still used as the `init_value` since it 914 // was originally a destination operand of the untiled `linalg.matmul`. 915 // When fusing an operand that is a destination operand, the iter_arg of 916 // the outer most loop should be changed to use the destination of the 917 // fused operation. With this the IR will be. 918 // 919 // ``` 920 // %0 = linalg.init 921 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 922 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 923 // %3 = tensor.extract_slice %arg1[..] 924 // %4 = linalg.fill .. outs(%3 : ) 925 // .. = linalg.matmul .. outs(%4 : ) 926 // } 927 // } 928 // ``` 929 if (destinationInitArg && 930 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) { 931 loops.front() 932 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] 933 .set(origDestinationTensors[resultNumber]); 934 } 935 return scf::SCFFuseProducerOfSliceResult{fusableProducer, 936 tileAndFuseResult->tiledValues[0], 937 tileAndFuseResult->tiledOps}; 938 } 939 940 /// Reconstruct the fused producer from within the tiled-and-fused code. 941 LogicalResult mlir::scf::yieldReplacementForFusedProducer( 942 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, 943 scf::SCFFuseProducerOfSliceResult fusedProducerInfo, 944 MutableArrayRef<LoopLikeOpInterface> loops) { 945 if (loops.empty()) 946 return success(); 947 948 OpResult fusableProducer = fusedProducerInfo.origProducer; 949 Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer; 950 FailureOr<Value> initValue = tensor::getOrCreateDestination( 951 rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); 952 if (succeeded(initValue)) { 953 954 YieldTiledValuesFn newYieldValuesFn = 955 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, 956 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, 957 SmallVector<SmallVector<OpFoldResult>> &tiledOffset, 958 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) 959 -> LogicalResult { 960 OpBuilder::InsertionGuard g(innerRewriter); 961 if (auto tiledDestStyleOp = 962 tiledAndFusedProducer 963 .getDefiningOp<DestinationStyleOpInterface>()) { 964 rewriter.setInsertionPoint(tiledDestStyleOp); 965 Value newRegionArg = newRegionIterArgs.back(); 966 auto destSlice = rewriter.create<tensor::ExtractSliceOp>( 967 sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(), 968 sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); 969 unsigned resultNumber = fusableProducer.getResultNumber(); 970 rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { 971 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); 972 }); 973 } 974 Block *block = rewriter.getInsertionPoint()->getBlock(); 975 rewriter.setInsertionPoint(block->getTerminator()); 976 tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer); 977 tiledOffset.emplace_back(sliceOp.getMixedOffsets()); 978 tiledSizes.emplace_back(sliceOp.getMixedSizes()); 979 return success(); 980 }; 981 982 return addInitOperandsToLoopNest(rewriter, loops, 983 SmallVector<Value>{initValue.value()}, 984 newYieldValuesFn); 985 } 986 return success(); 987 } 988 989 /// Implementation of tile consumer and fuse producer greedily. 990 FailureOr<scf::SCFTileAndFuseResult> 991 mlir::scf::tileConsumerAndFuseProducersUsingSCF( 992 RewriterBase &rewriter, TilingInterface consumer, 993 const scf::SCFTileAndFuseOptions &options) { 994 // This transformation is only valid for ops that return values (i.e. not 995 // valid to use with operations that have memref operands). 996 if (!consumer->getNumResults()) { 997 return rewriter.notifyMatchFailure( 998 consumer, "invalid pattern for op with no results"); 999 } 1000 1001 // 1. First tile the consumer. 1002 SetVector<Operation *> fusedProducers, tiledAndFusedOps; 1003 llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum; 1004 1005 FailureOr<scf::SCFTilingResult> tilingResult = 1006 tileUsingSCF(rewriter, consumer, options.tilingOptions); 1007 1008 if (failed(tilingResult)) 1009 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 1010 for (auto *tiledOp : tilingResult->tiledOps) 1011 tiledAndFusedOps.insert(tiledOp); 1012 1013 // If there are no loops generated, fusion is immaterial. 1014 auto &loops = tilingResult->loops; 1015 if (loops.empty()) { 1016 DenseMap<Value, Value> replacements; 1017 for (auto [origVal, replacement] : 1018 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) { 1019 replacements[origVal] = replacement; 1020 } 1021 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, 1022 replacements}; 1023 } 1024 1025 // To keep track of replacements for now just record the map from the original 1026 // untiled value to the result number of the for loop. Since the loop gets 1027 // potentially replaced during fusion, keeping the value directly wont work. 1028 DenseMap<Value, size_t> origValToResultNumber; 1029 for (auto [index, result] : llvm::enumerate(consumer->getResults())) { 1030 origValToResultNumber[result] = index; 1031 } 1032 1033 // 2. Typically, the operands of the tiled operation are slices of the 1034 // operands of the untiled operation. These are expressed in IR using 1035 // `tensor.extract_slice` operations with source being the operands of the 1036 // untiled operation. Create a worklist of these `tensor.extract_slice` 1037 // operations. If the producers of the source of the `tensor.extract_slice` 1038 // can be tiled such that the tiled value is generated in-place, that 1039 // effectively tiles + fuses the operations. 1040 auto addCandidateSlices = [](Operation *fusedOp, 1041 std::deque<tensor::ExtractSliceOp> &candidates) { 1042 for (Value operand : fusedOp->getOperands()) 1043 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 1044 candidates.push_back(sliceOp); 1045 }; 1046 1047 std::deque<tensor::ExtractSliceOp> candidates; 1048 addCandidateSlices(tiledAndFusedOps.back(), candidates); 1049 OpBuilder::InsertionGuard g(rewriter); 1050 while (!candidates.empty()) { 1051 // Traverse the slices in BFS fashion. 1052 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 1053 candidates.pop_front(); 1054 1055 // Find the original producer of the slice. 1056 auto [fusableProducer, destinationInitArg] = 1057 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), 1058 loops); 1059 if (!fusableProducer) 1060 continue; 1061 1062 auto [fuseSlice, yieldReplacement] = options.fusionControlFn( 1063 candidateSliceOp, fusableProducer, destinationInitArg.has_value()); 1064 if (!fuseSlice) 1065 continue; 1066 1067 // The operands of the fused producer might themselved be slices of 1068 // values produced by operations that implement the `TilingInterface`. 1069 // Add these operations to the worklist. 1070 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult = 1071 tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops); 1072 if (!fusedResult) 1073 continue; 1074 1075 if (yieldReplacement) { 1076 if (failed(yieldReplacementForFusedProducer( 1077 rewriter, candidateSliceOp, fusedResult.value(), loops))) { 1078 return rewriter.notifyMatchFailure( 1079 fusableProducer.getOwner(), "failed to replacement value for this " 1080 "oepration from within the tiled loop"); 1081 } 1082 origValToResultNumber[fusableProducer] = 1083 loops.front()->getNumResults() - 1; 1084 } 1085 1086 if (Operation *tiledAndFusedOp = 1087 fusedResult->tiledAndFusedProducer.getDefiningOp()) { 1088 fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); 1089 tiledAndFusedOps.insert(tiledAndFusedOp); 1090 addCandidateSlices(tiledAndFusedOp, candidates); 1091 } 1092 } 1093 1094 DenseMap<Value, Value> replacements; 1095 for (auto [origVal, resultNumber] : origValToResultNumber) { 1096 replacements[origVal] = loops.front()->getResult(resultNumber); 1097 } 1098 1099 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, 1100 replacements}; 1101 } 1102 1103 //===----------------------------------------------------------------------===// 1104 // tileAndFuseConsumerUsingSCF implementation. 1105 //===----------------------------------------------------------------------===// 1106 1107 /// A utility function that checks whether the only use of the result of a 1108 /// tensor.insert_slice op is in a scf.yield op. 1109 static LogicalResult 1110 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { 1111 Value result = candidateSliceOp.getResult(); 1112 Value::use_range uses = result.getUses(); 1113 if (!llvm::hasSingleElement(uses)) { 1114 LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n"); 1115 return failure(); 1116 } 1117 OpOperand &operandUse = (*uses.begin()); 1118 Operation *userOp = operandUse.getOwner(); 1119 if (!isa<scf::YieldOp>(userOp)) { 1120 LLVM_DEBUG(llvm::dbgs() 1121 << "Expected scf.yield to be the only user, but got -> " 1122 << (*userOp)); 1123 return failure(); 1124 } 1125 if (result.getDefiningOp()->getBlock() != userOp->getBlock()) { 1126 LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to " 1127 "be in the same block\n"); 1128 return failure(); 1129 } 1130 return success(); 1131 } 1132 1133 /// Fetches the OpOperand of the only user (and use) of the value `val` which 1134 /// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns 1135 /// failure otherwise. 1136 static FailureOr<OpOperand *> getConsumerFromUses(Value val, 1137 Block *containingOpBlock) { 1138 // Step 1. Check that the value has exactly one use. 1139 if (!llvm::hasSingleElement(val.getUses())) 1140 return failure(); 1141 // Step 2. Get uses. 1142 OpOperand &operand = (*val.getUses().begin()); 1143 Operation *consumerOp = operand.getOwner(); 1144 // TODO: We have to init result of consumer before scf.for, use 1145 // DestinationStyleOpInterface to get result shape from init for now. 1146 // Add support for other op such as op has InferTypeOpInterface. 1147 if (!isa<TilingInterface>(consumerOp) || 1148 !isa<DestinationStyleOpInterface>(consumerOp)) 1149 return failure(); 1150 if (containingOpBlock != consumerOp->getBlock()) 1151 return failure(); 1152 return &operand; 1153 } 1154 1155 /// Fetch the untiled consumer of a scf.for's result which is yielded by a 1156 /// tensor.insert_slice. This function makes the following assumptions : 1157 /// 1. tensor.insert_slice has scf.yield as its only user. 1158 /// 2. scf.for's corresponding result has only one use. 1159 static FailureOr<OpOperand *> 1160 getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) { 1161 if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) 1162 return failure(); 1163 Value sliceResult = candidateSliceOp.getResult(); 1164 // Step 1. Fetch the corresponding output. 1165 OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); 1166 unsigned resultNumber = yieldOpOperand.getOperandNumber(); 1167 // Step 2. Check containing op is scf.for. 1168 Operation *containingOp = candidateSliceOp->getParentOp(); 1169 auto forOp = dyn_cast<scf::ForOp>(containingOp); 1170 if (!forOp) 1171 return failure(); 1172 Value resultingValue = forOp->getResult(resultNumber); 1173 1174 return getConsumerFromUses(resultingValue, containingOp->getBlock()); 1175 } 1176 1177 /// Fetch the first untiled consumer of a scf.forall's result which is yielded 1178 /// by a tensor.parallel_insert_slice. 1179 static FailureOr<OpOperand *> 1180 getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) { 1181 // Step 1. Fetch the corresponding output 1182 Value sliceDest = candidateSliceOp.getDest(); 1183 auto iterArg = dyn_cast<BlockArgument>(sliceDest); 1184 if (!iterArg) 1185 return failure(); 1186 Operation *containingOp = iterArg.getOwner()->getParentOp(); 1187 if (containingOp != candidateSliceOp->getParentOp()->getParentOp()) 1188 return failure(); 1189 // Step 2. Check that the containing op is scf.forall. 1190 auto forallOp = dyn_cast<scf::ForallOp>(containingOp); 1191 if (!forallOp) 1192 return failure(); 1193 Value resultingValue = 1194 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)); 1195 1196 return getConsumerFromUses(resultingValue, containingOp->getBlock()); 1197 } 1198 1199 /// This utility currently checks whether the loop either :- 1200 /// 1. Yields exactly one result. 1201 /// 2. Has consumer op as its first user and other users to be in the same 1202 /// containing block as that of consumer op's. Currently we clone the loop op 1203 /// right before the consumer op in order to maintain a valid def-use chain. 1204 /// This utility thus helps ensuring that no invalid IR is formed due to the 1205 /// same. 1206 static LogicalResult checkAssumptionForLoop(Operation *loopOp, 1207 Operation *consumerOp) { 1208 // Check if the loop op yields one result. 1209 if (loopOp->getNumResults() == 1) 1210 return success(); 1211 // Check if the consumerOp is the first user of the loopOp and if other users 1212 // are in the same containing block as that of consumer op's. 1213 Block *parentBlock = consumerOp->getBlock(); 1214 for (Operation *userOp : loopOp->getUsers()) { 1215 if (userOp == consumerOp) 1216 continue; 1217 if (parentBlock != userOp->getBlock() || 1218 !consumerOp->isBeforeInBlock(userOp)) 1219 return failure(); 1220 } 1221 return success(); 1222 } 1223 1224 /// A utility to fetch an untiled consumer of 1225 /// tensor.insert_slice/tensor.parallel_insert_slice. 1226 static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) { 1227 if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) { 1228 return getUntiledConsumerFromSlice(insertSlice); 1229 } else if (auto parallelInsertSlice = 1230 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) { 1231 return getUntiledConsumerFromSlice(parallelInsertSlice); 1232 } else { 1233 return failure(); 1234 } 1235 } 1236 1237 /// After fusing consumer into scf.for we want to modify the scf.yield operation 1238 /// to reflect the same by returning the values yielded by the tiled consumer. 1239 static void 1240 fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp, 1241 TilingResult &tilingResult, 1242 ArrayRef<SmallVector<OpFoldResult>> &resultOffsets, 1243 ArrayRef<SmallVector<OpFoldResult>> &resultSizes, 1244 ArrayRef<BlockArgument> bbArgs) { 1245 scf::YieldOp oldTerminatorOp = 1246 cast<scf::YieldOp>(newForOp.getBody()->getTerminator()); 1247 unsigned totalOldResults = oldTerminatorOp->getNumResults(); 1248 unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults(); 1249 SmallVector<Value> newYieldOperands; 1250 newYieldOperands.reserve(totalOldResults + totalTiledResults); 1251 for (auto oldResult : oldTerminatorOp.getResults()) { 1252 newYieldOperands.push_back(oldResult); 1253 } 1254 rewriter.setInsertionPointAfter(oldTerminatorOp); 1255 Location loc = newForOp.getLoc(); 1256 for (auto [tiledResult, bbArg, resultOffset, resultSize] : 1257 llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs, 1258 resultOffsets, resultSizes)) { 1259 SmallVector<OpFoldResult> strides(resultOffset.size(), 1260 rewriter.getIndexAttr(1)); 1261 Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>( 1262 loc, tiledResult, bbArg, resultOffset, resultSize, strides); 1263 newYieldOperands.push_back(newInsertSliceOp); 1264 } 1265 rewriter.create<scf::YieldOp>(loc, newYieldOperands); 1266 rewriter.eraseOp(oldTerminatorOp); 1267 } 1268 1269 /// After fusing consumer into scf.forall we want to yield each of the resulting 1270 /// values by the tiled consumer within scf.forall.in_parallel region. 1271 static void 1272 fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp, 1273 SmallVector<Value> tiledResults, 1274 ArrayRef<SmallVector<OpFoldResult>> &resultOffsets, 1275 ArrayRef<SmallVector<OpFoldResult>> &resultSizes, 1276 ArrayRef<BlockArgument> bbArgs) { 1277 scf::InParallelOp newTerminatorOp = newForallOp.getTerminator(); 1278 rewriter.setInsertionPointToStart(newTerminatorOp.getBody()); 1279 Location firstYieldOpLoc = 1280 (*(newTerminatorOp.getYieldingOps().begin())).getLoc(); 1281 for (auto [tiledResult, bbArg, resultOffset, resultSize] : 1282 llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) { 1283 SmallVector<OpFoldResult> strides(resultOffset.size(), 1284 rewriter.getIndexAttr(1)); 1285 rewriter.create<tensor::ParallelInsertSliceOp>( 1286 firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides); 1287 } 1288 } 1289 1290 /// Implementation of fusing consumer of a single slice by computing the 1291 /// slice of the consumer in-place for scf loop. 1292 FailureOr<scf::SCFFuseConsumerOfSliceResult> 1293 mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, 1294 Operation *candidateSliceOp) { 1295 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>( 1296 candidateSliceOp)) 1297 return failure(); 1298 1299 bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp); 1300 1301 // 1. Get the consumer of scf.for for the result yielded by 1302 // tensor.insert_slice/parallel_insert_slice. 1303 FailureOr<OpOperand *> maybeConsumerOpOperand = 1304 getUntiledConsumerFromSlice(candidateSliceOp); 1305 if (failed(maybeConsumerOpOperand)) { 1306 return rewriter.notifyMatchFailure(candidateSliceOp, 1307 "could not fetch consumer to fuse"); 1308 } 1309 OpOperand *consumerOpOperand = *maybeConsumerOpOperand; 1310 Operation *consumerOp = consumerOpOperand->getOwner(); 1311 unsigned operandNumber = consumerOpOperand->getOperandNumber(); 1312 unsigned resultNumber = 0; 1313 if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) { 1314 resultNumber = producerResult.getResultNumber(); 1315 } else { 1316 return rewriter.notifyMatchFailure( 1317 consumerOp, "consumer op's operand doesn't seem to be an OpResult"); 1318 } 1319 1320 Operation *oldLoopOp = nullptr; 1321 SmallVector<Value> newOuts; 1322 Block *oldLoopBody = nullptr; 1323 unsigned initSize = 0; 1324 unsigned rank = 1; 1325 if (isInsertSliceOp) { 1326 auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>(); 1327 oldLoopOp = forOp; 1328 llvm::append_range(newOuts, forOp.getInits()); 1329 oldLoopBody = forOp.getBody(); 1330 initSize = forOp.getInits().size(); 1331 } else { 1332 auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>(); 1333 oldLoopOp = forallOp; 1334 llvm::append_range(newOuts, forallOp.getOutputs()); 1335 oldLoopBody = forallOp.getBody(); 1336 initSize = forallOp.getOutputs().size(); 1337 rank = forallOp.getRank(); 1338 } 1339 1340 if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) { 1341 return rewriter.notifyMatchFailure( 1342 oldLoopOp, "containing loop op should either yield just one value or " 1343 "have the consumer op as its first user"); 1344 } 1345 1346 OpBuilder::InsertionGuard g(rewriter); 1347 1348 // 2. Check consumer is not using scf loop's output as init. 1349 auto dstOp = cast<DestinationStyleOpInterface>(consumerOp); 1350 SmallVector<Value> dpsInits = 1351 llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; }); 1352 if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) { 1353 return rewriter.notifyMatchFailure( 1354 consumerOp, 1355 "consumer op taking the result of scf.for as init is not supported"); 1356 } 1357 newOuts.append(dpsInits); 1358 1359 Location loc = oldLoopOp->getLoc(); 1360 1361 // 3. Create new scf loop op. 1362 rewriter.setInsertionPoint(consumerOp); 1363 Operation *newLoopOp = nullptr; 1364 Block *newLoopBody = nullptr; 1365 if (isInsertSliceOp) { 1366 auto forOp = cast<scf::ForOp>(oldLoopOp); 1367 auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(), 1368 forOp.getUpperBound(), 1369 forOp.getStep(), newOuts); 1370 newLoopOp = newForOp; 1371 newLoopBody = newForOp.getBody(); 1372 } else { 1373 auto forallOp = cast<scf::ForallOp>(oldLoopOp); 1374 auto newForallOp = rewriter.create<scf::ForallOp>( 1375 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), 1376 forallOp.getMixedStep(), newOuts, forallOp.getMapping()); 1377 newLoopOp = newForallOp; 1378 rewriter.eraseOp(newForallOp.getTerminator()); 1379 newLoopBody = newForallOp.getBody(); 1380 } 1381 1382 // 4. Move the loop body to the new op. 1383 unsigned oldNumArguments = oldLoopBody->getNumArguments(); 1384 rewriter.mergeBlocks(oldLoopBody, newLoopBody, 1385 newLoopBody->getArguments().take_front(oldNumArguments)); 1386 1387 // 5. Set insertion point before terminator op of the loop and create a new 1388 // tensor.insert_slice. In the scf.for case this is a clone of the 1389 // candidateSliceOp whereas in the scf.forall case this is created from the 1390 // operands of tensor.parallel_insert_slice. 1391 tensor::InsertSliceOp clonedInsertSliceOp; 1392 if (auto sliceOp = 1393 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) { 1394 auto newForallOp = cast<scf::ForallOp>(newLoopOp); 1395 rewriter.setInsertionPoint(newForallOp.getTerminator()); 1396 clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>( 1397 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(), 1398 sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); 1399 } else { 1400 rewriter.setInsertionPoint(candidateSliceOp); 1401 clonedInsertSliceOp = 1402 cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp)); 1403 } 1404 1405 // 6.a. Clone consumer op. 1406 auto newForOpBlockArgsForConsumerDest = 1407 newLoopBody->getArguments().drop_front(oldNumArguments); 1408 auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs( 1409 rewriter, consumerOp, newForOpBlockArgsForConsumerDest)); 1410 1411 // 6.b. Replace all uses of the loop result with the result of the cloned 1412 // tensor.insert_slice. 1413 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber); 1414 rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { 1415 operandToReplace.set(clonedInsertSliceOp.getResult()); 1416 }); 1417 1418 // 7 - Perform tiling of the cloned consumer and replace the operand at 1419 // `operandNumber` with the source of the cloned tensor.insert_slice op. 1420 auto ossSliceOp = 1421 cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation()); 1422 FailureOr<TilingResult> tileAndFuseResult = 1423 tensor::replaceInsertSliceWithTiledConsumer( 1424 rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber)); 1425 if (failed(tileAndFuseResult)) { 1426 return failure(); 1427 } 1428 rewriter.replaceAllUsesWith( 1429 tileAndFuseResult->tiledOps[0]->getOperand(operandNumber), 1430 clonedInsertSliceOp.getSource()); 1431 1432 // 8 - Extract offset/sizes/strides required to create the 1433 // tensor.insert_slice/parallel_insert_slice for each result of the consumer. 1434 SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets(); 1435 SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes(); 1436 SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides(); 1437 1438 // 9. Check all insert stride is 1. 1439 if (llvm::any_of(strides, [](OpFoldResult stride) { 1440 return !isConstantIntValue(stride, 1); 1441 })) { 1442 return rewriter.notifyMatchFailure( 1443 candidateSliceOp, "containingOp's result yield with stride"); 1444 } 1445 1446 // 10. Try to get iter domain position from input position. 1447 SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes; 1448 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile( 1449 rewriter, operandNumber, offsets, sizes, iterDomainOffsets, 1450 iterDomainSizes))) { 1451 return rewriter.notifyMatchFailure( 1452 clonedConsumerOp, "can't get iter domain position from input position"); 1453 } 1454 1455 // 11. Try to fetch the offset and size for all results of the cloned 1456 // consumer. This would then be used to form the corresponding 1457 // tensor.insert_slice/parallel_insert_slice later. 1458 unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults(); 1459 SmallVector<SmallVector<OpFoldResult>> resultOffsets( 1460 totalNumResultsOfConsumer); 1461 SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer); 1462 for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) { 1463 if (failed(clonedConsumerOp.getResultTilePosition( 1464 rewriter, idx, iterDomainOffsets, iterDomainSizes, 1465 resultOffsets[idx], resultSizes[idx]))) { 1466 return rewriter.notifyMatchFailure( 1467 clonedConsumerOp, 1468 "can't get result domain position from iter domain position"); 1469 } 1470 } 1471 1472 auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets); 1473 auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes); 1474 if (isInsertSliceOp) { 1475 auto newForOp = cast<scf::ForOp>(newLoopOp); 1476 fixTerminatorSCFYield( 1477 rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes, 1478 newForOp.getBody()->getArguments().drop_front(1 + initSize)); 1479 } else { 1480 auto newForallOp = cast<scf::ForallOp>(newLoopOp); 1481 fixTerminatorSCFInParallel( 1482 rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(), 1483 arrayRefOffsets, arrayRefSizes, 1484 newForallOp.getBody()->getArguments().drop_front(rank + initSize)); 1485 } 1486 1487 // 12. Replace the result of scf loop and consumer op with new loop's results. 1488 for (auto &&[oldResult, newResult] : 1489 llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) { 1490 rewriter.replaceAllUsesWith(oldResult, newResult); 1491 } 1492 1493 for (auto &&[oldResult, newResult] : 1494 llvm::zip(consumerOp->getResults(), 1495 newLoopOp->getResults().drop_front(initSize))) { 1496 rewriter.replaceAllUsesWith(oldResult, newResult); 1497 } 1498 1499 // 13. Need to erase the old scf loop and the cloned consumer op. 1500 rewriter.eraseOp(oldLoopOp); 1501 rewriter.eraseOp(clonedConsumerOp); 1502 1503 return scf::SCFFuseConsumerOfSliceResult{ 1504 consumerOpOperand, 1505 &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)), 1506 tileAndFuseResult->tiledOps}; 1507 } 1508 1509 //===----------------------------------------------------------------------===// 1510 // lowerToLoopsUsingSCFForOp implementation. 1511 //===----------------------------------------------------------------------===// 1512 1513 FailureOr<SmallVector<scf::ForOp>> 1514 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 1515 TilingInterface op) { 1516 // TODO: Handle cases where the op has results if needed. 1517 if (op->getNumResults() > 0) { 1518 return rewriter.notifyMatchFailure( 1519 op, "unable to lower to loops operations with return values"); 1520 } 1521 1522 SmallVector<Range> domain = op.getIterationDomain(rewriter); 1523 SmallVector<Value> ivs; 1524 SmallVector<scf::ForOp> loops; 1525 Location loc = op.getLoc(); 1526 for (auto loopRange : domain) { 1527 Value offsetVal = 1528 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 1529 Value sizeVal = 1530 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 1531 Value strideVal = 1532 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 1533 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 1534 strideVal, ValueRange{}); 1535 loops.push_back(loop); 1536 ivs.push_back(loop.getInductionVar()); 1537 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 1538 } 1539 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 1540 return failure(); 1541 } 1542 return loops; 1543 } 1544