1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 24 #include "mlir/Interfaces/TilingInterface.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "tile-using-interface" 28 29 using namespace mlir; 30 31 scf::SCFTilingOptions & 32 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 33 assert(!tileSizeComputationFunction && "tile sizes already set"); 34 SmallVector<int64_t> tileSizes(ts.begin(), ts.end()); 35 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 36 OpBuilder::InsertionGuard guard(b); 37 b.setInsertionPointToStart( 38 &op->getParentOfType<func::FuncOp>().getBody().front()); 39 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 40 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 41 return v; 42 })); 43 }; 44 return *this; 45 } 46 47 /// Helper method to adjust the interchange vector to match the iteration 48 /// domain. 49 static SmallVector<int64_t> 50 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 51 size_t iterationDomainSize) { 52 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 53 if (filledVector.size() < iterationDomainSize) { 54 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 55 filledVector.append(range.begin(), range.end()); 56 } 57 if (filledVector.size() > iterationDomainSize) 58 filledVector.resize(iterationDomainSize); 59 return filledVector; 60 } 61 62 /// Helper method to apply permutation to a vector 63 template <typename T> 64 static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector, 65 ArrayRef<int64_t> interchange) { 66 assert(interchange.size() == vector.size()); 67 return llvm::to_vector( 68 llvm::map_range(interchange, [&](int64_t val) { return vector[val]; })); 69 } 70 /// Helper method to apply to invert a permutation. 71 static SmallVector<int64_t> 72 invertPermutationVector(ArrayRef<int64_t> interchange) { 73 SmallVector<int64_t> inversion(interchange.size()); 74 for (const auto &pos : llvm::enumerate(interchange)) { 75 inversion[pos.value()] = pos.index(); 76 } 77 return inversion; 78 } 79 /// Method to check if an interchange vector is a permutation. 80 static bool isPermutation(ArrayRef<int64_t> interchange) { 81 llvm::SmallDenseSet<int64_t, 4> seenVals; 82 for (auto val : interchange) { 83 if (seenVals.count(val)) 84 return false; 85 seenVals.insert(val); 86 } 87 return seenVals.size() == interchange.size(); 88 } 89 90 //===----------------------------------------------------------------------===// 91 // tileUsingSCFForOp implementation. 92 //===----------------------------------------------------------------------===// 93 94 // Check if `stride` evenly divides the trip count `size - offset`. 95 static bool tileDividesIterationDomain(Range loopRange) { 96 Optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 97 if (!offsetAsInt) 98 return false; 99 Optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 100 if (!sizeAsInt) 101 return false; 102 Optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 103 if (!strideAsInt) 104 return false; 105 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 106 } 107 108 /// Returns the bounded tile size given the current `iv`, `loopRange` and 109 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. 110 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 111 Range loopRange, Value iv, 112 Value tileSize) { 113 Optional<int64_t> ts = getConstantIntValue(tileSize); 114 if (ts && ts.value() == 1) 115 return getAsOpFoldResult(tileSize); 116 117 if (tileDividesIterationDomain( 118 Range{loopRange.offset, loopRange.size, tileSize})) 119 return tileSize; 120 121 // The tile size to use (to avoid out of bounds access) is minimum of 122 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 123 // loop. 124 AffineExpr s0, s1, d0; 125 bindDims(b.getContext(), d0); 126 bindSymbols(b.getContext(), s0, s1); 127 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); 128 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 129 return makeComposedFoldedAffineMin( 130 b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size}); 131 } 132 133 /// Generate an empty loop nest that represents the tiled loop nest shell. 134 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 135 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 136 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 137 /// the 138 /// tile processed within the inner most loop. 139 static SmallVector<scf::ForOp> 140 generateTileLoopNest(OpBuilder &builder, Location loc, 141 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 142 SmallVector<OpFoldResult> &offsets, 143 SmallVector<OpFoldResult> &sizes) { 144 assert(!loopRanges.empty() && "expected at least one loop range"); 145 assert(loopRanges.size() == tileSizeVals.size() && 146 "expected as many tile sizes as loop ranges"); 147 OpBuilder::InsertionGuard guard(builder); 148 SmallVector<scf::ForOp> loops; 149 offsets.resize(loopRanges.size()); 150 sizes.resize(loopRanges.size()); 151 152 for (auto loopRange : llvm::enumerate(loopRanges)) { 153 Value offset = 154 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); 155 Value size = 156 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); 157 Value tileSize = tileSizeVals[loopRange.index()]; 158 // No loops if tile size is zero. Set offset and size to the loop 159 // offset and size. 160 if (matchPattern(tileSize, m_Zero())) { 161 offsets[loopRange.index()] = offset; 162 sizes[loopRange.index()] = size; 163 continue; 164 } 165 166 auto loop = builder.create<scf::ForOp>( 167 loc, offset, size, tileSize, ValueRange{}, 168 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 169 ValueRange /*iterArgs*/) { 170 sizes[loopRange.index()] = getBoundedTileSize( 171 bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize); 172 builder.create<scf::YieldOp>(loc); 173 }); 174 offsets[loopRange.index()] = loop.getInductionVar(); 175 loops.push_back(loop); 176 builder.setInsertionPoint(loop.getBody()->getTerminator()); 177 } 178 return loops; 179 } 180 181 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, 182 /// construct the destructive update pattern that inserts the yielded 183 /// value into a destination tensor provided by `initValue` at offset 184 /// `tileOffsets` and size `tileSizes`. For example, 185 /// 186 /// ```mlir 187 /// scf.for %iv0 = ... { 188 /// %0 = tiled_op 189 /// } 190 /// ``` 191 /// 192 /// is transformed to 193 /// 194 /// ```mlir 195 /// scf.for %iv0 = ... iter_args(%arg = %0) { 196 /// %1 = tensor.extract_slice %arg 197 /// %2 = tiled_op 198 /// %3 = tensor.insert_slice %2 into %arg 199 /// scf.yield %3 200 /// } 201 /// ``` 202 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. 203 static FailureOr<SmallVector<Value>> 204 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, 205 ValueRange yieldedValues, 206 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 207 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 208 MutableArrayRef<scf::ForOp> loops) { 209 NewYieldValueFn yieldValueFn = 210 [&](OpBuilder &b, Location loc, 211 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 212 SmallVector<Value> inserts; 213 for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) { 214 ArrayRef<OpFoldResult> tileOffsets = 215 tileOffsetsList[yieldedValue.index()]; 216 ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()]; 217 SmallVector<OpFoldResult> tileStrides(tileOffsets.size(), 218 b.getIndexAttr(1)); 219 Value insert = b.create<tensor::InsertSliceOp>( 220 loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], 221 tileOffsets, tileSizes, tileStrides); 222 inserts.push_back(insert); 223 } 224 return inserts; 225 }; 226 227 SmallVector<scf::ForOp> newLoops = 228 replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, 229 /*replaceIterOperandsUsesInLoop =*/false); 230 for (const auto &loop : llvm::enumerate(loops)) { 231 rewriter.eraseOp(loop.value()); 232 loops[loop.index()] = newLoops[loop.index()]; 233 } 234 return llvm::to_vector(llvm::map_range( 235 loops.front().getResults().take_back(yieldedValues.size()), 236 [](OpResult r) -> Value { return r; })); 237 } 238 239 /// If the tiled operation is destination passing style, update the 240 /// slice of the destination used (which refers to the untiled destination) 241 /// to use the corresponding region argument of the innermost loop. 242 /// 243 /// ```mlir 244 /// %0 = 245 /// scf.for %iv0 = ... iter_args(%arg = %0) { 246 /// %1 = tensor.extract_slice %0 247 /// %2 = tiled_op 248 /// %3 = tensor.insert_slice %2 into %arg 249 /// scf.yield %3 250 /// } 251 /// ``` 252 /// 253 /// is transformed to 254 /// 255 /// ```mlir 256 /// scf.for %iv0 = ... iter_args(%arg = %0) { 257 /// %1 = tensor.extract_slice %arg 258 /// %2 = tiled_op 259 /// %3 = tensor.insert_slice %2 into %arg 260 /// scf.yield %3 261 /// } 262 /// ``` 263 static void 264 updateDestinationOperandsForTiledOp(OpBuilder &builder, 265 ValueRange tiledOpDestinationValues, 266 ValueRange bbArgsList) { 267 for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) { 268 auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>(); 269 if (!sliceOp) 270 continue; 271 sliceOp.setOperand(0, bbArgsList[destValue.index()]); 272 } 273 } 274 275 /// Implementation of tiling transformation of `op` that implements the 276 /// `TilingInterface` using `scf.for` to iterate over the tiles. 277 FailureOr<scf::SCFTilingResult> 278 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, 279 const scf::SCFTilingOptions &options) { 280 OpBuilder::InsertionGuard guard(rewriter); 281 rewriter.setInsertionPointAfter(op); 282 283 if (!options.tileSizeComputationFunction) { 284 return rewriter.notifyMatchFailure( 285 op, "missing tile size computation function"); 286 } 287 288 // Get destination tensors. 289 SmallVector<Value> destinationTensors; 290 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 291 destinationTensors))) 292 return rewriter.notifyMatchFailure(op, "failed to get destinations"); 293 294 // 1. Get the range of the loops that are represented by the operation. 295 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 296 size_t numLoops = iterationDomain.size(); 297 if (numLoops == 0) { 298 return rewriter.notifyMatchFailure( 299 op, "unable to tile op with no iteration domain"); 300 } 301 302 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 303 // skips tiling a particular dimension. This convention is significantly 304 // simpler to handle instead of adjusting affine maps to account for missing 305 // dimensions. 306 SmallVector<Value> tileSizeVector = 307 options.tileSizeComputationFunction(rewriter, op); 308 if (tileSizeVector.size() < iterationDomain.size()) { 309 auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 310 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 311 } 312 313 scf::SCFTilingResult tilingResult; 314 SmallVector<OpFoldResult> offsets, sizes; 315 { 316 // If there is an interchange specified, permute the iteration domain and 317 // the tile sizes. 318 SmallVector<int64_t> interchangeVector; 319 if (!options.interchangeVector.empty()) { 320 interchangeVector = fillInterchangeVector(options.interchangeVector, 321 iterationDomain.size()); 322 } 323 if (!interchangeVector.empty()) { 324 if (!isPermutation(interchangeVector)) { 325 return rewriter.notifyMatchFailure( 326 op, "invalid intechange vector, not a permutation of the entire " 327 "iteration space"); 328 } 329 330 iterationDomain = 331 applyPermutationToVector(iterationDomain, interchangeVector); 332 tileSizeVector = 333 applyPermutationToVector(tileSizeVector, interchangeVector); 334 } 335 336 // 3. Materialize an empty loop nest that iterates over the tiles. These 337 // loops for now do not return any values even if the original operation has 338 // results. 339 tilingResult.loops = generateTileLoopNest( 340 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 341 342 if (!interchangeVector.empty()) { 343 auto inversePermutation = invertPermutationVector(interchangeVector); 344 offsets = applyPermutationToVector(offsets, inversePermutation); 345 sizes = applyPermutationToVector(sizes, inversePermutation); 346 } 347 } 348 349 LLVM_DEBUG({ 350 if (!tilingResult.loops.empty()) { 351 llvm::dbgs() << "LoopNest shell :\n"; 352 tilingResult.loops.front().dump(); 353 llvm::dbgs() << "\n"; 354 } 355 }); 356 357 // 4. Generate the tiled implementation within the inner most loop. 358 if (!tilingResult.loops.empty()) 359 rewriter.setInsertionPoint( 360 tilingResult.loops.back().getBody()->getTerminator()); 361 SmallVector<Operation *> tiledImplementation = 362 op.getTiledImplementation(rewriter, offsets, sizes); 363 if (tiledImplementation.size() != 1) { 364 return rewriter.notifyMatchFailure( 365 op, "expected tiled implementation to return a single op"); 366 } 367 tilingResult.tiledOp = tiledImplementation[0]; 368 if (op->getNumResults() == 0) { 369 // nothing more to do. 370 return tilingResult; 371 } 372 373 // If loops are empty, the tiled op is used as the replacement for the untiled 374 // op. 375 if (tilingResult.loops.empty()) { 376 tilingResult.replacements = llvm::to_vector( 377 llvm::map_range(tiledImplementation[0]->getResults(), 378 [](OpResult result) -> Value { return result; })); 379 return tilingResult; 380 } 381 382 // 5. Yield all the results of the tiled operation. The surrounding loop 383 // nest is modified to insert a destructive update pattern to yield 384 // from the loop nest values to replace the untiled op with. 385 int64_t numResults = op->getNumResults(); 386 SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults), 387 resultSizesList(numResults); 388 for (const auto &result : llvm::enumerate(op->getResults())) { 389 if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, 390 sizes, 391 resultOffsetsList[result.index()], 392 resultSizesList[result.index()]))) { 393 return rewriter.notifyMatchFailure( 394 op, "failed to get slice of result produced"); 395 } 396 } 397 398 FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues( 399 rewriter, destinationTensors, tilingResult.tiledOp->getResults(), 400 resultOffsetsList, resultSizesList, tilingResult.loops); 401 if (failed(replacementOr)) 402 return rewriter.notifyMatchFailure(op, "failed to yield replacement"); 403 404 if (auto dstOp = 405 dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOp)) { 406 auto innerMostLoop = tilingResult.loops.back(); 407 SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands(); 408 assert(destinationTensors.size() == 409 innerMostLoop.getRegionIterArgs().size() && 410 "unexpected number of outputs"); 411 updateDestinationOperandsForTiledOp(rewriter, destinationTensors, 412 innerMostLoop.getRegionIterArgs()); 413 } 414 415 tilingResult.replacements = replacementOr.value(); 416 417 LLVM_DEBUG({ 418 if (!tilingResult.loops.empty()) { 419 llvm::dbgs() << "After tiled implementation :\n"; 420 tilingResult.loops.front().dump(); 421 llvm::dbgs() << "\n"; 422 } 423 }); 424 return tilingResult; 425 } 426 427 FailureOr<scf::SCFReductionTilingResult> 428 mlir::scf::tileReductionUsingScf(PatternRewriter &b, 429 PartialReductionOpInterface op, 430 ArrayRef<OpFoldResult> tileSize) { 431 Location loc = op.getLoc(); 432 // Ops implementing PartialReductionOpInterface are expected to implement 433 // TilingInterface. 434 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 435 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 436 SmallVector<Value> tileSizeVector = 437 getValueOrCreateConstantIndexOp(b, loc, tileSize); 438 if (tileSizeVector.size() < iterationDomain.size()) { 439 auto zero = b.create<arith::ConstantIndexOp>(loc, 0); 440 tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero); 441 } 442 if (op->getNumResults() != 1) 443 return b.notifyMatchFailure( 444 op, "don't support ops with multiple results for now"); 445 SmallVector<utils::IteratorType> iterators = 446 tilingInterfaceOp.getLoopIteratorTypes(); 447 int64_t numReductionDims = llvm::count( 448 tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction); 449 if (numReductionDims != 1) 450 return b.notifyMatchFailure( 451 op, "only support ops with one reduction dimension."); 452 int reductionDim; 453 for (auto &[idx, iteratorType] : 454 llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { 455 if (iteratorType == utils::IteratorType::reduction) { 456 reductionDim = idx; 457 break; 458 } 459 } 460 // 1. create the inital tensor value. 461 FailureOr<Operation *> identityTensor = 462 op.generateInitialTensorForPartialReduction(b, loc, tileSize, 463 reductionDim); 464 if (failed(identityTensor)) 465 return b.notifyMatchFailure(op, 466 "cannot create a tensor of identity value."); 467 // 2. Create the nested loops. 468 SmallVector<OpFoldResult> offsets, sizes; 469 SmallVector<scf::ForOp> loops = generateTileLoopNest( 470 b, loc, iterationDomain, tileSizeVector, offsets, sizes); 471 472 // 3. Generate the tiled implementation within the inner most loop. 473 b.setInsertionPoint(loops.back().getBody()->getTerminator()); 474 Operation *parallelOp = 475 op.tileToPartialReduction(b, loc, identityTensor.value()->getResults(), 476 offsets, sizes, reductionDim); 477 478 SmallVector<OpFoldResult> resultSizesList; 479 for (size_t i = 0; i < offsets.size(); i++) 480 resultSizesList.push_back( 481 b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i)); 482 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); 483 FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues( 484 b, identityTensor.value()->getResults(), parallelOp->getResults(), 485 outOffsets, resultSizesList, loops); 486 if (failed(replacementOr)) 487 return b.notifyMatchFailure(op, "failed to yield replacement"); 488 489 auto dstOp = cast<DestinationStyleOpInterface>(parallelOp); 490 auto innerMostLoop = loops.back(); 491 SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands(); 492 assert(destinationTensors.size() == 493 innerMostLoop.getRegionIterArgs().size() && 494 "unexpected number of outputs"); 495 updateDestinationOperandsForTiledOp(b, destinationTensors, 496 innerMostLoop.getRegionIterArgs()); 497 498 // 4. Apply the merge reduction to combine all the partial values. 499 b.setInsertionPointAfter(*loops.begin()); 500 Operation *mergeOp = 501 op.mergeReductions(b, loc, replacementOr.value(), reductionDim); 502 b.replaceOp(op, mergeOp->getResults()); 503 504 SCFReductionTilingResult results; 505 results.initialOp = identityTensor.value(); 506 results.loops = std::move(loops); 507 results.parallelTiledOp = parallelOp; 508 results.mergeOp = mergeOp; 509 return results; 510 } 511 //===----------------------------------------------------------------------===// 512 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. 513 //===----------------------------------------------------------------------===// 514 515 /// Return the untiled producer whose slice is used in a tiled consumer. The 516 /// method traverses the tile loop nest (`loops`) if needed, and returns the 517 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 518 /// indicates that this is a destination operand of the consumer. If there was 519 /// no loop traversal needed, the second value of the returned tuple is empty. 520 static std::tuple<OpResult, Optional<OpOperand *>> 521 getUntiledProducerFromSliceSource(OpOperand *source, 522 ArrayRef<scf::ForOp> loops) { 523 Optional<OpOperand *> destinationIterArg; 524 auto loopIt = loops.rbegin(); 525 while (auto iterArg = source->get().dyn_cast<BlockArgument>()) { 526 scf::ForOp loop = *loopIt; 527 if (iterArg.getOwner()->getParentOp() != loop) 528 break; 529 source = &loop.getOpOperandForRegionIterArg(iterArg); 530 loopIt++; 531 } 532 if (loopIt == loops.rend()) 533 destinationIterArg = source; 534 return {source->get().dyn_cast<OpResult>(), destinationIterArg}; 535 } 536 537 /// Implementation of tile consumer and fuse producer greedily. 538 FailureOr<scf::SCFTileAndFuseResult> 539 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( 540 RewriterBase &rewriter, TilingInterface consumer, 541 const scf::SCFTileAndFuseOptions &options) { 542 // This transformation is only valid for ops that return values (i.e. not 543 // valid to use with operations that have memref operands). 544 if (!consumer->getNumResults()) { 545 return rewriter.notifyMatchFailure( 546 consumer, "invalid pattern for op with no results"); 547 } 548 549 // 1. First tile the consumer. 550 scf::SCFTileAndFuseResult tileAndFuseResult; 551 llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber; 552 { 553 FailureOr<scf::SCFTilingResult> tilingResult = 554 tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); 555 if (failed(tilingResult)) 556 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 557 tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp); 558 tileAndFuseResult.loops = std::move(tilingResult->loops); 559 for (const auto &result : llvm::enumerate( 560 llvm::zip(consumer->getResults(), tilingResult->replacements))) { 561 tileAndFuseResult.replacements[std::get<0>(result.value())] = 562 std::get<1>(result.value()); 563 yieldedValueToResultNumber[tilingResult->tiledOp->getResult( 564 result.index())] = result.index(); 565 } 566 } 567 568 // If there are no loops generated, fusion is immaterial. 569 if (tileAndFuseResult.loops.empty()) 570 return tileAndFuseResult; 571 572 // 2. Typically, the operands of the tiled operation are slices of the 573 // operands of the untiled operation. These are expressed in IR using 574 // `tensor.extract_slice` operations with source being the operands of the 575 // untiled operation. Create a worklist of these `tensor.extract_slice` 576 // operations. If the producers of the source of the `tensor.extract_slice` 577 // can be tiled such that the tiled value is generated in-place, that 578 // effectively tiles + fuses the operations. 579 auto addCandidateSlices = [](Operation *fusedOp, 580 std::deque<tensor::ExtractSliceOp> &candidates) { 581 for (Value operand : fusedOp->getOperands()) 582 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 583 candidates.push_back(sliceOp); 584 }; 585 586 std::deque<tensor::ExtractSliceOp> candidates; 587 addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 588 OpBuilder::InsertionGuard g(rewriter); 589 while (!candidates.empty()) { 590 // 2a. Traverse the slices in BFS fashion. 591 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 592 candidates.pop_front(); 593 594 // 2b. Get the producer of the source (potentially walking through 595 // `iter_args` of nested `scf.for`) 596 auto [fusableProducer, destinationIterArg] = 597 getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0), 598 tileAndFuseResult.loops); 599 if (!fusableProducer) 600 continue; 601 602 // 2c. Generate the tiled implementation of the producer of the source 603 rewriter.setInsertionPoint(candidateSliceOp); 604 FailureOr<Value> fusedProducerValue = 605 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 606 fusableProducer); 607 if (failed(fusedProducerValue)) 608 continue; 609 rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); 610 611 // 2d. The operands of the fused producer might themselved be slices of 612 // values produced by operations that implement the `TilingInterface`. 613 // Add these operations to the worklist. 614 Operation *fusedProducer = fusedProducerValue->getDefiningOp(); 615 tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer); 616 addCandidateSlices(fusedProducer, candidates); 617 618 // 2e. If the slice is for a destination operand, for example, 619 // 620 // ```mlir 621 // %0 = linalg.init 622 // %1 = linalg.fill .. outs(%0 : ) 623 // %2 = scf.for .. iter_args(%arg0 = %1) { 624 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 625 // %4 = tensor.extract_slice %arg1 [..] 626 // .. = linalg.matmul .. outs(%4 : ) 627 // } 628 // } 629 // ``` 630 // 631 // the IR is currently 632 // 633 // ``` 634 // %0 = linalg.init 635 // %1 = linalg.fill 636 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 637 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 638 // %4 = tensor.extract_slice %0 /*incorrect value */ [..] 639 // %5 = linalg.fill .. outs(%4 : ) 640 // .. = linalg.matmul .. outs(%5 : ) 641 // } 642 // } 643 // ``` 644 // 645 // The untiled `linalg.fill` is still used as the `init_value` since it 646 // was originally a destination operand of the untiled `linalg.matmul`. 647 // When fusing an operand that is a destination operand. 648 // - Update the iter_arg of the outer most loop to use the destination 649 // of the untiled producer. 650 // - Update the destination of the slice of the tiled producer generated 651 // to use the same basic block argument as the slice that was used to 652 // generate inplace the tiled implementation of the producer. 653 // With this the IR will be. 654 // 655 // ``` 656 // %0 = linalg.init 657 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 658 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 659 // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] 660 // %4 = linalg.fill .. outs(%3 : ) 661 // .. = linalg.matmul .. outs(%4 : ) 662 // } 663 // } 664 // ``` 665 // TODO: This can be modeled better if the `DestinationStyleOpInterface`. 666 // Update to use that when it does become available. 667 scf::ForOp outerMostLoop = tileAndFuseResult.loops.front(); 668 Optional<unsigned> iterArgNumber; 669 if (destinationIterArg) { 670 iterArgNumber = outerMostLoop.getIterArgNumberForOpOperand( 671 *destinationIterArg.value()); 672 } 673 if (iterArgNumber) { 674 int64_t resultNumber = fusableProducer.getResultNumber(); 675 if (auto dstOp = dyn_cast<DestinationStyleOpInterface>( 676 fusableProducer.getOwner())) { 677 outerMostLoop.setIterArg( 678 iterArgNumber.value(), 679 dstOp.getTiedOpOperand(fusableProducer)->get()); 680 } 681 if (auto dstOp = fusedProducerValue.value() 682 .getDefiningOp<DestinationStyleOpInterface>()) { 683 scf::ForOp innerMostLoop = tileAndFuseResult.loops.back(); 684 updateDestinationOperandsForTiledOp( 685 rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), 686 innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); 687 } 688 } 689 } 690 return tileAndFuseResult; 691 } 692 693 //===----------------------------------------------------------------------===// 694 // lowerToLoopsUsingSCFForOp implementation. 695 //===----------------------------------------------------------------------===// 696 697 FailureOr<SmallVector<scf::ForOp>> 698 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 699 TilingInterface op) { 700 // TODO: Handle cases where the op has results if needed. 701 if (op->getNumResults() > 0) { 702 return rewriter.notifyMatchFailure( 703 op, "unable to lower to loops operations with return values"); 704 } 705 706 SmallVector<Range> domain = op.getIterationDomain(rewriter); 707 SmallVector<Value> ivs; 708 SmallVector<scf::ForOp> loops; 709 Location loc = op.getLoc(); 710 for (auto loopRange : domain) { 711 Value offsetVal = 712 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 713 Value sizeVal = 714 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 715 Value strideVal = 716 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 717 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 718 strideVal, ValueRange{}); 719 loops.push_back(loop); 720 ivs.push_back(loop.getInductionVar()); 721 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 722 } 723 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 724 return failure(); 725 } 726 return loops; 727 } 728