1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/IR/Dominance.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 26 #include "mlir/Interfaces/TilingInterface.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 #include "llvm/Support/Debug.h" 29 #include <optional> 30 31 #define DEBUG_TYPE "tile-using-interface" 32 33 using namespace mlir; 34 35 scf::SCFTilingOptions & 36 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) { 37 assert(!tileSizeComputationFunction && "tile sizes already set"); 38 auto tileSizes = llvm::to_vector(ts); 39 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 40 return tileSizes; 41 }; 42 return *this; 43 } 44 45 scf::SCFTilingOptions & 46 scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) { 47 assert(!numThreadsComputationFunction && "num tiles already set"); 48 auto numThreads = llvm::to_vector(nt); 49 numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) { 50 return numThreads; 51 }; 52 return *this; 53 } 54 55 /// Helper method to adjust the interchange vector to match the iteration 56 /// domain. 57 static SmallVector<int64_t> 58 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 59 size_t iterationDomainSize) { 60 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 61 if (filledVector.size() < iterationDomainSize) { 62 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 63 filledVector.append(range.begin(), range.end()); 64 } 65 if (filledVector.size() > iterationDomainSize) 66 filledVector.resize(iterationDomainSize); 67 return filledVector; 68 } 69 70 //===----------------------------------------------------------------------===// 71 // tileUsingSCF implementation. 72 //===----------------------------------------------------------------------===// 73 74 /// Verify the tile size options are set in a consistent manner. 75 static LogicalResult 76 verifyTileSizeOptions(RewriterBase &rewriter, Location loc, 77 const scf::SCFTilingOptions &options) { 78 // Specifying number of threads is only supported on `scf.forall` op. 79 if (options.numThreadsComputationFunction && 80 options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) { 81 return rewriter.notifyMatchFailure( 82 loc, "number of threads can only by specified when loop type is " 83 "set to use `scf.forall`"); 84 } 85 86 // If specified, check that the interchange vector is a permutation. 87 if (!options.interchangeVector.empty()) { 88 if (!isPermutationVector(options.interchangeVector)) { 89 return rewriter.notifyMatchFailure( 90 loc, "invalid interchange vector, not a permutation of the entire " 91 "iteration space"); 92 } 93 } 94 return success(); 95 } 96 97 /// Method to instantiate the tile sizes and/or number of threads specified 98 /// by the user. 99 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> 100 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, 101 ArrayRef<Range> iterationDomain, 102 const scf::SCFTilingOptions &options) { 103 OpFoldResult zero = rewriter.getIndexAttr(0); 104 SmallVector<OpFoldResult> tileSizes, numThreads; 105 size_t numLoops = iterationDomain.size(); 106 107 // Check whether the number of tiles to use is specified. 108 if (options.numThreadsComputationFunction) { 109 numThreads = options.numThreadsComputationFunction(rewriter, op); 110 numThreads.resize(numLoops, zero); 111 112 // If the number of tiles is also specified, use that. 113 if (options.tileSizeComputationFunction) { 114 tileSizes = options.tileSizeComputationFunction(rewriter, op); 115 tileSizes.resize(numLoops, zero); 116 return {tileSizes, numThreads}; 117 } 118 119 // Compute the tile sizes from the iteration domain and number 120 // of tiles as follows 121 // - niters = ceilDiv(ub - lb, step) 122 // - tileSize = ceilDiv(niters, numThreads) 123 AffineExpr s0, s1, s2; 124 bindSymbols(rewriter.getContext(), s0, s1, s2); 125 // TODO: The step here is assumed to be 1. 126 AffineExpr numItersExpr = (s1 - s0); 127 AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2); 128 tileSizes.resize(numLoops, zero); 129 for (auto [index, range, nt] : 130 llvm::enumerate(iterationDomain, numThreads)) { 131 if (isConstantIntValue(nt, 0)) 132 continue; 133 134 tileSizes[index] = affine::makeComposedFoldedAffineApply( 135 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt}); 136 } 137 tileSizes.resize(numLoops, zero); 138 return {tileSizes, numThreads}; 139 } 140 141 // Enforce the convention that "tiling by zero" 142 // skips tiling a particular dimension. This convention is significantly 143 // simpler to handle instead of adjusting affine maps to account for missing 144 // dimensions. 145 assert(options.tileSizeComputationFunction && 146 "expected tile sizes to be specified"); 147 tileSizes = options.tileSizeComputationFunction(rewriter, op); 148 tileSizes.resize(numLoops, zero); 149 150 return {tileSizes, numThreads}; 151 } 152 153 /// Checks if any of the tiled loops are not parallel. 154 static void checkSafeToTileToForall(TilingInterface op, 155 ArrayRef<OpFoldResult> tileSizes, 156 ArrayRef<OpFoldResult> numThreads) { 157 auto iterators = op.getLoopIteratorTypes(); 158 assert(iterators.size() == tileSizes.size() && 159 "expected as many tile size values as number of loops"); 160 assert((numThreads.empty() || (numThreads.size() == iterators.size())) && 161 "when specified, expected number of threads to use for each loop"); 162 163 for (auto [index, iterator, tileSize] : 164 llvm::enumerate(iterators, tileSizes)) { 165 // If num threads is specified, check that it is greater than one only for 166 // parallel dimensions. 167 if (!numThreads.empty()) { 168 if (std::optional<int64_t> constNumThreads = 169 getConstantIntValue(numThreads[index])) { 170 if (constNumThreads.value() > 1 && 171 iterator != utils::IteratorType::parallel) { 172 op.emitWarning() << "tiling is not thread safe at axis #" << index; 173 } 174 } 175 continue; 176 } 177 178 if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) { 179 if (constTileSize.value() > 0 && 180 iterator != utils::IteratorType::parallel) { 181 op.emitWarning() << "tiling is not thread safe at axis #" << index; 182 } 183 } 184 } 185 } 186 187 /// Check if `stride` evenly divides the trip count `size - offset`. 188 static bool tileDividesIterationDomain(Range loopRange) { 189 std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 190 if (!offsetAsInt) 191 return false; 192 std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 193 if (!sizeAsInt) 194 return false; 195 std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 196 if (!strideAsInt) 197 return false; 198 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 199 } 200 201 /// Returns the bounded tile size given the current `offset`, `loopRange` and 202 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`. 203 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 204 Range loopRange, OpFoldResult offset, 205 OpFoldResult tileSize) { 206 std::optional<int64_t> ts = getConstantIntValue(tileSize); 207 if (ts && ts.value() == 1) 208 return tileSize; 209 210 if (tileDividesIterationDomain( 211 Range{loopRange.offset, loopRange.size, tileSize})) 212 return tileSize; 213 214 // The tile size to use (to avoid out of bounds access) is minimum of 215 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 216 // loop. 217 AffineExpr s0, s1, d0; 218 bindDims(b.getContext(), d0); 219 bindSymbols(b.getContext(), s0, s1); 220 AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext()); 221 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 222 return affine::makeComposedFoldedAffineMin( 223 b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize}); 224 } 225 226 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less 227 /// than `iterationSize`. 228 static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, 229 OpFoldResult numThreads, 230 OpFoldResult iterationSize) { 231 std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize); 232 std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads); 233 std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize); 234 if (!tileSizeConst || !numThreadsConst || !iterSizeConst) 235 return false; 236 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; 237 } 238 239 /// Compute the `OpFoldResult`s that represents the multi-dimensional 240 /// `offset`s and `size`s of the tile of the iteration space that the 241 /// innermost loop body of the generated tiled loops corresponds to. 242 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> 243 getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, 244 ArrayRef<Range> iterationDomain, 245 ArrayRef<OpFoldResult> tileSizes, 246 ArrayRef<OpFoldResult> numThreads) { 247 SmallVector<OpFoldResult> offsets, sizes; 248 int materializedLoopNum = 0; 249 250 if (!numThreads.empty()) { 251 AffineExpr d0, d1, s0, s1; 252 AffineExpr offsetExpr, residualTileSizeExpr; 253 bindDims(rewriter.getContext(), d0, d1); 254 bindSymbols(rewriter.getContext(), s0, s1); 255 offsetExpr = d0 + d1 * s0; 256 residualTileSizeExpr = s1 - (d0 + d1 * s0); 257 258 for (auto [nt, tileSize, loopRange] : 259 llvm::zip_equal(numThreads, tileSizes, iterationDomain)) { 260 261 // Non-tiled cases, set the offset and size to the 262 // `loopRange.offset/size`. 263 if (isConstantIntValue(nt, 0)) { 264 offsets.push_back(loopRange.offset); 265 sizes.push_back(loopRange.size); 266 continue; 267 } 268 269 Value iv = ivs[materializedLoopNum++]; 270 OpFoldResult offset = affine::makeComposedFoldedAffineApply( 271 rewriter, loc, offsetExpr, 272 ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize}); 273 OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply( 274 rewriter, loc, residualTileSizeExpr, 275 {loopRange.offset, nt, tileSize, loopRange.size}); 276 277 OpFoldResult size = tileSize; 278 if (!isConstantIntValue(residualTileSize, 0)) { 279 OpFoldResult sizeMinusOffsetPerThread = 280 affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, 281 {offset, loopRange.size}); 282 size = affine::makeComposedFoldedAffineMin( 283 rewriter, loc, 284 AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), 285 {sizeMinusOffsetPerThread, tileSize}); 286 } 287 288 // Consider the case where the original loop was `[0, 100)`. 289 // If number of threads are `7`, the tile size would be computed as 290 // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6) 291 // - `offset = 0 + 6 * 15 = 105` 292 // - `tileSize = min(15, 100 - 105) = -5` 293 // To avoid negative tile sizes, we need to do a further 294 // `nonNegativeTileSize = affine.max(0, tileSize)`. 295 // This `max` can be avoided if 296 // `offset + tileSize * (numThreads - 1) < (ub - lb)` 297 if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) { 298 AffineMap maxMap = 299 AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); 300 size = affine::makeComposedFoldedAffineMax( 301 rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size}); 302 } 303 304 offsets.push_back(offset); 305 sizes.push_back(size); 306 } 307 return {offsets, sizes}; 308 } else { 309 for (auto [tileSize, loopRange] : 310 llvm::zip_equal(tileSizes, iterationDomain)) { 311 312 // Non-tiled cases, set the offset and size to the 313 // `loopRange.offset/size`. 314 if (isConstantIntValue(tileSize, 0)) { 315 offsets.push_back(loopRange.offset); 316 sizes.push_back(loopRange.size); 317 continue; 318 } 319 320 Value iv = ivs[materializedLoopNum++]; 321 OpFoldResult offset = getAsOpFoldResult(iv); 322 offsets.push_back(offset); 323 OpFoldResult size = 324 getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize); 325 sizes.push_back(size); 326 } 327 return {offsets, sizes}; 328 } 329 } 330 331 /// Function to return the bounds of the loops to be generated. 332 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, 333 SmallVector<OpFoldResult>> 334 getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 335 ArrayRef<OpFoldResult> tileSizes) { 336 SmallVector<OpFoldResult> lbs, ubs, steps; 337 for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) { 338 // No loop if the tile size is 0. 339 if (isConstantIntValue(tileSize, 0)) 340 continue; 341 lbs.push_back(loopRange.offset); 342 ubs.push_back(loopRange.size); 343 steps.push_back(tileSize); 344 } 345 return {lbs, ubs, steps}; 346 } 347 348 /// A function that allows returning additional yielded values during 349 /// `yieldTiledValuesAndReplace`. 350 /// - `ivs` induction variable for the loop. 351 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args. 352 /// - `tiledValues` the tiled values to return. Must be of same size as 353 /// `newbbArgs`, each element of this array is inserted into the corresponding 354 /// element in `newbbArgs`. 355 /// - `resultOffsets` is of the same size as `tiledValues` and represents 356 /// the offsets to use when inserting corresponding element from `tiledValues` 357 /// into the element from `newBbArgs`. 358 /// - `resultSizes` is of the same size as `tiledValues` and represents 359 /// the size of the corresponding element from `tiledValues` inserted into 360 /// the element from `newBbArgs`. 361 /// In case the method needs to return `failure()` the method is expected 362 /// to clean up any inserted operations. 363 using YieldTiledValuesFn = std::function<LogicalResult( 364 RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, 365 SmallVector<Value> &tiledValues, 366 SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 367 SmallVector<SmallVector<OpFoldResult>> &resultSizes)>; 368 369 /// Clones the operation and updates the destination if the operation 370 /// implements the `DestinationStyleOpInterface`. 371 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, 372 Operation *op, 373 ValueRange newDestArgs) { 374 Operation *clonedOp = rewriter.clone(*op); 375 if (newDestArgs.empty()) 376 return clonedOp; 377 if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp)) 378 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs); 379 return clonedOp; 380 } 381 382 /// Generate the tile-loop nest using `scf.for` operation. 383 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 384 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 385 /// - `destinationTensors` are the init values to use for the outer most loop. 386 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 387 /// most 388 /// loop. 389 /// - `loops` is an in-out parameter into which the generated loops are 390 /// populated. 391 static LogicalResult generateLoopNestUsingForOp( 392 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 393 ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors, 394 YieldTiledValuesFn yieldTiledValuesFn, 395 SmallVector<LoopLikeOpInterface> &loops) { 396 assert(!loopRanges.empty() && "unexpected empty loop ranges"); 397 assert(loopRanges.size() == tileSizes.size() && 398 "expected as many tile sizes as loop ranges"); 399 OpBuilder::InsertionGuard guard(rewriter); 400 401 SmallVector<OpFoldResult> lbs, ubs, steps; 402 std::tie(lbs, ubs, steps) = 403 getLoopBounds(rewriter, loc, loopRanges, tileSizes); 404 SmallVector<Value> lbVals = 405 getValueOrCreateConstantIndexOp(rewriter, loc, lbs); 406 SmallVector<Value> ubVals = 407 getValueOrCreateConstantIndexOp(rewriter, loc, ubs); 408 SmallVector<Value> stepVals = 409 getValueOrCreateConstantIndexOp(rewriter, loc, steps); 410 411 SmallVector<Value> ivs; 412 for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) { 413 auto loop = 414 rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors, 415 [](OpBuilder &bodyBuilder, Location bodyLoc, 416 Value iv, ValueRange /*iterArgs*/) {}); 417 loops.push_back(loop); 418 ivs.push_back(loop.getInductionVar()); 419 rewriter.setInsertionPointToEnd(loop.getBody()); 420 destinationTensors = loop.getRegionIterArgs(); 421 } 422 423 SmallVector<Value> tiledResults; 424 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 425 if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors, 426 tiledResults, resultOffsets, resultSizes))) { 427 return rewriter.notifyMatchFailure( 428 loc, "failed to generate inner tile loop body"); 429 } 430 if (loops.empty()) 431 return success(); 432 433 assert(tiledResults.size() == destinationTensors.size() && 434 "Number of results of body should be equal to number of iter args"); 435 436 // 6. Yield all the results of the tiled operation. 437 SmallVector<Value> yieldedValues; 438 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : 439 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, 440 resultSizes)) { 441 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 442 rewriter.getIndexAttr(1)); 443 auto insertSlice = rewriter.create<tensor::InsertSliceOp>( 444 loc, tiledValue, destinationTensor, resultOffset, resultSize, 445 resultStride); 446 yieldedValues.push_back(insertSlice); 447 } 448 rewriter.create<scf::YieldOp>(loc, yieldedValues); 449 450 // Add the scf.yield operations for all the outer loops. 451 for (auto [outerLoop, innerLoop] : 452 llvm::zip_equal(MutableArrayRef(loops).drop_back(), 453 MutableArrayRef(loops).drop_front())) { 454 rewriter.setInsertionPointToEnd( 455 cast<scf::ForOp>(outerLoop.getOperation()).getBody()); 456 rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults()); 457 } 458 return success(); 459 } 460 461 /// Generate the tile-loop nest using `scf.forall` operation. 462 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 463 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 464 /// - `destinationTensors` are the init values to use for the outer most loop. 465 /// - `mappingVector` is the mapping attributes to use for loop construction. 466 /// Can be empty. 467 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 468 /// most 469 /// loop. 470 /// - `loops` is an in-out parameter into which the generated loops are 471 /// populated. 472 static LogicalResult generateLoopNestUsingForallOp( 473 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 474 ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads, 475 ArrayRef<Attribute> mappingVector, ValueRange destinationTensors, 476 YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { 477 assert(!loopRanges.empty() && "unexpected empty loop ranges"); 478 assert(loopRanges.size() == tileSizes.size() && 479 "expected as many tile sizes as loop ranges"); 480 OpBuilder::InsertionGuard guard(rewriter); 481 SmallVector<OpFoldResult> offsets(loopRanges.size()), 482 sizes(loopRanges.size()); 483 484 std::optional<ArrayAttr> mappingAttr; 485 if (!mappingVector.empty()) 486 mappingAttr = rewriter.getArrayAttr(mappingVector); 487 488 scf::ForallOp forallOp; 489 bool useNumThreads = !numThreads.empty(); 490 491 if (useNumThreads) { 492 // Prune the zero numthreads. 493 SmallVector<OpFoldResult> nonZeroNumThreads; 494 for (auto nt : numThreads) { 495 if (isConstantIntValue(nt, 0)) 496 continue; 497 nonZeroNumThreads.push_back(nt); 498 } 499 forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads, 500 destinationTensors, mappingAttr); 501 } else { 502 SmallVector<OpFoldResult> lbs, ubs, steps; 503 std::tie(lbs, ubs, steps) = 504 getLoopBounds(rewriter, loc, loopRanges, tileSizes); 505 forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, 506 destinationTensors, mappingAttr); 507 } 508 loops.push_back(forallOp); 509 510 rewriter.setInsertionPoint(forallOp.getTerminator()); 511 destinationTensors = forallOp.getRegionOutArgs(); 512 513 SmallVector<Value> tiledResults; 514 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 515 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(), 516 destinationTensors, tiledResults, resultOffsets, 517 resultSizes))) 518 return rewriter.notifyMatchFailure(loc, "failed to generate loop body"); 519 520 rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); 521 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : 522 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, 523 resultSizes)) { 524 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 525 rewriter.getIndexAttr(1)); 526 527 rewriter.create<tensor::ParallelInsertSliceOp>( 528 loc, tiledValue, destinationTensor, resultOffset, resultSize, 529 resultStride); 530 } 531 return success(); 532 } 533 534 /// Generate the tile-loop nest using the loop construct specifed in `options`. 535 /// - `options`: Tiling options specified. 536 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 537 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 538 /// - `destinationTensors` are the init values to use for the outer most loop. 539 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 540 /// most 541 /// loop. 542 /// - `loops` is an in-out parameter into which the generated loops are 543 /// populated. 544 static LogicalResult generateLoopNest( 545 RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, 546 ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes, 547 ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors, 548 YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { 549 // If the tile sizes are all zero, no loops are generated. Just call the 550 // callback function to handle untiled case. 551 if (llvm::all_of(tileSizes, isZeroIndex)) { 552 SmallVector<Value> tiledResults; 553 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 554 return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, 555 tiledResults, resultOffsets, resultSizes); 556 } 557 if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { 558 return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, 559 destinationTensors, tiledBodyFn, loops); 560 } 561 if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { 562 return generateLoopNestUsingForallOp( 563 rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector, 564 destinationTensors, tiledBodyFn, loops); 565 } 566 return rewriter.notifyMatchFailure(loc, "unhandled loop type"); 567 } 568 569 /// Append the specified additional `newInitOperands` operands to the 570 /// loops existing `init` operands (or similar), and replace `loopOp` with 571 /// the new loop that has the additional init operands. The loop body of 572 /// this loop is moved over to the new loop. `yieldTiledValuesFn` 573 /// is called to get the new tiled values returned, and the offset 574 /// and sizes at which the tiled value is inserted into the 575 /// new region iter_args that correspond to the newly added init operands. 576 template <typename LoopType> 577 FailureOr<LoopLikeOpInterface> 578 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, 579 ValueRange newInitOperands, 580 YieldTiledValuesFn yieldTiledValuesFn) { 581 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); 582 } 583 584 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`. 585 template <> 586 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( 587 scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, 588 YieldTiledValuesFn yieldTiledValuesFn) { 589 OpBuilder::InsertionGuard g(rewriter); 590 Location loc = loopOp.getLoc(); 591 rewriter.setInsertionPoint(loopOp); 592 593 auto inits = llvm::to_vector(loopOp.getInitArgs()); 594 inits.append(newInitOperands.begin(), newInitOperands.end()); 595 auto newLoop = rewriter.create<scf::ForOp>( 596 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(), 597 inits, [](OpBuilder &, Location, Value, ValueRange) {}); 598 599 // Move the loop body to the new op. 600 Block *loopBody = loopOp.getBody(); 601 Block *newLoopBody = newLoop.getBody(); 602 rewriter.mergeBlocks( 603 loopBody, newLoopBody, 604 newLoopBody->getArguments().take_front(loopBody->getNumArguments())); 605 606 auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator()); 607 rewriter.setInsertionPoint(yieldOp); 608 609 SmallVector<Value> tiledValues; 610 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 611 ValueRange newRegionIterArgs = 612 newLoop.getRegionIterArgs().take_back(newInitOperands.size()); 613 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(), 614 newRegionIterArgs, tiledValues, resultOffsets, 615 resultSizes))) { 616 rewriter.eraseOp(newLoop); 617 return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values"); 618 } 619 620 SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands()); 621 for (auto [tiledValue, regionIterArg, resultOffset, resultSize] : 622 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets, 623 resultSizes)) { 624 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 625 rewriter.getIndexAttr(1)); 626 Value insert = rewriter.create<tensor::InsertSliceOp>( 627 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize, 628 resultStride); 629 newYieldValues.push_back(insert); 630 } 631 632 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 633 rewriter.replaceOp(loopOp, 634 newLoop->getResults().take_front(loopOp.getNumResults())); 635 return cast<LoopLikeOpInterface>(newLoop.getOperation()); 636 } 637 638 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall` 639 template <> 640 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>( 641 scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, 642 YieldTiledValuesFn yieldTiledValuesFn) { 643 OpBuilder::InsertionGuard g(rewriter); 644 Location loc = loopOp.getLoc(); 645 rewriter.setInsertionPoint(loopOp); 646 auto inits = llvm::to_vector(loopOp.getOutputs()); 647 inits.append(newInitOperands.begin(), newInitOperands.end()); 648 auto newLoop = rewriter.create<scf::ForallOp>( 649 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(), 650 loopOp.getMixedStep(), inits, loopOp.getMapping(), 651 [](OpBuilder &, Location, ValueRange) {}); 652 653 // Move the region of the current block to the newly created op. 654 Block *loopBody = loopOp.getBody(); 655 Block *newLoopBody = newLoop.getBody(); 656 rewriter.mergeBlocks( 657 loopBody, newLoopBody, 658 newLoopBody->getArguments().take_front(loopBody->getNumArguments())); 659 660 auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator()); 661 rewriter.setInsertionPoint(terminator); 662 SmallVector<Value> tiledValues; 663 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 664 ValueRange regionIterArgs = 665 newLoop.getRegionIterArgs().take_back(newInitOperands.size()); 666 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(), 667 regionIterArgs, tiledValues, resultOffsets, 668 resultSizes))) { 669 rewriter.eraseOp(newLoop); 670 return rewriter.notifyMatchFailure(loopOp, 671 "failed to get yielded tiled values"); 672 } 673 674 // Update the terminator. 675 rewriter.setInsertionPointToEnd(terminator.getBody()); 676 677 for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal( 678 tiledValues, regionIterArgs, resultOffsets, resultSizes)) { 679 SmallVector<OpFoldResult> resultStride(resultOffset.size(), 680 rewriter.getIndexAttr(1)); 681 rewriter.create<tensor::ParallelInsertSliceOp>( 682 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize, 683 resultStride); 684 } 685 686 rewriter.replaceOp(loopOp, 687 newLoop->getResults().take_front(loopOp.getNumResults())); 688 return cast<LoopLikeOpInterface>(newLoop.getOperation()); 689 } 690 691 /// Implementation of `yieldTiledValuesAndReplaceLoop` for 692 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each 693 /// supported loop type. 694 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop( 695 LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter, 696 ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { 697 return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>( 698 loopLikeOp.getOperation()) 699 .Case<scf::ForOp, scf::ForallOp>( 700 [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { 701 return yieldTiledValuesAndReplaceLoop( 702 loopOp, rewriter, newInitOperands, yieldTiledValuesFn); 703 }) 704 .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { 705 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); 706 }); 707 } 708 709 /// Method to add new init values to a loop nest. Updates `loops` in-place with 710 /// new loops that use the `newInitValues`. 711 /// The outer-loops are updated to yield the new result values of the inner 712 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get 713 /// the additional values to yield form the innermost loop. 714 static LogicalResult addInitOperandsToLoopNest( 715 RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops, 716 ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) { 717 SmallVector<scf::ForOp> newLoops; 718 if (loops.empty()) 719 return success(); 720 OpBuilder::InsertionGuard g(rewriter); 721 rewriter.setInsertionPoint(loops.front()); 722 723 SmallVector<Value> ivs; 724 for (auto &loop : loops.drop_back()) { 725 rewriter.setInsertionPoint(loop); 726 727 // if loops.size() > 1 we assume that scf.for is used for the loops. 728 auto forLoop = cast<scf::ForOp>(loop.getOperation()); 729 730 // Create a new loop with the new init values for this loop. 731 SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs()); 732 newInits.append(newInitValues.begin(), newInitValues.end()); 733 auto newLoop = rewriter.create<scf::ForOp>( 734 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), 735 forLoop.getStep(), newInits, 736 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); 737 738 // Merge the body of the new loop with the body of the old loops. 739 SmallVector<Value> sourceBlockArgs; 740 sourceBlockArgs.push_back(newLoop.getInductionVar()); 741 auto newRegionIterArgs = newLoop.getRegionIterArgs(); 742 sourceBlockArgs.append( 743 newRegionIterArgs.begin(), 744 std::next(newRegionIterArgs.begin(), forLoop.getNumResults())); 745 rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs); 746 rewriter.replaceOp( 747 forLoop, newLoop.getResults().take_front(forLoop.getNumResults())); 748 loop = newLoop; 749 ivs.push_back(newLoop.getInductionVar()); 750 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size()); 751 } 752 753 // Update the loop body of the innermost loop to get new yield values. 754 LoopLikeOpInterface innerMostLoop = loops.back(); 755 FailureOr<LoopLikeOpInterface> newInnerMostLoop = 756 yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues, 757 getNewTiledYieldsFn); 758 759 if (failed(newInnerMostLoop)) 760 return innerMostLoop.emitOpError("failed to return additional yields"); 761 loops.back() = newInnerMostLoop.value(); 762 763 // Make all other loops except the innermost loops yield the values returned 764 // by the inner loop. 765 for (auto [outerLoop, innerLoop] : 766 llvm::zip_equal(loops.drop_back(), loops.drop_front())) { 767 // Again assume that all the outer loops are scf.for operations. 768 auto outerForLoop = cast<scf::ForOp>(outerLoop); 769 auto outerLoopYield = 770 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); 771 SmallVector<Value> newYields = 772 llvm::to_vector(outerLoopYield.getOperands()); 773 ValueRange additionalYields = 774 innerLoop->getResults().take_back(newInitValues.size()); 775 newYields.append(additionalYields.begin(), additionalYields.end()); 776 rewriter.setInsertionPoint(outerLoopYield); 777 rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields); 778 } 779 return success(); 780 } 781 782 /// Implementation of tiling transformation of `op` that implements the 783 /// `TilingInterface` using `scf.for` to iterate over the tiles. 784 FailureOr<scf::SCFTilingResult> 785 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, 786 const scf::SCFTilingOptions &options) { 787 if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) { 788 return failure(); 789 } 790 791 OpBuilder::InsertionGuard guard(rewriter); 792 rewriter.setInsertionPointAfter(op); 793 794 // 1. Get the range of the loops that are represented by the operation. 795 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 796 797 // 2. Materialize the tile sizes and/or number of threads; 798 SmallVector<OpFoldResult> tileSizes, numThreads; 799 std::tie(tileSizes, numThreads) = 800 getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options); 801 802 // Check if it is safe to tile. This is hold over from previous iterations 803 // of tile to for-all. Consider dropping it. 804 if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { 805 checkSafeToTileToForall(op, tileSizes, numThreads); 806 } 807 808 // 3. If there is an interchange specified, permute the iteration domain and 809 // the tile sizes. 810 SmallVector<int64_t> interchangeVector; 811 if (!options.interchangeVector.empty()) { 812 interchangeVector = fillInterchangeVector(options.interchangeVector, 813 iterationDomain.size()); 814 assert(isPermutationVector(interchangeVector) && 815 "expected interchange vector to be a permutation"); 816 817 applyPermutationToVector(iterationDomain, interchangeVector); 818 applyPermutationToVector(tileSizes, interchangeVector); 819 if (!numThreads.empty()) 820 applyPermutationToVector(numThreads, interchangeVector); 821 } 822 823 FailureOr<TilingResult> tilingResult; 824 // 4. Define the lambda function used later to generate the body of the 825 // innermost tiled loop. 826 YieldTiledValuesFn innerYieldTiledValuesFn = 827 [&](RewriterBase &rewriter, Location loc, ValueRange ivs, 828 ValueRange regionIterArgs, SmallVector<Value> &tiledResults, 829 SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 830 SmallVector<SmallVector<OpFoldResult>> &resultSizes) 831 -> LogicalResult { 832 // 4a. Compute the `offsets` and `sizes` to use for tiling. 833 SmallVector<OpFoldResult> offsets, sizes; 834 std::tie(offsets, sizes) = getTileOffsetAndSizes( 835 rewriter, loc, ivs, iterationDomain, tileSizes, numThreads); 836 837 // 4b. If interchange was provided, apply inverse of the interchange 838 // to get back the offsets/sizes in the order to be specified. 839 if (!interchangeVector.empty()) { 840 auto inversePermutation = invertPermutationVector(interchangeVector); 841 applyPermutationToVector(offsets, inversePermutation); 842 applyPermutationToVector(sizes, inversePermutation); 843 } 844 845 // 5. Generate the tiled implementation within the inner most loop. 846 847 // 5a. Clone the operation within the loop body. 848 auto clonedOp = cast<TilingInterface>( 849 cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs)); 850 851 // 5b. Early return cloned op if tiling is not happening. We can not return 852 // the original op because it could lead to 853 // `rewriter.replaceOp(op, op->getResults())` and users would get crash. 854 if (llvm::all_of(tileSizes, isZeroIndex)) { 855 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); 856 tilingResult = 857 TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(), 858 /*generatedSlices=*/{}}; 859 return success(); 860 } 861 862 // 5c. Tile the cloned operation. 863 tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes); 864 if (failed(tilingResult)) { 865 rewriter.eraseOp(clonedOp); 866 return op.emitOpError("faild to tile operation"); 867 } 868 869 // 5d. Delete the cloned operation. 870 rewriter.eraseOp(clonedOp); 871 872 // 5e. Compute the offsets at which the result values are to be inserted 873 // back into its destinations. 874 for (auto [index, tiledValue] : 875 llvm::enumerate(tilingResult->tiledValues)) { 876 tiledResults.push_back(tiledValue); 877 SmallVector<OpFoldResult> resultOffset, resultSize; 878 if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes, 879 resultOffset, resultSize))) { 880 for (auto op : tilingResult->tiledOps) { 881 rewriter.eraseOp(op); 882 } 883 return rewriter.notifyMatchFailure( 884 op, "failed to get slice of result produced"); 885 } 886 resultOffsets.emplace_back(std::move(resultOffset)); 887 resultSizes.emplace_back(std::move(resultSize)); 888 } 889 890 return success(); 891 }; 892 893 // 6. Find the destination tensors to use for the operation. 894 SmallVector<Value> destinationTensors; 895 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 896 destinationTensors))) { 897 return rewriter.notifyMatchFailure(op, 898 "unable to create destination tensors"); 899 } 900 901 // 7. Generate the tiled loops nest using the callback defined above. 902 SmallVector<LoopLikeOpInterface> loops; 903 if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain, 904 tileSizes, numThreads, destinationTensors, 905 innerYieldTiledValuesFn, loops))) 906 return op.emitOpError("failed to generate tiling loops"); 907 assert(succeeded(tilingResult) && 908 "expected tiling result to be computed after loop generation"); 909 910 // If loops are empty, the tiled op is used as the replacement for the untiled 911 // op. 912 if (loops.empty()) { 913 return scf::SCFTilingResult{tilingResult->tiledOps, loops, 914 tilingResult->tiledValues, 915 tilingResult->generatedSlices}; 916 } 917 918 SmallVector<Value> replacements = llvm::map_to_vector( 919 loops.front()->getResults(), [](OpResult r) -> Value { return r; }); 920 return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements, 921 tilingResult->generatedSlices}; 922 } 923 924 FailureOr<scf::SCFReductionTilingResult> 925 mlir::scf::tileReductionUsingScf(RewriterBase &b, 926 PartialReductionOpInterface op, 927 ArrayRef<OpFoldResult> tileSizes) { 928 Location loc = op.getLoc(); 929 // Ops implementing PartialReductionOpInterface are expected to implement 930 // TilingInterface. 931 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 932 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 933 auto tileSizesVector = llvm::to_vector(tileSizes); 934 if (tileSizesVector.size() < iterationDomain.size()) { 935 auto zero = b.getIndexAttr(0); 936 tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(), 937 zero); 938 } 939 SmallVector<utils::IteratorType> iterators = 940 tilingInterfaceOp.getLoopIteratorTypes(); 941 942 SmallVector<int> reductionDims; 943 for (auto [idx, iteratorType] : 944 llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { 945 if (iteratorType == utils::IteratorType::reduction) 946 reductionDims.push_back(idx); 947 } 948 949 // 2. create the inital tensor value. 950 FailureOr<SmallVector<Value>> maybeInitTensors = 951 op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector, 952 reductionDims); 953 if (failed(maybeInitTensors)) { 954 return b.notifyMatchFailure(op, "Failed to create initial tensors."); 955 } 956 SmallVector<Value> &initTensors = maybeInitTensors.value(); 957 958 // 3. Define the callback to use for generating the inner most tile loop body. 959 SmallVector<Operation *> parallelTiledOps; 960 auto innerYieldTiledValuesFn = 961 [&](RewriterBase &rewriter, Location loc, ValueRange ivs, 962 ValueRange regionIterArgs, SmallVector<Value> &tiledResult, 963 SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 964 SmallVector<SmallVector<OpFoldResult>> &resultSizes) 965 -> LogicalResult { 966 SmallVector<OpFoldResult> offsets, sizes; 967 { 968 int materializedLoopNum = 0; 969 for (auto [tileSize, loopRange] : 970 llvm::zip_equal(tileSizesVector, iterationDomain)) { 971 if (isConstantIntValue(tileSize, 0)) { 972 offsets.push_back(loopRange.offset); 973 sizes.push_back(loopRange.size); 974 continue; 975 } 976 Value iv = ivs[materializedLoopNum++]; 977 offsets.push_back(iv); 978 sizes.push_back( 979 getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize)); 980 } 981 } 982 983 // 4a. Clone the operation. 984 { 985 auto clonedOp = cast<PartialReductionOpInterface>( 986 cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs)); 987 988 // 4b. Tile the cloned operation. 989 FailureOr<TilingResult> partialTilingResult = 990 clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets, 991 sizes, reductionDims); 992 if (failed(partialTilingResult)) { 993 return failure(); 994 } 995 std::swap(parallelTiledOps, partialTilingResult->tiledOps); 996 std::swap(tiledResult, partialTilingResult->tiledValues); 997 998 // 4c. Delete the cloned operation. 999 b.eraseOp(clonedOp); 1000 } 1001 1002 // 4d. Compute the offsets and sizes needed to insert the result of the 1003 // tiled value back into destination before yielding the destination. 1004 for (auto result : tiledResult) { 1005 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); 1006 resultOffsets.emplace_back(std::move(outOffsets)); 1007 1008 SmallVector<OpFoldResult> outSizes; 1009 for (size_t i = 0; i < offsets.size(); i++) { 1010 outSizes.push_back(tensor::getMixedSize(b, loc, result, i)); 1011 } 1012 resultSizes.emplace_back(std::move(outSizes)); 1013 } 1014 return success(); 1015 }; 1016 1017 // 5. Generate the tiled implementation using the destination tensors. 1018 SmallVector<LoopLikeOpInterface> loops; 1019 scf::SCFTilingOptions options; 1020 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); 1021 if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector, 1022 /*numThreads=*/ArrayRef<OpFoldResult>{}, 1023 initTensors, innerYieldTiledValuesFn, loops))) 1024 return b.notifyMatchFailure(op, "failed to tile for parallel reduction"); 1025 1026 SmallVector<Value> replacements = llvm::map_to_vector( 1027 loops.front()->getResults(), [](OpResult r) -> Value { return r; }); 1028 1029 // 5. Apply the merge reduction to combine all the partial values. 1030 b.setInsertionPointAfter(*loops.begin()); 1031 FailureOr<MergeResult> mergeResult = 1032 op.mergeReductions(b, loc, replacements, reductionDims); 1033 if (failed(mergeResult)) { 1034 return failure(); 1035 } 1036 b.replaceOp(op, mergeResult->replacements); 1037 1038 SCFReductionTilingResult reductionTilingResult; 1039 std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps); 1040 std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps); 1041 std::swap(reductionTilingResult.initialValues, initTensors); 1042 std::swap(reductionTilingResult.loops, loops); 1043 std::swap(reductionTilingResult.replacements, mergeResult->replacements); 1044 1045 return reductionTilingResult; 1046 } 1047 1048 //===----------------------------------------------------------------------===// 1049 // tileConsumerAndFuseProducersUsingSCF implementation. 1050 //===----------------------------------------------------------------------===// 1051 1052 /// Return the untiled producer whose slice is used in a tiled consumer. The 1053 /// method traverses the tile loop nest (`loops`) if needed, and returns the 1054 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 1055 /// indicates that this is a destination operand of the consumer. If there was 1056 /// no loop traversal needed, the second value of the returned tuple is empty. 1057 static std::tuple<OpResult, std::optional<OpOperand *>> 1058 getUntiledProducerFromSliceSource(OpOperand *source, 1059 ArrayRef<LoopLikeOpInterface> loops) { 1060 std::optional<OpOperand *> destinationIterArg; 1061 auto loopIt = loops.rbegin(); 1062 while (auto iterArg = dyn_cast<BlockArgument>(source->get())) { 1063 auto loop = *loopIt; 1064 if (iterArg.getOwner()->getParentOp() != loop) 1065 break; 1066 source = loop.getTiedLoopInit(iterArg); 1067 loopIt++; 1068 } 1069 if (loopIt == loops.rend()) 1070 destinationIterArg = source; 1071 return {dyn_cast<OpResult>(source->get()), destinationIterArg}; 1072 } 1073 1074 /// Implementation of fusing producer of a single slice by computing the 1075 /// slice of the producer in-place. 1076 std::optional<scf::SCFFuseProducerOfSliceResult> 1077 mlir::scf::tileAndFuseProducerOfSlice( 1078 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, 1079 MutableArrayRef<LoopLikeOpInterface> loops) { 1080 // 1. Get the producer of the source (potentially walking through 1081 // `iter_args` of nested `scf.for`) 1082 auto [fusableProducer, destinationInitArg] = 1083 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), 1084 loops); 1085 if (!fusableProducer) 1086 return std::nullopt; 1087 unsigned resultNumber = fusableProducer.getResultNumber(); 1088 1089 OpBuilder::InsertionGuard g(rewriter); 1090 rewriter.setInsertionPoint(candidateSliceOp); 1091 1092 // 2. Clone the fused producer 1093 // 2a. Compute the destination operands to use for the cloned operation. 1094 SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors; 1095 Operation *fusableProducerOp = fusableProducer.getOwner(); 1096 if (isa<DestinationStyleOpInterface>(fusableProducerOp) && 1097 failed(tensor::getOrCreateDestinations( 1098 rewriter, fusableProducerOp->getLoc(), fusableProducerOp, 1099 origDestinationTensors))) 1100 return std::nullopt; 1101 1102 clonedOpDestinationTensors = origDestinationTensors; 1103 if (destinationInitArg && 1104 isa<DestinationStyleOpInterface>(fusableProducerOp)) { 1105 // 2b. If the producer is also destination style, then to maintain the 1106 // destination passing style, update the destination of the producer to be 1107 // the source of the slice. 1108 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource(); 1109 } 1110 // 2c. Clone the fused producer. 1111 Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs( 1112 rewriter, fusableProducerOp, clonedOpDestinationTensors); 1113 // 2d. Update the source of the candidateSlice to be the cloned producer. 1114 // Easier to just clone the slice with different source since replacements 1115 // and DCE of cloned ops becomes easier 1116 SmallVector<Value> candidateSliceOpOperands = 1117 llvm::to_vector(candidateSliceOp->getOperands()); 1118 candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber); 1119 tensor::ExtractSliceOp clonedCandidateSliceOp = 1120 mlir::clone(rewriter, candidateSliceOp, 1121 candidateSliceOp->getResultTypes(), candidateSliceOpOperands); 1122 1123 // 3. Generate the tiled implementation of the producer of the source 1124 FailureOr<TilingResult> tileAndFuseResult = 1125 tensor::replaceExtractSliceWithTiledProducer( 1126 rewriter, clonedCandidateSliceOp, 1127 clonedProducerOp->getResult(resultNumber)); 1128 if (failed(tileAndFuseResult)) 1129 return std::nullopt; 1130 // Note: Do not delete the candidateSliceOp, since its passed in from the 1131 // caller. 1132 rewriter.replaceAllUsesWith(candidateSliceOp, 1133 tileAndFuseResult->tiledValues[0]); 1134 rewriter.eraseOp(clonedCandidateSliceOp); 1135 rewriter.eraseOp(clonedProducerOp); 1136 1137 // 3. If the slice is for a destination operand, for example, 1138 // 1139 // ```mlir 1140 // %0 = linalg.init 1141 // %1 = linalg.fill .. outs(%0 : ) 1142 // %2 = scf.for .. iter_args(%arg0 = %1) { 1143 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 1144 // %4 = tensor.extract_slice %arg1 [..] 1145 // .. = linalg.matmul .. outs(%4 : ) 1146 // } 1147 // } 1148 // ``` 1149 // 1150 // the IR is currently 1151 // 1152 // ``` 1153 // %0 = linalg.init 1154 // %1 = linalg.fill 1155 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 1156 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 1157 // %4 = tensor.extract_slice %arg1[..] 1158 // %5 = linalg.fill .. outs(%4 : ) 1159 // .. = linalg.matmul .. outs(%5 : ) 1160 // } 1161 // } 1162 // ``` 1163 // 1164 // The untiled `linalg.fill` is still used as the `init_value` since it 1165 // was originally a destination operand of the untiled `linalg.matmul`. 1166 // When fusing an operand that is a destination operand, the iter_arg of 1167 // the outer most loop should be changed to use the destination of the 1168 // fused operation. With this the IR will be. 1169 // 1170 // ``` 1171 // %0 = linalg.init 1172 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 1173 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 1174 // %3 = tensor.extract_slice %arg1[..] 1175 // %4 = linalg.fill .. outs(%3 : ) 1176 // .. = linalg.matmul .. outs(%4 : ) 1177 // } 1178 // } 1179 // ``` 1180 if (destinationInitArg && 1181 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) { 1182 loops.front() 1183 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] 1184 .set(origDestinationTensors[resultNumber]); 1185 } 1186 return scf::SCFFuseProducerOfSliceResult{ 1187 fusableProducer, tileAndFuseResult->tiledValues[0], 1188 tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices}; 1189 } 1190 1191 /// Reconstruct the fused producer from within the tiled-and-fused code. 1192 FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer( 1193 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, 1194 scf::SCFFuseProducerOfSliceResult fusedProducerInfo, 1195 MutableArrayRef<LoopLikeOpInterface> loops, 1196 ArrayRef<unsigned> yieldResultNumber) { 1197 if (loops.empty()) 1198 return success(); 1199 1200 Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(), 1201 *tiledOwner = fusedProducerInfo.tiledOps[0]; 1202 1203 Location loc = originalOwner->getLoc(); 1204 // a. collect all init Value to be appended 1205 SmallVector<unsigned> initNumberList = 1206 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>( 1207 0, originalOwner->getNumResults())) 1208 : llvm::to_vector(yieldResultNumber); 1209 SmallVector<Value> initValueList; 1210 for (const auto &resultNumber : initNumberList) { 1211 FailureOr<Value> initValue = tensor::getOrCreateDestination( 1212 rewriter, loc, originalOwner->getResult(resultNumber)); 1213 if (succeeded(initValue)) { 1214 initValueList.push_back(initValue.value()); 1215 } else { 1216 return failure(); 1217 } 1218 } 1219 1220 SmallVector<Operation *> generatedSlices; 1221 YieldTiledValuesFn newYieldValuesFn = 1222 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, 1223 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, 1224 SmallVector<SmallVector<OpFoldResult>> &tiledOffset, 1225 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { 1226 OpBuilder::InsertionGuard g(innerRewriter); 1227 1228 // get sliceOp tile information 1229 SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(), 1230 sliceSizes = sliceOp.getMixedSizes(); 1231 1232 // expect all strides of sliceOp being 1 1233 if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { 1234 return !isConstantIntValue(ofr, 1); 1235 })) 1236 return failure(); 1237 1238 unsigned sliceResultNumber = 1239 fusedProducerInfo.origProducer.getResultNumber(); 1240 1241 auto tilableOp = cast<TilingInterface>(originalOwner); 1242 // b. get iterDomain Offset and Sizes based on sliceOp tile 1243 SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes; 1244 // skip tensor.pack/unpack/pad, which expects single opResult 1245 if (tilableOp->getNumResults() > 1 && 1246 failed(tilableOp.getIterationDomainTileFromResultTile( 1247 rewriter, sliceResultNumber, sliceOffset, sliceSizes, 1248 iterDomainOffset, iterDomainSizes))) { 1249 // In theory, it is unnecessary to raise an error here. Actually although 1250 // it fails to reconstruct the result tensor, it should not broke current 1251 // fusion anyway. The reason why we must return failure currently is that 1252 // the callback function `newYieldValuesFn` will be called after new init 1253 // operand(s) has already been appended. It will take more refactoring to 1254 // make sure the init operands are added consistently in the future. For 1255 // more details, please refer to: 1256 // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814 1257 return failure(); 1258 } 1259 1260 // c. calculate offsets and sizes info of all OpResults respectively based 1261 // on iteration Domain Tile 1262 SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList; 1263 for (const auto &resultNumber : initNumberList) { 1264 if (resultNumber == sliceResultNumber) { 1265 offsetList.push_back(sliceOffset); 1266 sizesList.push_back(sliceSizes); 1267 } else { 1268 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty()); 1269 // infer result tile according to the iteration domain tile 1270 SmallVector<OpFoldResult> offset, sizes; 1271 if (failed(tilableOp.getResultTilePosition( 1272 rewriter, resultNumber, iterDomainOffset, iterDomainSizes, 1273 offset, sizes))) { 1274 return failure(); 1275 } 1276 offsetList.push_back(offset); 1277 sizesList.push_back(sizes); 1278 } 1279 } 1280 1281 // d. create `extract_slice` for `iter_args` for DPS operation if necessary 1282 if (auto tiledDestStyleOp = 1283 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) { 1284 rewriter.setInsertionPoint(tiledDestStyleOp); 1285 for (const auto &&[index, newRegionArg] : 1286 llvm::enumerate(newRegionIterArgs)) { 1287 auto destSlice = rewriter.create<tensor::ExtractSliceOp>( 1288 loc, newRegionArg, offsetList[index], sizesList[index], 1289 SmallVector<OpFoldResult>(offsetList[index].size(), 1290 rewriter.getIndexAttr(1))); 1291 generatedSlices.push_back(destSlice); 1292 unsigned resultNumber = initNumberList[index]; 1293 rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { 1294 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); 1295 }); 1296 } 1297 } 1298 1299 // e. prepare tiled offset and sizes for later `insert_slice` creation by 1300 // caller 1301 Block *block = rewriter.getInsertionPoint()->getBlock(); 1302 rewriter.setInsertionPoint(block->getTerminator()); 1303 for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) { 1304 tiledResult.push_back(tiledOwner->getResult(resultNumber)); 1305 tiledOffset.emplace_back(offsetList[index]); 1306 tiledSizes.emplace_back(sizesList[index]); 1307 } 1308 return success(); 1309 }; 1310 1311 if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList, 1312 newYieldValuesFn))) { 1313 return failure(); 1314 } 1315 return generatedSlices; 1316 } 1317 1318 /// Implementation of tile consumer and fuse producer greedily. 1319 FailureOr<scf::SCFTileAndFuseResult> 1320 mlir::scf::tileConsumerAndFuseProducersUsingSCF( 1321 RewriterBase &rewriter, TilingInterface consumer, 1322 const scf::SCFTileAndFuseOptions &options) { 1323 // This transformation is only valid for ops that return values (i.e. not 1324 // valid to use with operations that have memref operands). 1325 if (!consumer->getNumResults()) { 1326 return rewriter.notifyMatchFailure( 1327 consumer, "invalid pattern for op with no results"); 1328 } 1329 1330 // 1. First tile the consumer. 1331 SetVector<Operation *> fusedProducers, tiledAndFusedOps; 1332 llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum; 1333 1334 FailureOr<scf::SCFTilingResult> tilingResult = 1335 tileUsingSCF(rewriter, consumer, options.tilingOptions); 1336 1337 if (failed(tilingResult)) 1338 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 1339 for (auto *tiledOp : tilingResult->tiledOps) 1340 tiledAndFusedOps.insert(tiledOp); 1341 1342 // If there are no loops generated, fusion is immaterial. 1343 auto &loops = tilingResult->loops; 1344 if (loops.empty()) { 1345 DenseMap<Value, Value> replacements; 1346 for (auto [origVal, replacement] : 1347 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) { 1348 replacements[origVal] = replacement; 1349 } 1350 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, 1351 replacements}; 1352 } 1353 1354 // To keep track of replacements for now just record the map from the original 1355 // untiled value to the result number of the for loop. Since the loop gets 1356 // potentially replaced during fusion, keeping the value directly wont work. 1357 DenseMap<Value, size_t> origValToResultNumber; 1358 for (auto [index, result] : llvm::enumerate(consumer->getResults())) { 1359 origValToResultNumber[result] = index; 1360 } 1361 1362 // 2. Typically, the operands of the tiled operation are slices of the 1363 // operands of the untiled operation. These are expressed in IR using 1364 // `tensor.extract_slice` operations with source being the operands of the 1365 // untiled operation. Create a worklist of these `tensor.extract_slice` 1366 // operations. If the producers of the source of the `tensor.extract_slice` 1367 // can be tiled such that the tiled value is generated in-place, that 1368 // effectively tiles + fuses the operations. 1369 struct WorklistItem { 1370 tensor::ExtractSliceOp candidateSlice; 1371 SCFTileAndFuseOptions::ControlFnResult controlFnResult; 1372 }; 1373 std::deque<WorklistItem> worklist; 1374 auto addCandidateSlices = [&worklist, &options, 1375 &loops](ArrayRef<Operation *> candidates) { 1376 for (auto candidate : candidates) { 1377 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate); 1378 if (!sliceOp || sliceOp.use_empty()) 1379 continue; 1380 1381 auto [fusableProducer, destinationInitArg] = 1382 getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops); 1383 if (!fusableProducer) 1384 continue; 1385 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult = 1386 options.fusionControlFn(sliceOp, fusableProducer, 1387 destinationInitArg.has_value()); 1388 if (!controlFnResult) 1389 continue; 1390 worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()}); 1391 } 1392 }; 1393 1394 addCandidateSlices(tilingResult->generatedSlices); 1395 OpBuilder::InsertionGuard g(rewriter); 1396 while (!worklist.empty()) { 1397 // Traverse the slices in BFS fashion. 1398 WorklistItem worklistItem = worklist.front(); 1399 worklist.pop_front(); 1400 1401 // The operands of the fused producer might themselved be slices of 1402 // values produced by operations that implement the `TilingInterface`. 1403 // Add these operations to the worklist. 1404 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult = 1405 tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice, 1406 loops); 1407 if (!fusedResult) 1408 continue; 1409 1410 if (worklistItem.controlFnResult.yieldProducerReplacement) { 1411 // Reconstruct and yield all opResult of fusableProducerOp by default. The 1412 // caller can specific which one to yield by designating optional argument 1413 // named `yieldResultNumber` of `yieldReplacementForFusedProducer`. 1414 Operation *fusableProducerOp = fusedResult->origProducer.getOwner(); 1415 FailureOr<SmallVector<Operation *>> newSlices = 1416 yieldReplacementForFusedProducer(rewriter, 1417 worklistItem.candidateSlice, 1418 fusedResult.value(), loops); 1419 if (failed(newSlices)) { 1420 return rewriter.notifyMatchFailure( 1421 fusableProducerOp, "failed to replacement value for this " 1422 "operation from within the tiled loop"); 1423 } 1424 addCandidateSlices(newSlices.value()); 1425 for (auto [index, result] : 1426 llvm::enumerate(fusableProducerOp->getResults())) { 1427 origValToResultNumber[result] = loops.front()->getNumResults() - 1428 fusableProducerOp->getNumResults() + 1429 index; 1430 } 1431 } 1432 addCandidateSlices(fusedResult->generatedSlices); 1433 if (Operation *tiledAndFusedOp = 1434 fusedResult->tiledAndFusedProducer.getDefiningOp()) { 1435 fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); 1436 tiledAndFusedOps.insert(tiledAndFusedOp); 1437 } 1438 } 1439 1440 DenseMap<Value, Value> replacements; 1441 for (auto [origVal, resultNumber] : origValToResultNumber) { 1442 replacements[origVal] = loops.front()->getResult(resultNumber); 1443 } 1444 1445 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, 1446 replacements}; 1447 } 1448 1449 //===----------------------------------------------------------------------===// 1450 // tileAndFuseConsumerUsingSCF implementation. 1451 //===----------------------------------------------------------------------===// 1452 1453 /// A utility function that checks whether the only use of the result of a 1454 /// tensor.insert_slice op is in a scf.yield op. 1455 static LogicalResult 1456 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { 1457 Value result = candidateSliceOp.getResult(); 1458 Value::use_range uses = result.getUses(); 1459 if (!llvm::hasSingleElement(uses)) { 1460 LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n"); 1461 return failure(); 1462 } 1463 OpOperand &operandUse = (*uses.begin()); 1464 Operation *userOp = operandUse.getOwner(); 1465 if (!isa<scf::YieldOp>(userOp)) { 1466 LLVM_DEBUG(llvm::dbgs() 1467 << "Expected scf.yield to be the only user, but got -> " 1468 << (*userOp)); 1469 return failure(); 1470 } 1471 if (result.getDefiningOp()->getBlock() != userOp->getBlock()) { 1472 LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to " 1473 "be in the same block\n"); 1474 return failure(); 1475 } 1476 return success(); 1477 } 1478 1479 /// Fetches the OpOperand of the only user (and use) of the value `val` which 1480 /// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns 1481 /// failure otherwise. 1482 static FailureOr<OpOperand *> getConsumerFromUses(Value val, 1483 Block *containingOpBlock) { 1484 // Step 1. Check that the value has exactly one use. 1485 if (!llvm::hasSingleElement(val.getUses())) 1486 return failure(); 1487 // Step 2. Get uses. 1488 OpOperand &operand = (*val.getUses().begin()); 1489 Operation *consumerOp = operand.getOwner(); 1490 // TODO: We have to init result of consumer before scf.for, use 1491 // DestinationStyleOpInterface to get result shape from init for now. 1492 // Add support for other op such as op has InferTypeOpInterface. 1493 if (!isa<TilingInterface>(consumerOp) || 1494 !isa<DestinationStyleOpInterface>(consumerOp)) 1495 return failure(); 1496 if (containingOpBlock != consumerOp->getBlock()) 1497 return failure(); 1498 return &operand; 1499 } 1500 1501 /// Find the perfectly nested loops outside of given loop(included) sorted from 1502 /// outer to inner. 1503 /// 1504 /// E.g. 1505 /// 1506 /// ``` 1507 /// %0 = scf.for() 1508 /// %1 = scf.for() 1509 /// %2 = scf.for() 1510 /// %3 = ... 1511 /// yield %3 1512 /// yield %2 1513 /// yield %1 1514 /// ``` 1515 /// 1516 /// This function will return three perfectly nested loops: %0 + %1 + %2, when 1517 /// target inner loop is %2. 1518 static SmallVector<scf::ForOp> 1519 getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) { 1520 SmallVector<scf::ForOp> nestLoops = {loop}; 1521 auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp()); 1522 1523 // Check if it is the ForOp that yield the result of inner loop. 1524 auto isForOpYieldResultOfInnerLoop = 1525 [](scf::ForOp outerLoop) -> LogicalResult { 1526 Block *body = outerLoop.getBody(); 1527 if (!llvm::hasSingleElement(body->without_terminator())) 1528 return failure(); 1529 auto yieldOp = cast<scf::YieldOp>(body->getTerminator()); 1530 auto innerForOp = dyn_cast<scf::ForOp>(body->front()); 1531 if (!innerForOp) 1532 return failure(); 1533 // All of innerForOp results should be yielded. 1534 return success(innerForOp->getNumResults() == yieldOp->getNumOperands()); 1535 }; 1536 1537 while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) { 1538 nestLoops.push_back(outerLoop); 1539 outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp()); 1540 } 1541 // sorted from outer to inner 1542 return {nestLoops.rbegin(), nestLoops.rend()}; 1543 } 1544 1545 /// Fetch the untiled consumer of a scf.for's result which is yielded by a 1546 /// tensor.insert_slice. This function makes the following assumptions : 1547 /// 1. tensor.insert_slice has scf.yield as its only user. 1548 /// 2. scf.for's corresponding result has only one use. 1549 static FailureOr<OpOperand *> 1550 getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) { 1551 if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) 1552 return failure(); 1553 Value sliceResult = candidateSliceOp.getResult(); 1554 // Step 1. Fetch the corresponding output. 1555 OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); 1556 unsigned resultNumber = yieldOpOperand.getOperandNumber(); 1557 // Step 2. Check containing op is scf.for. 1558 Operation *containingOp = candidateSliceOp->getParentOp(); 1559 auto forOp = dyn_cast<scf::ForOp>(containingOp); 1560 if (!forOp) 1561 return failure(); 1562 scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front(); 1563 Value resultingValue = topLevelForOp->getResult(resultNumber); 1564 1565 return getConsumerFromUses(resultingValue, topLevelForOp->getBlock()); 1566 } 1567 1568 /// Fetch the first untiled consumer of a scf.forall's result which is yielded 1569 /// by a tensor.parallel_insert_slice. 1570 static FailureOr<OpOperand *> 1571 getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) { 1572 // Step 1. Fetch the corresponding output 1573 Value sliceDest = candidateSliceOp.getDest(); 1574 auto iterArg = dyn_cast<BlockArgument>(sliceDest); 1575 if (!iterArg) 1576 return failure(); 1577 Operation *containingOp = iterArg.getOwner()->getParentOp(); 1578 if (containingOp != candidateSliceOp->getParentOp()->getParentOp()) 1579 return failure(); 1580 // Step 2. Check that the containing op is scf.forall. 1581 auto forallOp = dyn_cast<scf::ForallOp>(containingOp); 1582 if (!forallOp) 1583 return failure(); 1584 Value resultingValue = 1585 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)); 1586 1587 return getConsumerFromUses(resultingValue, containingOp->getBlock()); 1588 } 1589 1590 /// This utility currently checks whether the loop either :- 1591 /// 1. Yields exactly one result. 1592 /// 2. Has consumer op as its first user and other users to be in the same 1593 /// containing block as that of consumer op's. Currently we clone the loop op 1594 /// right before the consumer op in order to maintain a valid def-use chain. 1595 /// This utility thus helps ensuring that no invalid IR is formed due to the 1596 /// same. 1597 static LogicalResult checkAssumptionForLoop(Operation *loopOp, 1598 Operation *consumerOp) { 1599 // Check if the loop op yields one result. 1600 if (loopOp->getNumResults() == 1) 1601 return success(); 1602 // Check if the consumerOp is the first user of the loopOp and if other users 1603 // are in the same containing block as that of consumer op's. 1604 Block *parentBlock = consumerOp->getBlock(); 1605 for (Operation *userOp : loopOp->getUsers()) { 1606 if (userOp == consumerOp) 1607 continue; 1608 if (parentBlock != userOp->getBlock() || 1609 !consumerOp->isBeforeInBlock(userOp)) 1610 return failure(); 1611 } 1612 return success(); 1613 } 1614 1615 /// A utility to fetch an untiled consumer of 1616 /// tensor.insert_slice/tensor.parallel_insert_slice. 1617 static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) { 1618 if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) { 1619 return getUntiledConsumerFromSlice(insertSlice); 1620 } else if (auto parallelInsertSlice = 1621 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) { 1622 return getUntiledConsumerFromSlice(parallelInsertSlice); 1623 } else { 1624 return failure(); 1625 } 1626 } 1627 1628 /// Implementation of fusing consumer of a single slice by computing the 1629 /// slice of the consumer in-place for scf loop. 1630 FailureOr<scf::SCFFuseConsumerOfSliceResult> 1631 mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, 1632 Operation *candidateSliceOp) { 1633 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>( 1634 candidateSliceOp)) 1635 return failure(); 1636 1637 bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp); 1638 1639 // 1. Get the consumer of scf.for for the result yielded by 1640 // tensor.insert_slice/parallel_insert_slice. 1641 FailureOr<OpOperand *> maybeConsumerOpOperand = 1642 getUntiledConsumerFromSlice(candidateSliceOp); 1643 if (failed(maybeConsumerOpOperand)) { 1644 return rewriter.notifyMatchFailure(candidateSliceOp, 1645 "could not fetch consumer to fuse"); 1646 } 1647 OpOperand *consumerOpOperand = *maybeConsumerOpOperand; 1648 Operation *consumerOp = consumerOpOperand->getOwner(); 1649 unsigned operandNumber = consumerOpOperand->getOperandNumber(); 1650 unsigned resultNumber = 0; 1651 if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) { 1652 resultNumber = producerResult.getResultNumber(); 1653 } else { 1654 return rewriter.notifyMatchFailure( 1655 consumerOp, "consumer op's operand doesn't seem to be an OpResult"); 1656 } 1657 1658 // There are two possible cases regarding `oldLoopOp` here: 1659 // 1. single `scf.forall` or `scf.for`. 1660 // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the 1661 // top-level loop is the outer-most one of these nested loops. 1662 LoopLikeOpInterface innerMostLoop = 1663 candidateSliceOp->getParentOfType<LoopLikeOpInterface>(); 1664 SmallVector<LoopLikeOpInterface> nestedLoops; 1665 if (isInsertSliceOp) { 1666 nestedLoops = llvm::map_to_vector( 1667 getPerfectlyNestedLoopsOutsideOf( 1668 cast<scf::ForOp>(innerMostLoop.getOperation())), 1669 [](scf::ForOp forOp) { 1670 return cast<LoopLikeOpInterface>(forOp.getOperation()); 1671 }); 1672 } else { 1673 nestedLoops = {innerMostLoop}; 1674 } 1675 1676 LoopLikeOpInterface outerMostLoop = nestedLoops.front(); 1677 1678 if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) { 1679 return rewriter.notifyMatchFailure( 1680 outerMostLoop, 1681 "containing loop op should either yield just one value or " 1682 "have the consumer op as its first user"); 1683 } 1684 1685 OpBuilder::InsertionGuard g(rewriter); 1686 1687 // 2. Check consumer is not using scf loop's output as init. 1688 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); 1689 if (!dstOp) 1690 return rewriter.notifyMatchFailure(consumerOp, 1691 "consumer op is not DPS operation"); 1692 SmallVector<Value> dpsInits = 1693 llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; }); 1694 if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) { 1695 return rewriter.notifyMatchFailure( 1696 consumerOp, 1697 "consumer op taking the result of scf.for as init is not supported"); 1698 } 1699 SmallVector<Value> newInits = dpsInits; 1700 1701 Location loc = outerMostLoop->getLoc(); 1702 1703 // 3. Move the whole loop structure right before consumer Op, the dominance 1704 // should be already ensured by `checkAssumptionForLoop`. 1705 rewriter.moveOpBefore(outerMostLoop, consumerOp); 1706 1707 // 4. Set insertion point before terminator op of the loop and create a new 1708 // tensor.insert_slice. In the scf.for case this is a clone of the 1709 // candidateSliceOp whereas in the scf.forall case this is created from the 1710 // operands of tensor.parallel_insert_slice. 1711 tensor::InsertSliceOp clonedInsertSliceOp; 1712 if (auto sliceOp = 1713 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) { 1714 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation()); 1715 rewriter.setInsertionPoint(newForallOp.getTerminator()); 1716 clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>( 1717 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(), 1718 sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); 1719 } else { 1720 rewriter.setInsertionPoint(candidateSliceOp); 1721 clonedInsertSliceOp = 1722 cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp)); 1723 } 1724 1725 // 5.a. Clone consumer op. 1726 auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp)); 1727 1728 // 5.b. Replace all uses of the loop result with the result of the cloned 1729 // tensor.insert_slice. 1730 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber); 1731 rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { 1732 operandToReplace.set(clonedInsertSliceOp.getResult()); 1733 }); 1734 1735 // 6. Perform tiling of the cloned consumer and replace the operand at 1736 // `operandNumber` with the source of the cloned tensor.insert_slice op. 1737 auto ossSliceOp = 1738 cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation()); 1739 FailureOr<TilingResult> tileAndFuseResult = 1740 tensor::replaceInsertSliceWithTiledConsumer( 1741 rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber)); 1742 if (failed(tileAndFuseResult)) { 1743 return failure(); 1744 } 1745 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]); 1746 rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber), 1747 clonedInsertSliceOp.getSource()); 1748 1749 // 7. Reconstruct [nested] loop with new inits. 1750 YieldTiledValuesFn newYieldValuesFn = 1751 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, 1752 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, 1753 SmallVector<SmallVector<OpFoldResult>> &tiledOffset, 1754 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { 1755 OpBuilder::InsertionGuard g(innerRewriter); 1756 // 8. Set inner insertPoint right before tiled consumer op. 1757 innerRewriter.setInsertionPoint(tiledConsumerOp); 1758 1759 SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets(); 1760 SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes(); 1761 SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides(); 1762 1763 // 9. Check all insert stride is 1. 1764 if (llvm::any_of(strides, [](OpFoldResult stride) { 1765 return !isConstantIntValue(stride, 1); 1766 })) { 1767 return rewriter.notifyMatchFailure( 1768 candidateSliceOp, "containingOp's result yield with stride"); 1769 } 1770 1771 // 10. Try to get iter domain position from input position. 1772 SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes; 1773 if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile( 1774 rewriter, operandNumber, offsets, sizes, iterDomainOffsets, 1775 iterDomainSizes))) { 1776 return rewriter.notifyMatchFailure( 1777 tiledConsumerOp, 1778 "can't get iter domain position from input position"); 1779 } 1780 1781 // 11. Try to fetch the offset and size for all results of the cloned 1782 // consumer. This would then be used to form the corresponding 1783 // tensor.insert_slice/parallel_insert_slice later. 1784 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults(); 1785 SmallVector<SmallVector<OpFoldResult>> resultOffsets( 1786 totalNumResultsOfConsumer); 1787 SmallVector<SmallVector<OpFoldResult>> resultSizes( 1788 totalNumResultsOfConsumer); 1789 for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) { 1790 if (failed(tiledConsumerOp.getResultTilePosition( 1791 rewriter, idx, iterDomainOffsets, iterDomainSizes, 1792 resultOffsets[idx], resultSizes[idx]))) { 1793 return rewriter.notifyMatchFailure( 1794 tiledConsumerOp, 1795 "can't get result domain position from iter domain position"); 1796 } 1797 } 1798 1799 // 12. Create `extract_slice` for `iter_args` for DPS operation if 1800 // necessary. 1801 if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>( 1802 tiledConsumerOp.getOperation())) { 1803 rewriter.setInsertionPoint(tiledDestStyleOp); 1804 for (const auto &&[index, newRegionArg] : 1805 llvm::enumerate(newRegionIterArgs)) { 1806 auto destSlice = rewriter.create<tensor::ExtractSliceOp>( 1807 loc, newRegionArg, resultOffsets[index], resultSizes[index], 1808 SmallVector<OpFoldResult>(resultOffsets[index].size(), 1809 rewriter.getIndexAttr(1))); 1810 // Make a copy of index to avoid a capturing structured binding, which 1811 // is a C++20 extension. 1812 auto dstNumber = index; 1813 rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { 1814 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice); 1815 }); 1816 } 1817 } 1818 1819 // 13. Prepare tiled offset and sizes for later `insert_slice` creation by 1820 // caller. 1821 Block *block = rewriter.getInsertionPoint()->getBlock(); 1822 rewriter.setInsertionPoint(block->getTerminator()); 1823 for (const auto &&[index, result] : 1824 llvm::enumerate(tiledConsumerOp->getResults())) { 1825 tiledResult.push_back(result); 1826 tiledOffset.emplace_back(resultOffsets[index]); 1827 tiledSizes.emplace_back(resultSizes[index]); 1828 } 1829 return success(); 1830 }; 1831 // 14. Add new inits to [nested] loops. 1832 if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits, 1833 newYieldValuesFn))) { 1834 return rewriter.notifyMatchFailure(tiledConsumerOp, 1835 "unable to add new inits to nest loop"); 1836 } 1837 1838 // 15. Replace the result of scf loop and consumer op with new loop's results. 1839 1840 for (auto &&[oldResult, newResult] : llvm::zip( 1841 consumerOp->getResults(), 1842 nestedLoops.front()->getResults().take_back(newInits.size()))) { 1843 rewriter.replaceAllUsesWith(oldResult, newResult); 1844 } 1845 1846 // 16. Need to erase the old scf loop and the cloned consumer op. 1847 rewriter.eraseOp(clonedConsumerOp); 1848 1849 return scf::SCFFuseConsumerOfSliceResult{ 1850 consumerOpOperand, 1851 &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)), 1852 tileAndFuseResult->tiledOps}; 1853 } 1854 1855 //===----------------------------------------------------------------------===// 1856 // lowerToLoopsUsingSCFForOp implementation. 1857 //===----------------------------------------------------------------------===// 1858 1859 FailureOr<SmallVector<scf::ForOp>> 1860 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 1861 TilingInterface op) { 1862 // TODO: Handle cases where the op has results if needed. 1863 if (op->getNumResults() > 0) { 1864 return rewriter.notifyMatchFailure( 1865 op, "unable to lower to loops operations with return values"); 1866 } 1867 1868 SmallVector<Range> domain = op.getIterationDomain(rewriter); 1869 SmallVector<Value> ivs; 1870 SmallVector<scf::ForOp> loops; 1871 Location loc = op.getLoc(); 1872 for (auto loopRange : domain) { 1873 Value offsetVal = 1874 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 1875 Value sizeVal = 1876 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 1877 Value strideVal = 1878 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 1879 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 1880 strideVal, ValueRange{}); 1881 loops.push_back(loop); 1882 ivs.push_back(loop.getInductionVar()); 1883 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 1884 } 1885 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 1886 return failure(); 1887 } 1888 return loops; 1889 } 1890