1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/IR/Dominance.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 26 #include "mlir/Interfaces/TilingInterface.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 #include "llvm/Support/Debug.h" 29 #include <optional> 30 31 #define DEBUG_TYPE "tile-using-interface" 32 33 using namespace mlir; 34 35 scf::SCFTilingOptions & 36 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) { 37 assert(!tileSizeComputationFunction && "tile sizes already set"); 38 auto tileSizes = llvm::to_vector(ts); 39 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 40 return tileSizes; 41 }; 42 return *this; 43 } 44 45 scf::SCFTilingOptions & 46 scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) { 47 assert(!numThreadsComputationFunction && "num tiles already set"); 48 auto numThreads = llvm::to_vector(nt); 49 numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) { 50 return numThreads; 51 }; 52 return *this; 53 } 54 55 /// Helper method to adjust the interchange vector to match the iteration 56 /// domain. 57 static SmallVector<int64_t> 58 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 59 size_t iterationDomainSize) { 60 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 61 if (filledVector.size() < iterationDomainSize) { 62 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 63 filledVector.append(range.begin(), range.end()); 64 } 65 if (filledVector.size() > iterationDomainSize) 66 filledVector.resize(iterationDomainSize); 67 return filledVector; 68 } 69 70 //===----------------------------------------------------------------------===// 71 // tileUsingSCF implementation. 72 //===----------------------------------------------------------------------===// 73 74 /// Verify the tile size options are set in a consistent manner. 75 static LogicalResult 76 verifyTileSizeOptions(RewriterBase &rewriter, Location loc, 77 const scf::SCFTilingOptions &options) { 78 // Specifying number of threads is only supported on `scf.forall` op. 79 if (options.numThreadsComputationFunction && 80 options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) { 81 return rewriter.notifyMatchFailure( 82 loc, "number of threads can only by specified when loop type is " 83 "set to use `scf.forall`"); 84 } 85 86 // If specified, check that the interchange vector is a permutation. 87 if (!options.interchangeVector.empty()) { 88 if (!isPermutationVector(options.interchangeVector)) { 89 return rewriter.notifyMatchFailure( 90 loc, "invalid interchange vector, not a permutation of the entire " 91 "iteration space"); 92 } 93 } 94 return success(); 95 } 96 97 /// Method to instantiate the tile sizes and/or number of threads specified 98 /// by the user. 99 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> 100 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, 101 ArrayRef<Range> iterationDomain, 102 const scf::SCFTilingOptions &options) { 103 OpFoldResult zero = rewriter.getIndexAttr(0); 104 SmallVector<OpFoldResult> tileSizes, numThreads; 105 size_t numLoops = iterationDomain.size(); 106 107 // Check whether the number of tiles to use is specified. 108 if (options.numThreadsComputationFunction) { 109 numThreads = options.numThreadsComputationFunction(rewriter, op); 110 numThreads.resize(numLoops, zero); 111 112 // If the number of tiles is also specified, use that. 113 if (options.tileSizeComputationFunction) { 114 tileSizes = options.tileSizeComputationFunction(rewriter, op); 115 tileSizes.resize(numLoops, zero); 116 return {tileSizes, numThreads}; 117 } 118 119 // Compute the tile sizes from the iteration domain and number 120 // of tiles as follows 121 // - niters = ceilDiv(ub - lb, step) 122 // - tileSize = ceilDiv(niters, numThreads) 123 AffineExpr s0, s1, s2; 124 bindSymbols(rewriter.getContext(), s0, s1, s2); 125 // TODO: The step here is assumed to be 1. 126 AffineExpr numItersExpr = (s1 - s0); 127 AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2); 128 tileSizes.resize(numLoops, zero); 129 for (auto [index, range, nt] : 130 llvm::enumerate(iterationDomain, numThreads)) { 131 if (isConstantIntValue(nt, 0)) 132 continue; 133 134 tileSizes[index] = affine::makeComposedFoldedAffineApply( 135 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt}); 136 } 137 tileSizes.resize(numLoops, zero); 138 return {tileSizes, numThreads}; 139 } 140 141 // Enforce the convention that "tiling by zero" 142 // skips tiling a particular dimension. This convention is significantly 143 // simpler to handle instead of adjusting affine maps to account for missing 144 // dimensions. 145 assert(options.tileSizeComputationFunction && 146 "expected tile sizes to be specified"); 147 tileSizes = options.tileSizeComputationFunction(rewriter, op); 148 tileSizes.resize(numLoops, zero); 149 150 return {tileSizes, numThreads}; 151 } 152 153 /// Checks if any of the tiled loops are not parallel. 154 static void checkSafeToTileToForall(TilingInterface op, 155 ArrayRef<OpFoldResult> tileSizes, 156 ArrayRef<OpFoldResult> numThreads) { 157 auto iterators = op.getLoopIteratorTypes(); 158 assert(iterators.size() == tileSizes.size() && 159 "expected as many tile size values as number of loops"); 160 assert((numThreads.empty() || (numThreads.size() == iterators.size())) && 161 "when specified, expected number of threads to use for each loop"); 162 163 for (auto [index, iterator, tileSize] : 164 llvm::enumerate(iterators, tileSizes)) { 165 // If num threads is specified, check that it is greater than one only for 166 // parallel dimensions. 167 if (!numThreads.empty()) { 168 if (std::optional<int64_t> constNumThreads = 169 getConstantIntValue(numThreads[index])) { 170 if (constNumThreads.value() > 1 && 171 iterator != utils::IteratorType::parallel) { 172 op.emitWarning() << "tiling is not thread safe at axis #" << index; 173 } 174 } 175 continue; 176 } 177 178 if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) { 179 if (constTileSize.value() > 0 && 180 iterator != utils::IteratorType::parallel) { 181 op.emitWarning() << "tiling is not thread safe at axis #" << index; 182 } 183 } 184 } 185 } 186 187 /// Check if `stride` evenly divides the trip count `size - offset`. 188 static bool tileDividesIterationDomain(Range loopRange) { 189 std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 190 if (!offsetAsInt) 191 return false; 192 std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 193 if (!sizeAsInt) 194 return false; 195 std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 196 if (!strideAsInt) 197 return false; 198 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 199 } 200 201 /// Returns the bounded tile size given the current `offset`, `loopRange` and 202 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`. 203 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 204 Range loopRange, OpFoldResult offset, 205 OpFoldResult tileSize) { 206 std::optional<int64_t> ts = getConstantIntValue(tileSize); 207 if (ts && ts.value() == 1) 208 return tileSize; 209 210 if (tileDividesIterationDomain( 211 Range{loopRange.offset, loopRange.size, tileSize})) 212 return tileSize; 213 214 // The tile size to use (to avoid out of bounds access) is minimum of 215 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 216 // loop. 217 AffineExpr s0, s1, d0; 218 bindDims(b.getContext(), d0); 219 bindSymbols(b.getContext(), s0, s1); 220 AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext()); 221 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 222 return affine::makeComposedFoldedAffineMin( 223 b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize}); 224 } 225 226 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less 227 /// than `iterationSize`. 228 static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, 229 OpFoldResult numThreads, 230 OpFoldResult iterationSize) { 231 std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize); 232 std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads); 233 std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize); 234 if (!tileSizeConst || !numThreadsConst || !iterSizeConst) 235 return false; 236 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; 237 } 238 239 /// Compute the `OpFoldResult`s that represents the multi-dimensional 240 /// `offset`s and `size`s of the tile of the iteration space that the 241 /// innermost loop body of the generated tiled loops corresponds to. 242 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> 243 getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, 244 ArrayRef<Range> iterationDomain, 245 ArrayRef<OpFoldResult> tileSizes, 246 ArrayRef<OpFoldResult> numThreads) { 247 SmallVector<OpFoldResult> offsets, sizes; 248 int materializedLoopNum = 0; 249 250 if (!numThreads.empty()) { 251 AffineExpr d0, d1, s0, s1; 252 AffineExpr offsetExpr, residualTileSizeExpr; 253 bindDims(rewriter.getContext(), d0, d1); 254 bindSymbols(rewriter.getContext(), s0, s1); 255 offsetExpr = d0 + d1 * s0; 256 residualTileSizeExpr = s1 - (d0 + d1 * s0); 257 258 for (auto [nt, tileSize, loopRange] : 259 llvm::zip_equal(numThreads, tileSizes, iterationDomain)) { 260 261 // Non-tiled cases, set the offset and size to the 262 // `loopRange.offset/size`. 263 if (isConstantIntValue(nt, 0)) { 264 offsets.push_back(loopRange.offset); 265 sizes.push_back(loopRange.size); 266 continue; 267 } 268 269 Value iv = ivs[materializedLoopNum++]; 270 OpFoldResult offset = affine::makeComposedFoldedAffineApply( 271 rewriter, loc, offsetExpr, 272 ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize}); 273 OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply( 274 rewriter, loc, residualTileSizeExpr, 275 {loopRange.offset, nt, tileSize, loopRange.size}); 276 277 OpFoldResult size = tileSize; 278 if (!isConstantIntValue(residualTileSize, 0)) { 279 OpFoldResult sizeMinusOffsetPerThread = 280 affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, 281 {offset, loopRange.size}); 282 size = affine::makeComposedFoldedAffineMin( 283 rewriter, loc, 284 AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), 285 {sizeMinusOffsetPerThread, tileSize}); 286 } 287 288 // Consider the case where the original loop was `[0, 100)`. 289 // If number of threads are `7`, the tile size would be computed as 290 // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6) 291 // - `offset = 0 + 6 * 15 = 105` 292 // - `tileSize = min(15, 100 - 105) = -5` 293 // To avoid negative tile sizes, we need to do a further 294 // `nonNegativeTileSize = affine.max(0, tileSize)`. 295 // This `max` can be avoided if 296 // `offset + tileSize * (numThreads - 1) < (ub - lb)` 297 if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) { 298 AffineMap maxMap = 299 AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); 300 size = affine::makeComposedFoldedAffineMax( 301 rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size}); 302 } 303 304 offsets.push_back(offset); 305 sizes.push_back(size); 306 } 307 return {offsets, sizes}; 308 } else { 309 for (auto [tileSize, loopRange] : 310 llvm::zip_equal(tileSizes, iterationDomain)) { 311 312 // Non-tiled cases, set the offset and size to the 313 // `loopRange.offset/size`. 314 if (isConstantIntValue(tileSize, 0)) { 315 offsets.push_back(loopRange.offset); 316 sizes.push_back(loopRange.size); 317 continue; 318 } 319 320 Value iv = ivs[materializedLoopNum++]; 321 OpFoldResult offset = getAsOpFoldResult(iv); 322 offsets.push_back(offset); 323 OpFoldResult size = 324 getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize); 325 sizes.push_back(size); 326 } 327 return {offsets, sizes}; 328 } 329 } 330 331 /// Function to return the bounds of the loops to be generated. 332 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, 333 SmallVector<OpFoldResult>> 334 getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 335 ArrayRef<OpFoldResult> tileSizes) { 336 SmallVector<OpFoldResult> lbs, ubs, steps; 337 for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) { 338 // No loop if the tile size is 0. 339 if (isConstantIntValue(tileSize, 0)) 340 continue; 341 lbs.push_back(loopRange.offset); 342 ubs.push_back(loopRange.size); 343 steps.push_back(tileSize); 344 } 345 return {lbs, ubs, steps}; 346 } 347 348 /// A function that allows returning additional yielded values during 349 /// `yieldTiledValuesAndReplace`. 350 /// - `ivs` induction variable for the loop. 351 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args. 352 /// - `tiledValues` the tiled values to return. Must be of same size as 353 /// `newbbArgs`, each element of this array is inserted into the corresponding 354 /// element in `newbbArgs`. 355 /// - `resultOffsets` is of the same size as `tiledValues` and represents 356 /// the offsets to use when inserting corresponding element from `tiledValues` 357 /// into the element from `newBbArgs`. 358 /// - `resultSizes` is of the same size as `tiledValues` and represents 359 /// the size of the corresponding element from `tiledValues` inserted into 360 /// the element from `newBbArgs`. 361 /// In case the method needs to return `failure()` the method is expected 362 /// to clean up any inserted operations. 363 using YieldTiledValuesFn = std::function<LogicalResult( 364 RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, 365 SmallVector<Value> &tiledValues, 366 SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 367 SmallVector<SmallVector<OpFoldResult>> &resultSizes)>; 368 369 /// Clones the operation and updates the destination if the operation 370 /// implements the `DestinationStyleOpInterface`. 371 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, 372 Operation *op, 373 ValueRange newDestArgs) { 374 Operation *clonedOp = rewriter.clone(*op); 375 if (newDestArgs.empty()) 376 return clonedOp; 377 if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp)) 378 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs); 379 return clonedOp; 380 } 381 382 /// Generate the tile-loop nest using `scf.for` operation. 383 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 384 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 385 /// - `destinationTensors` are the init values to use for the outer most loop. 386 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 387 /// most 388 /// loop. 389 /// - `loops` is an in-out parameter into which the generated loops are 390 /// populated. 391 static LogicalResult generateLoopNestUsingForOp( 392 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 393 ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors, 394 YieldTiledValuesFn yieldTiledValuesFn, 395 SmallVector<LoopLikeOpInterface> &loops) { 396 assert(!loopRanges.empty() && "unexpected empty loop ranges"); 397 assert(loopRanges.size() == tileSizes.size() && 398 "expected as many tile sizes as loop ranges"); 399 OpBuilder::InsertionGuard guard(rewriter); 400 401 SmallVector<OpFoldResult> lbs, ubs, steps; 402 std::tie(lbs, ubs, steps) = 403 getLoopBounds(rewriter, loc, loopRanges, tileSizes); 404 SmallVector<Value> lbVals = 405 getValueOrCreateConstantIndexOp(rewriter, loc, lbs); 406 SmallVector<Value> ubVals = 407 getValueOrCreateConstantIndexOp(rewriter, loc, ubs); 408 SmallVector<Value> stepVals = 409 getValueOrCreateConstantIndexOp(rewriter, loc, steps); 410 411 SmallVector<Value> ivs; 412 for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) { 413 auto loop = 414 rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors, 415 [](OpBuilder &bodyBuilder, Location bodyLoc, 416 Value iv, ValueRange /*iterArgs*/) {}); 417 loops.push_back(loop); 418 ivs.push_back(loop.getInductionVar()); 419 rewriter.setInsertionPointToEnd(loop.getBody()); 420 destinationTensors = loop.getRegionIterArgs(); 421 } 422 423 SmallVector<Value> tiledResults; 424 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 425 if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors, 426 tiledResults, resultOffsets, resultSizes))) { 427 return rewriter.notifyMatchFailure( 428 loc, "failed to generate inner tile loop body"); 429 } 430 if (loops.empty()) 431 return success(); 432 433 assert(tiledResults.size() == destinationTensors.size() && 434 "Number of results of body should be equal to number of iter args"); 435 436 // 6. Yield all the results of the tiled operation. 437 SmallVector<Value> yieldedValues; 438 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : 439 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, 440 resultSizes)) { 441 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 442 rewriter.getIndexAttr(1)); 443 auto insertSlice = rewriter.create<tensor::InsertSliceOp>( 444 loc, tiledValue, destinationTensor, resultOffset, resultSize, 445 resultStride); 446 yieldedValues.push_back(insertSlice); 447 } 448 rewriter.create<scf::YieldOp>(loc, yieldedValues); 449 450 // Add the scf.yield operations for all the outer loops. 451 for (auto [outerLoop, innerLoop] : 452 llvm::zip_equal(MutableArrayRef(loops).drop_back(), 453 MutableArrayRef(loops).drop_front())) { 454 rewriter.setInsertionPointToEnd( 455 cast<scf::ForOp>(outerLoop.getOperation()).getBody()); 456 rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults()); 457 } 458 return success(); 459 } 460 461 /// Generate the tile-loop nest using `scf.forall` operation. 462 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 463 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 464 /// - `destinationTensors` are the init values to use for the outer most loop. 465 /// - `mappingVector` is the mapping attributes to use for loop construction. 466 /// Can be empty. 467 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 468 /// most 469 /// loop. 470 /// - `loops` is an in-out parameter into which the generated loops are 471 /// populated. 472 static LogicalResult generateLoopNestUsingForallOp( 473 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 474 ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads, 475 ArrayRef<Attribute> mappingVector, ValueRange destinationTensors, 476 YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { 477 assert(!loopRanges.empty() && "unexpected empty loop ranges"); 478 assert(loopRanges.size() == tileSizes.size() && 479 "expected as many tile sizes as loop ranges"); 480 OpBuilder::InsertionGuard guard(rewriter); 481 SmallVector<OpFoldResult> offsets(loopRanges.size()), 482 sizes(loopRanges.size()); 483 484 std::optional<ArrayAttr> mappingAttr; 485 if (!mappingVector.empty()) 486 mappingAttr = rewriter.getArrayAttr(mappingVector); 487 488 scf::ForallOp forallOp; 489 bool useNumThreads = !numThreads.empty(); 490 491 if (useNumThreads) { 492 // Prune the zero numthreads. 493 SmallVector<OpFoldResult> nonZeroNumThreads; 494 for (auto nt : numThreads) { 495 if (isConstantIntValue(nt, 0)) 496 continue; 497 nonZeroNumThreads.push_back(nt); 498 } 499 forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads, 500 destinationTensors, mappingAttr); 501 } else { 502 SmallVector<OpFoldResult> lbs, ubs, steps; 503 std::tie(lbs, ubs, steps) = 504 getLoopBounds(rewriter, loc, loopRanges, tileSizes); 505 forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, 506 destinationTensors, mappingAttr); 507 } 508 loops.push_back(forallOp); 509 510 rewriter.setInsertionPoint(forallOp.getTerminator()); 511 destinationTensors = forallOp.getRegionOutArgs(); 512 513 SmallVector<Value> tiledResults; 514 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 515 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(), 516 destinationTensors, tiledResults, resultOffsets, 517 resultSizes))) 518 return rewriter.notifyMatchFailure(loc, "failed to generate loop body"); 519 520 rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); 521 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : 522 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, 523 resultSizes)) { 524 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 525 rewriter.getIndexAttr(1)); 526 527 rewriter.create<tensor::ParallelInsertSliceOp>( 528 loc, tiledValue, destinationTensor, resultOffset, resultSize, 529 resultStride); 530 } 531 return success(); 532 } 533 534 /// Generate the tile-loop nest using the loop construct specifed in `options`. 535 /// - `options`: Tiling options specified. 536 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 537 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 538 /// - `destinationTensors` are the init values to use for the outer most loop. 539 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 540 /// most 541 /// loop. 542 /// - `loops` is an in-out parameter into which the generated loops are 543 /// populated. 544 static LogicalResult generateLoopNest( 545 RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, 546 ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes, 547 ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors, 548 YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { 549 // If the tile sizes are all zero, no loops are generated. Just call the 550 // callback function to handle untiled case. 551 if (llvm::all_of(tileSizes, isZeroIndex)) { 552 SmallVector<Value> tiledResults; 553 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 554 return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, 555 tiledResults, resultOffsets, resultSizes); 556 } 557 if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { 558 return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, 559 destinationTensors, tiledBodyFn, loops); 560 } 561 if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { 562 return generateLoopNestUsingForallOp( 563 rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector, 564 destinationTensors, tiledBodyFn, loops); 565 } 566 return rewriter.notifyMatchFailure(loc, "unhandled loop type"); 567 } 568 569 /// Append the specified additional `newInitOperands` operands to the 570 /// loops existing `init` operands (or similar), and replace `loopOp` with 571 /// the new loop that has the additional init operands. The loop body of 572 /// this loop is moved over to the new loop. `yieldTiledValuesFn` 573 /// is called to get the new tiled values returned, and the offset 574 /// and sizes at which the tiled value is inserted into the 575 /// new region iter_args that correspond to the newly added init operands. 576 template <typename LoopType> 577 FailureOr<LoopLikeOpInterface> 578 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, 579 ValueRange newInitOperands, 580 YieldTiledValuesFn yieldTiledValuesFn) { 581 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); 582 } 583 584 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`. 585 template <> 586 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( 587 scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, 588 YieldTiledValuesFn yieldTiledValuesFn) { 589 OpBuilder::InsertionGuard g(rewriter); 590 Location loc = loopOp.getLoc(); 591 rewriter.setInsertionPoint(loopOp); 592 593 auto inits = llvm::to_vector(loopOp.getInitArgs()); 594 inits.append(newInitOperands.begin(), newInitOperands.end()); 595 auto newLoop = rewriter.create<scf::ForOp>( 596 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(), 597 inits, [](OpBuilder &, Location, Value, ValueRange) {}); 598 599 // Move the loop body to the new op. 600 Block *loopBody = loopOp.getBody(); 601 Block *newLoopBody = newLoop.getBody(); 602 rewriter.mergeBlocks( 603 loopBody, newLoopBody, 604 newLoopBody->getArguments().take_front(loopBody->getNumArguments())); 605 606 auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator()); 607 rewriter.setInsertionPoint(yieldOp); 608 609 SmallVector<Value> tiledValues; 610 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 611 ValueRange newRegionIterArgs = 612 newLoop.getRegionIterArgs().take_back(newInitOperands.size()); 613 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(), 614 newRegionIterArgs, tiledValues, resultOffsets, 615 resultSizes))) { 616 rewriter.eraseOp(newLoop); 617 return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values"); 618 } 619 620 SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands()); 621 for (auto [tiledValue, regionIterArg, resultOffset, resultSize] : 622 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets, 623 resultSizes)) { 624 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 625 rewriter.getIndexAttr(1)); 626 Value insert = rewriter.create<tensor::InsertSliceOp>( 627 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize, 628 resultStride); 629 newYieldValues.push_back(insert); 630 } 631 632 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 633 rewriter.replaceOp(loopOp, 634 newLoop->getResults().take_front(loopOp.getNumResults())); 635 return cast<LoopLikeOpInterface>(newLoop.getOperation()); 636 } 637 638 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall` 639 template <> 640 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>( 641 scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, 642 YieldTiledValuesFn yieldTiledValuesFn) { 643 OpBuilder::InsertionGuard g(rewriter); 644 Location loc = loopOp.getLoc(); 645 rewriter.setInsertionPoint(loopOp); 646 auto inits = llvm::to_vector(loopOp.getOutputs()); 647 inits.append(newInitOperands.begin(), newInitOperands.end()); 648 auto newLoop = rewriter.create<scf::ForallOp>( 649 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(), 650 loopOp.getMixedStep(), inits, loopOp.getMapping(), 651 [](OpBuilder &, Location, ValueRange) {}); 652 653 // Move the region of the current block to the newly created op. 654 Block *loopBody = loopOp.getBody(); 655 Block *newLoopBody = newLoop.getBody(); 656 rewriter.mergeBlocks( 657 loopBody, newLoopBody, 658 newLoopBody->getArguments().take_front(loopBody->getNumArguments())); 659 660 auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator()); 661 rewriter.setInsertionPoint(terminator); 662 SmallVector<Value> tiledValues; 663 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 664 ValueRange regionIterArgs = 665 newLoop.getRegionIterArgs().take_back(newInitOperands.size()); 666 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(), 667 regionIterArgs, tiledValues, resultOffsets, 668 resultSizes))) { 669 rewriter.eraseOp(newLoop); 670 return rewriter.notifyMatchFailure(loopOp, 671 "failed to get yielded tiled values"); 672 } 673 674 // Update the terminator. 675 rewriter.setInsertionPointToEnd(terminator.getBody()); 676 677 for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal( 678 tiledValues, regionIterArgs, resultOffsets, resultSizes)) { 679 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 680 rewriter.getIndexAttr(1)); 681 rewriter.create<tensor::ParallelInsertSliceOp>( 682 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize, 683 resultStride); 684 } 685 686 rewriter.replaceOp(loopOp, 687 newLoop->getResults().take_front(loopOp.getNumResults())); 688 return cast<LoopLikeOpInterface>(newLoop.getOperation()); 689 } 690 691 /// Implementation of `yieldTiledValuesAndReplaceLoop` for 692 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each 693 /// supported loop type. 694 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop( 695 LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter, 696 ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { 697 return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>( 698 loopLikeOp.getOperation()) 699 .Case<scf::ForOp, scf::ForallOp>( 700 [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { 701 return yieldTiledValuesAndReplaceLoop( 702 loopOp, rewriter, newInitOperands, yieldTiledValuesFn); 703 }) 704 .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { 705 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); 706 }); 707 } 708 709 /// Method to add new init values to a loop nest. Updates `loops` in-place with 710 /// new loops that use the `newInitValues`. 711 /// The outer-loops are updated to yield the new result values of the inner 712 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get 713 /// the additional values to yield form the innermost loop. 714 static LogicalResult addInitOperandsToLoopNest( 715 RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops, 716 ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) { 717 SmallVector<scf::ForOp> newLoops; 718 if (loops.empty()) 719 return success(); 720 OpBuilder::InsertionGuard g(rewriter); 721 rewriter.setInsertionPoint(loops.front()); 722 723 SmallVector<Value> ivs; 724 for (auto &loop : loops.drop_back()) { 725 rewriter.setInsertionPoint(loop); 726 727 // if loops.size() > 1 we assume that scf.for is used for the loops. 728 auto forLoop = cast<scf::ForOp>(loop.getOperation()); 729 730 // Create a new loop with the new init values for this loop. 731 SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs()); 732 newInits.append(newInitValues.begin(), newInitValues.end()); 733 auto newLoop = rewriter.create<scf::ForOp>( 734 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), 735 forLoop.getStep(), newInits, 736 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); 737 738 // Merge the body of the new loop with the body of the old loops. 739 SmallVector<Value> sourceBlockArgs; 740 sourceBlockArgs.push_back(newLoop.getInductionVar()); 741 auto newRegionIterArgs = newLoop.getRegionIterArgs(); 742 sourceBlockArgs.append( 743 newRegionIterArgs.begin(), 744 std::next(newRegionIterArgs.begin(), forLoop.getNumResults())); 745 rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs); 746 rewriter.replaceOp( 747 forLoop, newLoop.getResults().take_front(forLoop.getNumResults())); 748 loop = newLoop; 749 ivs.push_back(newLoop.getInductionVar()); 750 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size()); 751 } 752 753 // Update the loop body of the innermost loop to get new yield values. 754 LoopLikeOpInterface innerMostLoop = loops.back(); 755 FailureOr<LoopLikeOpInterface> newInnerMostLoop = 756 yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues, 757 getNewTiledYieldsFn); 758 759 if (failed(newInnerMostLoop)) 760 return innerMostLoop.emitOpError("failed to return additional yields"); 761 loops.back() = newInnerMostLoop.value(); 762 763 // Make all other loops except the innermost loops yield the values returned 764 // by the inner loop. 765 for (auto [outerLoop, innerLoop] : 766 llvm::zip_equal(loops.drop_back(), loops.drop_front())) { 767 // Again assume that all the outer loops are scf.for operations. 768 auto outerForLoop = cast<scf::ForOp>(outerLoop); 769 auto outerLoopYield = 770 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); 771 SmallVector<Value> newYields = 772 llvm::to_vector(outerLoopYield.getOperands()); 773 ValueRange additionalYields = 774 innerLoop->getResults().take_back(newInitValues.size()); 775 newYields.append(additionalYields.begin(), additionalYields.end()); 776 rewriter.setInsertionPoint(outerLoopYield); 777 rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields); 778 } 779 return success(); 780 } 781 782 /// Implementation of tiling transformation of `op` that implements the 783 /// `TilingInterface` using `scf.for` to iterate over the tiles. 784 FailureOr<scf::SCFTilingResult> 785 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, 786 const scf::SCFTilingOptions &options) { 787 if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) { 788 return failure(); 789 } 790 791 OpBuilder::InsertionGuard guard(rewriter); 792 rewriter.setInsertionPointAfter(op); 793 794 // 1. Get the range of the loops that are represented by the operation. 795 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 796 797 // 2. Materialize the tile sizes and/or number of threads; 798 SmallVector<OpFoldResult> tileSizes, numThreads; 799 std::tie(tileSizes, numThreads) = 800 getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options); 801 802 // Check if it is safe to tile. This is hold over from previous iterations 803 // of tile to for-all. Consider dropping it. 804 if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { 805 checkSafeToTileToForall(op, tileSizes, numThreads); 806 } 807 808 // 3. If there is an interchange specified, permute the iteration domain and 809 // the tile sizes. 810 SmallVector<int64_t> interchangeVector; 811 if (!options.interchangeVector.empty()) { 812 interchangeVector = fillInterchangeVector(options.interchangeVector, 813 iterationDomain.size()); 814 assert(isPermutationVector(interchangeVector) && 815 "expected interchange vector to be a permutation"); 816 817 applyPermutationToVector(iterationDomain, interchangeVector); 818 applyPermutationToVector(tileSizes, interchangeVector); 819 if (!numThreads.empty()) 820 applyPermutationToVector(numThreads, interchangeVector); 821 } 822 823 FailureOr<TilingResult> tilingResult; 824 // 4. Define the lambda function used later to generate the body of the 825 // innermost tiled loop. 826 YieldTiledValuesFn innerYieldTiledValuesFn = 827 [&](RewriterBase &rewriter, Location loc, ValueRange ivs, 828 ValueRange regionIterArgs, SmallVector<Value> &tiledResults, 829 SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 830 SmallVector<SmallVector<OpFoldResult>> &resultSizes) 831 -> LogicalResult { 832 // 4a. Compute the `offsets` and `sizes` to use for tiling. 833 SmallVector<OpFoldResult> offsets, sizes; 834 std::tie(offsets, sizes) = getTileOffsetAndSizes( 835 rewriter, loc, ivs, iterationDomain, tileSizes, numThreads); 836 837 // 4b. If interchange was provided, apply inverse of the interchange 838 // to get back the offsets/sizes in the order to be specified. 839 if (!interchangeVector.empty()) { 840 auto inversePermutation = invertPermutationVector(interchangeVector); 841 applyPermutationToVector(offsets, inversePermutation); 842 applyPermutationToVector(sizes, inversePermutation); 843 } 844 845 // 5. Generate the tiled implementation within the inner most loop. 846 847 // 5a. Clone the operation within the loop body. 848 auto clonedOp = cast<TilingInterface>( 849 cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs)); 850 851 // 5b. Early return cloned op if tiling is not happening. We can not return 852 // the original op because it could lead to 853 // `rewriter.replaceOp(op, op->getResults())` and users would get crash. 854 if (llvm::all_of(tileSizes, isZeroIndex)) { 855 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); 856 tilingResult = 857 TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()}; 858 return success(); 859 } 860 861 // 5c. Tile the cloned operation. 862 tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes); 863 if (failed(tilingResult)) { 864 rewriter.eraseOp(clonedOp); 865 return op.emitOpError("faild to tile operation"); 866 } 867 868 // 5d. Delete the cloned operation. 869 rewriter.eraseOp(clonedOp); 870 871 // 5e. Compute the offsets at which the result values are to be inserted 872 // back into its destinations. 873 for (auto [index, tiledValue] : 874 llvm::enumerate(tilingResult->tiledValues)) { 875 tiledResults.push_back(tiledValue); 876 SmallVector<OpFoldResult> resultOffset, resultSize; 877 if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes, 878 resultOffset, resultSize))) { 879 for (auto op : tilingResult->tiledOps) { 880 rewriter.eraseOp(op); 881 } 882 return rewriter.notifyMatchFailure( 883 op, "failed to get slice of result produced"); 884 } 885 resultOffsets.emplace_back(std::move(resultOffset)); 886 resultSizes.emplace_back(std::move(resultSize)); 887 } 888 889 return success(); 890 }; 891 892 // 6. Find the destination tensors to use for the operation. 893 SmallVector<Value> destinationTensors; 894 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 895 destinationTensors))) { 896 return rewriter.notifyMatchFailure(op, 897 "unable to create destination tensors"); 898 } 899 900 // 7. Generate the tiled loops nest using the callback defined above. 901 SmallVector<LoopLikeOpInterface> loops; 902 if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain, 903 tileSizes, numThreads, destinationTensors, 904 innerYieldTiledValuesFn, loops))) 905 return op.emitOpError("failed to generate tiling loops"); 906 assert(succeeded(tilingResult) && 907 "expected tiling result to be computed after loop generation"); 908 909 // If loops are empty, the tiled op is used as the replacement for the untiled 910 // op. 911 if (loops.empty()) { 912 return scf::SCFTilingResult{tilingResult->tiledOps, loops, 913 tilingResult->tiledValues}; 914 } 915 916 SmallVector<Value> replacements = llvm::map_to_vector( 917 loops.front()->getResults(), [](OpResult r) -> Value { return r; }); 918 return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements}; 919 } 920 921 FailureOr<scf::SCFReductionTilingResult> 922 mlir::scf::tileReductionUsingScf(RewriterBase &b, 923 PartialReductionOpInterface op, 924 ArrayRef<OpFoldResult> tileSizes) { 925 Location loc = op.getLoc(); 926 // Ops implementing PartialReductionOpInterface are expected to implement 927 // TilingInterface. 928 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 929 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 930 auto tileSizesVector = llvm::to_vector(tileSizes); 931 if (tileSizesVector.size() < iterationDomain.size()) { 932 auto zero = b.getIndexAttr(0); 933 tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(), 934 zero); 935 } 936 SmallVector<utils::IteratorType> iterators = 937 tilingInterfaceOp.getLoopIteratorTypes(); 938 939 SmallVector<int> reductionDims; 940 for (auto [idx, iteratorType] : 941 llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { 942 if (iteratorType == utils::IteratorType::reduction) 943 reductionDims.push_back(idx); 944 } 945 946 // 2. create the inital tensor value. 947 FailureOr<SmallVector<Value>> maybeInitTensors = 948 op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector, 949 reductionDims); 950 if (failed(maybeInitTensors)) { 951 return b.notifyMatchFailure(op, "Failed to create initial tensors."); 952 } 953 SmallVector<Value> &initTensors = maybeInitTensors.value(); 954 955 // 3. Define the callback to use for generating the inner most tile loop body. 956 SmallVector<Operation *> parallelTiledOps; 957 auto innerYieldTiledValuesFn = 958 [&](RewriterBase &rewriter, Location loc, ValueRange ivs, 959 ValueRange regionIterArgs, SmallVector<Value> &tiledResult, 960 SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 961 SmallVector<SmallVector<OpFoldResult>> &resultSizes) 962 -> LogicalResult { 963 SmallVector<OpFoldResult> offsets, sizes; 964 { 965 int materializedLoopNum = 0; 966 for (auto [tileSize, loopRange] : 967 llvm::zip_equal(tileSizesVector, iterationDomain)) { 968 if (isConstantIntValue(tileSize, 0)) { 969 offsets.push_back(loopRange.offset); 970 sizes.push_back(loopRange.size); 971 continue; 972 } 973 Value iv = ivs[materializedLoopNum++]; 974 offsets.push_back(iv); 975 sizes.push_back( 976 getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize)); 977 } 978 } 979 980 // 4a. Clone the operation. 981 { 982 auto clonedOp = cast<PartialReductionOpInterface>( 983 cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs)); 984 985 // 4b. Tile the cloned operation. 986 FailureOr<TilingResult> partialTilingResult = 987 clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets, 988 sizes, reductionDims); 989 if (failed(partialTilingResult)) { 990 return failure(); 991 } 992 std::swap(parallelTiledOps, partialTilingResult->tiledOps); 993 std::swap(tiledResult, partialTilingResult->tiledValues); 994 995 // 4c. Delete the cloned operation. 996 b.eraseOp(clonedOp); 997 } 998 999 // 4d. Compute the offsets and sizes needed to insert the result of the 1000 // tiled value back into destination before yielding the destination. 1001 for (auto result : tiledResult) { 1002 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); 1003 resultOffsets.emplace_back(std::move(outOffsets)); 1004 1005 SmallVector<OpFoldResult> outSizes; 1006 for (size_t i = 0; i < offsets.size(); i++) { 1007 outSizes.push_back(tensor::getMixedSize(b, loc, result, i)); 1008 } 1009 resultSizes.emplace_back(std::move(outSizes)); 1010 } 1011 return success(); 1012 }; 1013 1014 // 5. Generate the tiled implementation using the destination tensors. 1015 SmallVector<LoopLikeOpInterface> loops; 1016 scf::SCFTilingOptions options; 1017 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); 1018 if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector, 1019 /*numThreads=*/ArrayRef<OpFoldResult>{}, 1020 initTensors, innerYieldTiledValuesFn, loops))) 1021 return b.notifyMatchFailure(op, "failed to tile for parallel reduction"); 1022 1023 SmallVector<Value> replacements = llvm::map_to_vector( 1024 loops.front()->getResults(), [](OpResult r) -> Value { return r; }); 1025 1026 // 5. Apply the merge reduction to combine all the partial values. 1027 b.setInsertionPointAfter(*loops.begin()); 1028 FailureOr<MergeResult> mergeResult = 1029 op.mergeReductions(b, loc, replacements, reductionDims); 1030 if (failed(mergeResult)) { 1031 return failure(); 1032 } 1033 b.replaceOp(op, mergeResult->replacements); 1034 1035 SCFReductionTilingResult reductionTilingResult; 1036 std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps); 1037 std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps); 1038 std::swap(reductionTilingResult.initialValues, initTensors); 1039 std::swap(reductionTilingResult.loops, loops); 1040 std::swap(reductionTilingResult.replacements, mergeResult->replacements); 1041 1042 return reductionTilingResult; 1043 } 1044 1045 //===----------------------------------------------------------------------===// 1046 // tileConsumerAndFuseProducersUsingSCF implementation. 1047 //===----------------------------------------------------------------------===// 1048 1049 /// Return the untiled producer whose slice is used in a tiled consumer. The 1050 /// method traverses the tile loop nest (`loops`) if needed, and returns the 1051 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 1052 /// indicates that this is a destination operand of the consumer. If there was 1053 /// no loop traversal needed, the second value of the returned tuple is empty. 1054 static std::tuple<OpResult, std::optional<OpOperand *>> 1055 getUntiledProducerFromSliceSource(OpOperand *source, 1056 ArrayRef<LoopLikeOpInterface> loops) { 1057 std::optional<OpOperand *> destinationIterArg; 1058 auto loopIt = loops.rbegin(); 1059 while (auto iterArg = dyn_cast<BlockArgument>(source->get())) { 1060 auto loop = *loopIt; 1061 if (iterArg.getOwner()->getParentOp() != loop) 1062 break; 1063 source = loop.getTiedLoopInit(iterArg); 1064 loopIt++; 1065 } 1066 if (loopIt == loops.rend()) 1067 destinationIterArg = source; 1068 return {dyn_cast<OpResult>(source->get()), destinationIterArg}; 1069 } 1070 1071 /// Implementation of fusing producer of a single slice by computing the 1072 /// slice of the producer in-place. 1073 std::optional<scf::SCFFuseProducerOfSliceResult> 1074 mlir::scf::tileAndFuseProducerOfSlice( 1075 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, 1076 MutableArrayRef<LoopLikeOpInterface> loops) { 1077 // 1. Get the producer of the source (potentially walking through 1078 // `iter_args` of nested `scf.for`) 1079 auto [fusableProducer, destinationInitArg] = 1080 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), 1081 loops); 1082 if (!fusableProducer) 1083 return std::nullopt; 1084 unsigned resultNumber = fusableProducer.getResultNumber(); 1085 1086 OpBuilder::InsertionGuard g(rewriter); 1087 rewriter.setInsertionPoint(candidateSliceOp); 1088 1089 // 2. Clone the fused producer 1090 // 2a. Compute the destination operands to use for the cloned operation. 1091 SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors; 1092 Operation *fusableProducerOp = fusableProducer.getOwner(); 1093 if (isa<DestinationStyleOpInterface>(fusableProducerOp) && 1094 failed(tensor::getOrCreateDestinations( 1095 rewriter, fusableProducerOp->getLoc(), fusableProducerOp, 1096 origDestinationTensors))) 1097 return std::nullopt; 1098 1099 clonedOpDestinationTensors = origDestinationTensors; 1100 if (destinationInitArg && 1101 isa<DestinationStyleOpInterface>(fusableProducerOp)) { 1102 // 2b. If the producer is also destination style, then to maintain the 1103 // destination passing style, update the destination of the producer to be 1104 // the source of the slice. 1105 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource(); 1106 } 1107 // 2c. Clone the fused producer. 1108 Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs( 1109 rewriter, fusableProducerOp, clonedOpDestinationTensors); 1110 // 2d. Update the source of the candidateSlice to be the cloned producer. 1111 // Easier to just clone the slice with different source since replacements 1112 // and DCE of cloned ops becomes easier 1113 SmallVector<Value> candidateSliceOpOperands = 1114 llvm::to_vector(candidateSliceOp->getOperands()); 1115 candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber); 1116 tensor::ExtractSliceOp clonedCandidateSliceOp = 1117 mlir::clone(rewriter, candidateSliceOp, 1118 candidateSliceOp->getResultTypes(), candidateSliceOpOperands); 1119 1120 // 3. Generate the tiled implementation of the producer of the source 1121 FailureOr<TilingResult> tileAndFuseResult = 1122 tensor::replaceExtractSliceWithTiledProducer( 1123 rewriter, clonedCandidateSliceOp, 1124 clonedProducerOp->getResult(resultNumber)); 1125 if (failed(tileAndFuseResult)) 1126 return std::nullopt; 1127 // Note: Do not delete the candidateSliceOp, since its passed in from the 1128 // caller. 1129 rewriter.replaceAllUsesWith(candidateSliceOp, 1130 tileAndFuseResult->tiledValues[0]); 1131 rewriter.eraseOp(clonedCandidateSliceOp); 1132 rewriter.eraseOp(clonedProducerOp); 1133 1134 // 3. If the slice is for a destination operand, for example, 1135 // 1136 // ```mlir 1137 // %0 = linalg.init 1138 // %1 = linalg.fill .. outs(%0 : ) 1139 // %2 = scf.for .. iter_args(%arg0 = %1) { 1140 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 1141 // %4 = tensor.extract_slice %arg1 [..] 1142 // .. = linalg.matmul .. outs(%4 : ) 1143 // } 1144 // } 1145 // ``` 1146 // 1147 // the IR is currently 1148 // 1149 // ``` 1150 // %0 = linalg.init 1151 // %1 = linalg.fill 1152 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 1153 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 1154 // %4 = tensor.extract_slice %arg1[..] 1155 // %5 = linalg.fill .. outs(%4 : ) 1156 // .. = linalg.matmul .. outs(%5 : ) 1157 // } 1158 // } 1159 // ``` 1160 // 1161 // The untiled `linalg.fill` is still used as the `init_value` since it 1162 // was originally a destination operand of the untiled `linalg.matmul`. 1163 // When fusing an operand that is a destination operand, the iter_arg of 1164 // the outer most loop should be changed to use the destination of the 1165 // fused operation. With this the IR will be. 1166 // 1167 // ``` 1168 // %0 = linalg.init 1169 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 1170 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 1171 // %3 = tensor.extract_slice %arg1[..] 1172 // %4 = linalg.fill .. outs(%3 : ) 1173 // .. = linalg.matmul .. outs(%4 : ) 1174 // } 1175 // } 1176 // ``` 1177 if (destinationInitArg && 1178 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) { 1179 loops.front() 1180 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] 1181 .set(origDestinationTensors[resultNumber]); 1182 } 1183 return scf::SCFFuseProducerOfSliceResult{fusableProducer, 1184 tileAndFuseResult->tiledValues[0], 1185 tileAndFuseResult->tiledOps}; 1186 } 1187 1188 /// Reconstruct the fused producer from within the tiled-and-fused code. 1189 LogicalResult mlir::scf::yieldReplacementForFusedProducer( 1190 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, 1191 scf::SCFFuseProducerOfSliceResult fusedProducerInfo, 1192 MutableArrayRef<LoopLikeOpInterface> loops, 1193 ArrayRef<unsigned> yieldResultNumber) { 1194 if (loops.empty()) 1195 return success(); 1196 1197 Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(), 1198 *tiledOwner = fusedProducerInfo.tiledOps[0]; 1199 1200 Location loc = originalOwner->getLoc(); 1201 // a. collect all init Value to be appended 1202 SmallVector<unsigned> initNumberList = 1203 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>( 1204 0, originalOwner->getNumResults())) 1205 : llvm::to_vector(yieldResultNumber); 1206 SmallVector<Value> initValueList; 1207 for (const auto &resultNumber : initNumberList) { 1208 FailureOr<Value> initValue = tensor::getOrCreateDestination( 1209 rewriter, loc, originalOwner->getResult(resultNumber)); 1210 if (succeeded(initValue)) { 1211 initValueList.push_back(initValue.value()); 1212 } else { 1213 return failure(); 1214 } 1215 } 1216 1217 YieldTiledValuesFn newYieldValuesFn = 1218 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, 1219 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, 1220 SmallVector<SmallVector<OpFoldResult>> &tiledOffset, 1221 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { 1222 OpBuilder::InsertionGuard g(innerRewriter); 1223 1224 // get sliceOp tile information 1225 SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(), 1226 sliceSizes = sliceOp.getMixedSizes(); 1227 1228 // expect all strides of sliceOp being 1 1229 if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { 1230 return !isConstantIntValue(ofr, 1); 1231 })) 1232 return failure(); 1233 1234 unsigned sliceResultNumber = 1235 fusedProducerInfo.origProducer.getResultNumber(); 1236 1237 auto tilableOp = cast<TilingInterface>(originalOwner); 1238 // b. get iterDomain Offset and Sizes based on sliceOp tile 1239 SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes; 1240 // skip tensor.pack/unpack/pad, which expects single opResult 1241 if (tilableOp->getNumResults() > 1 && 1242 failed(tilableOp.getIterationDomainTileFromResultTile( 1243 rewriter, sliceResultNumber, sliceOffset, sliceSizes, 1244 iterDomainOffset, iterDomainSizes))) { 1245 // In theory, it is unnecessary to raise an error here. Actually although 1246 // it fails to reconstruct the result tensor, it should not broke current 1247 // fusion anyway. The reason why we must return failure currently is that 1248 // the callback function `newYieldValuesFn` will be called after new init 1249 // operand(s) has already been appended. It will take more refactoring to 1250 // make sure the init operands are added consistently in the future. For 1251 // more details, please refer to: 1252 // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814 1253 return failure(); 1254 } 1255 1256 // c. calculate offsets and sizes info of all OpResults respectively based 1257 // on iteration Domain Tile 1258 SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList; 1259 for (const auto &resultNumber : initNumberList) { 1260 if (resultNumber == sliceResultNumber) { 1261 offsetList.push_back(sliceOffset); 1262 sizesList.push_back(sliceSizes); 1263 } else { 1264 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty()); 1265 // infer result tile according to the iteration domain tile 1266 SmallVector<OpFoldResult> offset, sizes; 1267 if (failed(tilableOp.getResultTilePosition( 1268 rewriter, resultNumber, iterDomainOffset, iterDomainSizes, 1269 offset, sizes))) { 1270 return failure(); 1271 } 1272 offsetList.push_back(offset); 1273 sizesList.push_back(sizes); 1274 } 1275 } 1276 1277 // d. create `extract_slice` for `iter_args` for DPS operation if necessary 1278 if (auto tiledDestStyleOp = 1279 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) { 1280 rewriter.setInsertionPoint(tiledDestStyleOp); 1281 for (const auto &&[index, newRegionArg] : 1282 llvm::enumerate(newRegionIterArgs)) { 1283 auto destSlice = rewriter.create<tensor::ExtractSliceOp>( 1284 loc, newRegionArg, offsetList[index], sizesList[index], 1285 SmallVector<OpFoldResult>(offsetList[index].size(), 1286 rewriter.getIndexAttr(1))); 1287 unsigned resultNumber = initNumberList[index]; 1288 rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { 1289 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); 1290 }); 1291 } 1292 } 1293 1294 // e. prepare tiled offset and sizes for later `insert_slice` creation by 1295 // caller 1296 Block *block = rewriter.getInsertionPoint()->getBlock(); 1297 rewriter.setInsertionPoint(block->getTerminator()); 1298 for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) { 1299 tiledResult.push_back(tiledOwner->getResult(resultNumber)); 1300 tiledOffset.emplace_back(offsetList[index]); 1301 tiledSizes.emplace_back(sizesList[index]); 1302 } 1303 return success(); 1304 }; 1305 1306 return addInitOperandsToLoopNest(rewriter, loops, initValueList, 1307 newYieldValuesFn); 1308 } 1309 1310 /// Implementation of tile consumer and fuse producer greedily. 1311 FailureOr<scf::SCFTileAndFuseResult> 1312 mlir::scf::tileConsumerAndFuseProducersUsingSCF( 1313 RewriterBase &rewriter, TilingInterface consumer, 1314 const scf::SCFTileAndFuseOptions &options) { 1315 // This transformation is only valid for ops that return values (i.e. not 1316 // valid to use with operations that have memref operands). 1317 if (!consumer->getNumResults()) { 1318 return rewriter.notifyMatchFailure( 1319 consumer, "invalid pattern for op with no results"); 1320 } 1321 1322 // 1. First tile the consumer. 1323 SetVector<Operation *> fusedProducers, tiledAndFusedOps; 1324 llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum; 1325 1326 FailureOr<scf::SCFTilingResult> tilingResult = 1327 tileUsingSCF(rewriter, consumer, options.tilingOptions); 1328 1329 if (failed(tilingResult)) 1330 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 1331 for (auto *tiledOp : tilingResult->tiledOps) 1332 tiledAndFusedOps.insert(tiledOp); 1333 1334 // If there are no loops generated, fusion is immaterial. 1335 auto &loops = tilingResult->loops; 1336 if (loops.empty()) { 1337 DenseMap<Value, Value> replacements; 1338 for (auto [origVal, replacement] : 1339 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) { 1340 replacements[origVal] = replacement; 1341 } 1342 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, 1343 replacements}; 1344 } 1345 1346 // To keep track of replacements for now just record the map from the original 1347 // untiled value to the result number of the for loop. Since the loop gets 1348 // potentially replaced during fusion, keeping the value directly wont work. 1349 DenseMap<Value, size_t> origValToResultNumber; 1350 for (auto [index, result] : llvm::enumerate(consumer->getResults())) { 1351 origValToResultNumber[result] = index; 1352 } 1353 1354 // 2. Typically, the operands of the tiled operation are slices of the 1355 // operands of the untiled operation. These are expressed in IR using 1356 // `tensor.extract_slice` operations with source being the operands of the 1357 // untiled operation. Create a worklist of these `tensor.extract_slice` 1358 // operations. If the producers of the source of the `tensor.extract_slice` 1359 // can be tiled such that the tiled value is generated in-place, that 1360 // effectively tiles + fuses the operations. 1361 auto addCandidateSlices = [](Operation *fusedOp, 1362 std::deque<tensor::ExtractSliceOp> &candidates) { 1363 for (Value operand : fusedOp->getOperands()) 1364 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 1365 candidates.push_back(sliceOp); 1366 }; 1367 1368 std::deque<tensor::ExtractSliceOp> candidates; 1369 addCandidateSlices(tiledAndFusedOps.back(), candidates); 1370 OpBuilder::InsertionGuard g(rewriter); 1371 while (!candidates.empty()) { 1372 // Traverse the slices in BFS fashion. 1373 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 1374 candidates.pop_front(); 1375 1376 // Find the original producer of the slice. 1377 auto [fusableProducer, destinationInitArg] = 1378 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), 1379 loops); 1380 if (!fusableProducer) 1381 continue; 1382 1383 auto [fuseSlice, yieldReplacement] = options.fusionControlFn( 1384 candidateSliceOp, fusableProducer, destinationInitArg.has_value()); 1385 if (!fuseSlice) 1386 continue; 1387 1388 // The operands of the fused producer might themselved be slices of 1389 // values produced by operations that implement the `TilingInterface`. 1390 // Add these operations to the worklist. 1391 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult = 1392 tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops); 1393 if (!fusedResult) 1394 continue; 1395 1396 if (yieldReplacement) { 1397 // Reconstruct and yield all opResult of fusableProducerOp by default. The 1398 // caller can specific which one to yield by designating optional argument 1399 // named `yieldResultNumber` of `yieldReplacementForFusedProducer`. 1400 Operation *fusableProducerOp = fusableProducer.getOwner(); 1401 if (failed(yieldReplacementForFusedProducer( 1402 rewriter, candidateSliceOp, fusedResult.value(), loops))) { 1403 return rewriter.notifyMatchFailure( 1404 fusableProducerOp, "failed to replacement value for this " 1405 "operation from within the tiled loop"); 1406 } 1407 for (auto [index, result] : 1408 llvm::enumerate(fusableProducerOp->getResults())) { 1409 origValToResultNumber[result] = loops.front()->getNumResults() - 1410 fusableProducerOp->getNumResults() + 1411 index; 1412 } 1413 } 1414 1415 if (Operation *tiledAndFusedOp = 1416 fusedResult->tiledAndFusedProducer.getDefiningOp()) { 1417 fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); 1418 tiledAndFusedOps.insert(tiledAndFusedOp); 1419 addCandidateSlices(tiledAndFusedOp, candidates); 1420 } 1421 } 1422 1423 DenseMap<Value, Value> replacements; 1424 for (auto [origVal, resultNumber] : origValToResultNumber) { 1425 replacements[origVal] = loops.front()->getResult(resultNumber); 1426 } 1427 1428 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, 1429 replacements}; 1430 } 1431 1432 //===----------------------------------------------------------------------===// 1433 // tileAndFuseConsumerUsingSCF implementation. 1434 //===----------------------------------------------------------------------===// 1435 1436 /// A utility function that checks whether the only use of the result of a 1437 /// tensor.insert_slice op is in a scf.yield op. 1438 static LogicalResult 1439 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { 1440 Value result = candidateSliceOp.getResult(); 1441 Value::use_range uses = result.getUses(); 1442 if (!llvm::hasSingleElement(uses)) { 1443 LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n"); 1444 return failure(); 1445 } 1446 OpOperand &operandUse = (*uses.begin()); 1447 Operation *userOp = operandUse.getOwner(); 1448 if (!isa<scf::YieldOp>(userOp)) { 1449 LLVM_DEBUG(llvm::dbgs() 1450 << "Expected scf.yield to be the only user, but got -> " 1451 << (*userOp)); 1452 return failure(); 1453 } 1454 if (result.getDefiningOp()->getBlock() != userOp->getBlock()) { 1455 LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to " 1456 "be in the same block\n"); 1457 return failure(); 1458 } 1459 return success(); 1460 } 1461 1462 /// Fetches the OpOperand of the only user (and use) of the value `val` which 1463 /// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns 1464 /// failure otherwise. 1465 static FailureOr<OpOperand *> getConsumerFromUses(Value val, 1466 Block *containingOpBlock) { 1467 // Step 1. Check that the value has exactly one use. 1468 if (!llvm::hasSingleElement(val.getUses())) 1469 return failure(); 1470 // Step 2. Get uses. 1471 OpOperand &operand = (*val.getUses().begin()); 1472 Operation *consumerOp = operand.getOwner(); 1473 // TODO: We have to init result of consumer before scf.for, use 1474 // DestinationStyleOpInterface to get result shape from init for now. 1475 // Add support for other op such as op has InferTypeOpInterface. 1476 if (!isa<TilingInterface>(consumerOp) || 1477 !isa<DestinationStyleOpInterface>(consumerOp)) 1478 return failure(); 1479 if (containingOpBlock != consumerOp->getBlock()) 1480 return failure(); 1481 return &operand; 1482 } 1483 1484 /// Find the perfectly nested loops outside of given loop(included) sorted from 1485 /// outer to inner. 1486 /// 1487 /// E.g. 1488 /// 1489 /// ``` 1490 /// %0 = scf.for() 1491 /// %1 = scf.for() 1492 /// %2 = scf.for() 1493 /// %3 = ... 1494 /// yield %3 1495 /// yield %2 1496 /// yield %1 1497 /// ``` 1498 /// 1499 /// This function will return three perfectly nested loops: %0 + %1 + %2, when 1500 /// target inner loop is %2. 1501 static SmallVector<scf::ForOp> 1502 getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) { 1503 SmallVector<scf::ForOp> nestLoops = {loop}; 1504 auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp()); 1505 1506 // Check if it is the ForOp that yield the result of inner loop. 1507 auto isForOpYieldResultOfInnerLoop = 1508 [](scf::ForOp outerLoop) -> LogicalResult { 1509 Block *body = outerLoop.getBody(); 1510 if (!llvm::hasSingleElement(body->without_terminator())) 1511 return failure(); 1512 auto yieldOp = cast<scf::YieldOp>(body->getTerminator()); 1513 auto innerForOp = dyn_cast<scf::ForOp>(body->front()); 1514 if (!innerForOp) 1515 return failure(); 1516 // All of innerForOp results should be yielded. 1517 return success(innerForOp->getNumResults() == yieldOp->getNumOperands()); 1518 }; 1519 1520 while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) { 1521 nestLoops.push_back(outerLoop); 1522 outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp()); 1523 } 1524 // sorted from outer to inner 1525 return {nestLoops.rbegin(), nestLoops.rend()}; 1526 } 1527 1528 /// Fetch the untiled consumer of a scf.for's result which is yielded by a 1529 /// tensor.insert_slice. This function makes the following assumptions : 1530 /// 1. tensor.insert_slice has scf.yield as its only user. 1531 /// 2. scf.for's corresponding result has only one use. 1532 static FailureOr<OpOperand *> 1533 getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) { 1534 if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) 1535 return failure(); 1536 Value sliceResult = candidateSliceOp.getResult(); 1537 // Step 1. Fetch the corresponding output. 1538 OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); 1539 unsigned resultNumber = yieldOpOperand.getOperandNumber(); 1540 // Step 2. Check containing op is scf.for. 1541 Operation *containingOp = candidateSliceOp->getParentOp(); 1542 auto forOp = dyn_cast<scf::ForOp>(containingOp); 1543 if (!forOp) 1544 return failure(); 1545 scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front(); 1546 Value resultingValue = topLevelForOp->getResult(resultNumber); 1547 1548 return getConsumerFromUses(resultingValue, topLevelForOp->getBlock()); 1549 } 1550 1551 /// Fetch the first untiled consumer of a scf.forall's result which is yielded 1552 /// by a tensor.parallel_insert_slice. 1553 static FailureOr<OpOperand *> 1554 getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) { 1555 // Step 1. Fetch the corresponding output 1556 Value sliceDest = candidateSliceOp.getDest(); 1557 auto iterArg = dyn_cast<BlockArgument>(sliceDest); 1558 if (!iterArg) 1559 return failure(); 1560 Operation *containingOp = iterArg.getOwner()->getParentOp(); 1561 if (containingOp != candidateSliceOp->getParentOp()->getParentOp()) 1562 return failure(); 1563 // Step 2. Check that the containing op is scf.forall. 1564 auto forallOp = dyn_cast<scf::ForallOp>(containingOp); 1565 if (!forallOp) 1566 return failure(); 1567 Value resultingValue = 1568 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)); 1569 1570 return getConsumerFromUses(resultingValue, containingOp->getBlock()); 1571 } 1572 1573 /// This utility currently checks whether the loop either :- 1574 /// 1. Yields exactly one result. 1575 /// 2. Has consumer op as its first user and other users to be in the same 1576 /// containing block as that of consumer op's. Currently we clone the loop op 1577 /// right before the consumer op in order to maintain a valid def-use chain. 1578 /// This utility thus helps ensuring that no invalid IR is formed due to the 1579 /// same. 1580 static LogicalResult checkAssumptionForLoop(Operation *loopOp, 1581 Operation *consumerOp) { 1582 // Check if the loop op yields one result. 1583 if (loopOp->getNumResults() == 1) 1584 return success(); 1585 // Check if the consumerOp is the first user of the loopOp and if other users 1586 // are in the same containing block as that of consumer op's. 1587 Block *parentBlock = consumerOp->getBlock(); 1588 for (Operation *userOp : loopOp->getUsers()) { 1589 if (userOp == consumerOp) 1590 continue; 1591 if (parentBlock != userOp->getBlock() || 1592 !consumerOp->isBeforeInBlock(userOp)) 1593 return failure(); 1594 } 1595 return success(); 1596 } 1597 1598 /// A utility to fetch an untiled consumer of 1599 /// tensor.insert_slice/tensor.parallel_insert_slice. 1600 static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) { 1601 if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) { 1602 return getUntiledConsumerFromSlice(insertSlice); 1603 } else if (auto parallelInsertSlice = 1604 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) { 1605 return getUntiledConsumerFromSlice(parallelInsertSlice); 1606 } else { 1607 return failure(); 1608 } 1609 } 1610 1611 /// Implementation of fusing consumer of a single slice by computing the 1612 /// slice of the consumer in-place for scf loop. 1613 FailureOr<scf::SCFFuseConsumerOfSliceResult> 1614 mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, 1615 Operation *candidateSliceOp) { 1616 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>( 1617 candidateSliceOp)) 1618 return failure(); 1619 1620 bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp); 1621 1622 // 1. Get the consumer of scf.for for the result yielded by 1623 // tensor.insert_slice/parallel_insert_slice. 1624 FailureOr<OpOperand *> maybeConsumerOpOperand = 1625 getUntiledConsumerFromSlice(candidateSliceOp); 1626 if (failed(maybeConsumerOpOperand)) { 1627 return rewriter.notifyMatchFailure(candidateSliceOp, 1628 "could not fetch consumer to fuse"); 1629 } 1630 OpOperand *consumerOpOperand = *maybeConsumerOpOperand; 1631 Operation *consumerOp = consumerOpOperand->getOwner(); 1632 unsigned operandNumber = consumerOpOperand->getOperandNumber(); 1633 unsigned resultNumber = 0; 1634 if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) { 1635 resultNumber = producerResult.getResultNumber(); 1636 } else { 1637 return rewriter.notifyMatchFailure( 1638 consumerOp, "consumer op's operand doesn't seem to be an OpResult"); 1639 } 1640 1641 // There are two possible cases regarding `oldLoopOp` here: 1642 // 1. single `scf.forall` or `scf.for`. 1643 // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the 1644 // top-level loop is the outer-most one of these nested loops. 1645 LoopLikeOpInterface innerMostLoop = 1646 candidateSliceOp->getParentOfType<LoopLikeOpInterface>(); 1647 SmallVector<LoopLikeOpInterface> nestedLoops; 1648 if (isInsertSliceOp) { 1649 nestedLoops = llvm::map_to_vector( 1650 getPerfectlyNestedLoopsOutsideOf( 1651 cast<scf::ForOp>(innerMostLoop.getOperation())), 1652 [](scf::ForOp forOp) { 1653 return cast<LoopLikeOpInterface>(forOp.getOperation()); 1654 }); 1655 } else { 1656 nestedLoops = {innerMostLoop}; 1657 } 1658 1659 LoopLikeOpInterface outerMostLoop = nestedLoops.front(); 1660 1661 if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) { 1662 return rewriter.notifyMatchFailure( 1663 outerMostLoop, 1664 "containing loop op should either yield just one value or " 1665 "have the consumer op as its first user"); 1666 } 1667 1668 OpBuilder::InsertionGuard g(rewriter); 1669 1670 // 2. Check consumer is not using scf loop's output as init. 1671 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); 1672 if (!dstOp) 1673 return rewriter.notifyMatchFailure(consumerOp, 1674 "consumer op is not DPS operation"); 1675 SmallVector<Value> dpsInits = 1676 llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; }); 1677 if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) { 1678 return rewriter.notifyMatchFailure( 1679 consumerOp, 1680 "consumer op taking the result of scf.for as init is not supported"); 1681 } 1682 SmallVector<Value> newInits = dpsInits; 1683 1684 Location loc = outerMostLoop->getLoc(); 1685 1686 // 3. Move the whole loop structure right before consumer Op, the dominance 1687 // should be already ensured by `checkAssumptionForLoop`. 1688 rewriter.moveOpBefore(outerMostLoop, consumerOp); 1689 1690 // 4. Set insertion point before terminator op of the loop and create a new 1691 // tensor.insert_slice. In the scf.for case this is a clone of the 1692 // candidateSliceOp whereas in the scf.forall case this is created from the 1693 // operands of tensor.parallel_insert_slice. 1694 tensor::InsertSliceOp clonedInsertSliceOp; 1695 if (auto sliceOp = 1696 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) { 1697 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation()); 1698 rewriter.setInsertionPoint(newForallOp.getTerminator()); 1699 clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>( 1700 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(), 1701 sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); 1702 } else { 1703 rewriter.setInsertionPoint(candidateSliceOp); 1704 clonedInsertSliceOp = 1705 cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp)); 1706 } 1707 1708 // 5.a. Clone consumer op. 1709 auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp)); 1710 1711 // 5.b. Replace all uses of the loop result with the result of the cloned 1712 // tensor.insert_slice. 1713 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber); 1714 rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { 1715 operandToReplace.set(clonedInsertSliceOp.getResult()); 1716 }); 1717 1718 // 6. Perform tiling of the cloned consumer and replace the operand at 1719 // `operandNumber` with the source of the cloned tensor.insert_slice op. 1720 auto ossSliceOp = 1721 cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation()); 1722 FailureOr<TilingResult> tileAndFuseResult = 1723 tensor::replaceInsertSliceWithTiledConsumer( 1724 rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber)); 1725 if (failed(tileAndFuseResult)) { 1726 return failure(); 1727 } 1728 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]); 1729 rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber), 1730 clonedInsertSliceOp.getSource()); 1731 1732 // 7. Reconstruct [nested] loop with new inits. 1733 YieldTiledValuesFn newYieldValuesFn = 1734 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, 1735 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, 1736 SmallVector<SmallVector<OpFoldResult>> &tiledOffset, 1737 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { 1738 OpBuilder::InsertionGuard g(innerRewriter); 1739 // 8. Set inner insertPoint right before tiled consumer op. 1740 innerRewriter.setInsertionPoint(tiledConsumerOp); 1741 1742 SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets(); 1743 SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes(); 1744 SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides(); 1745 1746 // 9. Check all insert stride is 1. 1747 if (llvm::any_of(strides, [](OpFoldResult stride) { 1748 return !isConstantIntValue(stride, 1); 1749 })) { 1750 return rewriter.notifyMatchFailure( 1751 candidateSliceOp, "containingOp's result yield with stride"); 1752 } 1753 1754 // 10. Try to get iter domain position from input position. 1755 SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes; 1756 if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile( 1757 rewriter, operandNumber, offsets, sizes, iterDomainOffsets, 1758 iterDomainSizes))) { 1759 return rewriter.notifyMatchFailure( 1760 tiledConsumerOp, 1761 "can't get iter domain position from input position"); 1762 } 1763 1764 // 11. Try to fetch the offset and size for all results of the cloned 1765 // consumer. This would then be used to form the corresponding 1766 // tensor.insert_slice/parallel_insert_slice later. 1767 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults(); 1768 SmallVector<SmallVector<OpFoldResult>> resultOffsets( 1769 totalNumResultsOfConsumer); 1770 SmallVector<SmallVector<OpFoldResult>> resultSizes( 1771 totalNumResultsOfConsumer); 1772 for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) { 1773 if (failed(tiledConsumerOp.getResultTilePosition( 1774 rewriter, idx, iterDomainOffsets, iterDomainSizes, 1775 resultOffsets[idx], resultSizes[idx]))) { 1776 return rewriter.notifyMatchFailure( 1777 tiledConsumerOp, 1778 "can't get result domain position from iter domain position"); 1779 } 1780 } 1781 1782 // 12. Create `extract_slice` for `iter_args` for DPS operation if 1783 // necessary. 1784 if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>( 1785 tiledConsumerOp.getOperation())) { 1786 rewriter.setInsertionPoint(tiledDestStyleOp); 1787 for (const auto &&[index, newRegionArg] : 1788 llvm::enumerate(newRegionIterArgs)) { 1789 auto destSlice = rewriter.create<tensor::ExtractSliceOp>( 1790 loc, newRegionArg, resultOffsets[index], resultSizes[index], 1791 SmallVector<OpFoldResult>(resultOffsets[index].size(), 1792 rewriter.getIndexAttr(1))); 1793 // Make a copy of index to avoid a capturing structured binding, which 1794 // is a C++20 extension. 1795 auto dstNumber = index; 1796 rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { 1797 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice); 1798 }); 1799 } 1800 } 1801 1802 // 13. Prepare tiled offset and sizes for later `insert_slice` creation by 1803 // caller. 1804 Block *block = rewriter.getInsertionPoint()->getBlock(); 1805 rewriter.setInsertionPoint(block->getTerminator()); 1806 for (const auto &&[index, result] : 1807 llvm::enumerate(tiledConsumerOp->getResults())) { 1808 tiledResult.push_back(result); 1809 tiledOffset.emplace_back(resultOffsets[index]); 1810 tiledSizes.emplace_back(resultSizes[index]); 1811 } 1812 return success(); 1813 }; 1814 // 14. Add new inits to [nested] loops. 1815 if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits, 1816 newYieldValuesFn))) { 1817 return rewriter.notifyMatchFailure(tiledConsumerOp, 1818 "unable to add new inits to nest loop"); 1819 } 1820 1821 // 15. Replace the result of scf loop and consumer op with new loop's results. 1822 1823 for (auto &&[oldResult, newResult] : llvm::zip( 1824 consumerOp->getResults(), 1825 nestedLoops.front()->getResults().take_back(newInits.size()))) { 1826 rewriter.replaceAllUsesWith(oldResult, newResult); 1827 } 1828 1829 // 16. Need to erase the old scf loop and the cloned consumer op. 1830 rewriter.eraseOp(clonedConsumerOp); 1831 1832 return scf::SCFFuseConsumerOfSliceResult{ 1833 consumerOpOperand, 1834 &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)), 1835 tileAndFuseResult->tiledOps}; 1836 } 1837 1838 //===----------------------------------------------------------------------===// 1839 // lowerToLoopsUsingSCFForOp implementation. 1840 //===----------------------------------------------------------------------===// 1841 1842 FailureOr<SmallVector<scf::ForOp>> 1843 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 1844 TilingInterface op) { 1845 // TODO: Handle cases where the op has results if needed. 1846 if (op->getNumResults() > 0) { 1847 return rewriter.notifyMatchFailure( 1848 op, "unable to lower to loops operations with return values"); 1849 } 1850 1851 SmallVector<Range> domain = op.getIterationDomain(rewriter); 1852 SmallVector<Value> ivs; 1853 SmallVector<scf::ForOp> loops; 1854 Location loc = op.getLoc(); 1855 for (auto loopRange : domain) { 1856 Value offsetVal = 1857 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 1858 Value sizeVal = 1859 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 1860 Value strideVal = 1861 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 1862 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 1863 strideVal, ValueRange{}); 1864 loops.push_back(loop); 1865 ivs.push_back(loop.getInductionVar()); 1866 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 1867 } 1868 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 1869 return failure(); 1870 } 1871 return loops; 1872 } 1873