1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 24 #include "mlir/Interfaces/TilingInterface.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "tile-using-interface" 28 29 using namespace mlir; 30 31 scf::SCFTilingOptions & 32 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 33 assert(!tileSizeComputationFunction && "tile sizes already set"); 34 SmallVector<int64_t> tileSizes(ts.begin(), ts.end()); 35 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 36 OpBuilder::InsertionGuard guard(b); 37 b.setInsertionPointToStart( 38 &op->getParentOfType<func::FuncOp>().getBody().front()); 39 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 40 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 41 return v; 42 })); 43 }; 44 return *this; 45 } 46 47 /// Helper method to adjust the interchange vector to match the iteration 48 /// domain. 49 static SmallVector<int64_t> 50 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 51 size_t iterationDomainSize) { 52 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 53 if (filledVector.size() < iterationDomainSize) { 54 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 55 filledVector.append(range.begin(), range.end()); 56 } 57 if (filledVector.size() > iterationDomainSize) 58 filledVector.resize(iterationDomainSize); 59 return filledVector; 60 } 61 62 /// Helper method to apply permutation to a vector 63 template <typename T> 64 static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector, 65 ArrayRef<int64_t> interchange) { 66 assert(interchange.size() == vector.size()); 67 return llvm::to_vector( 68 llvm::map_range(interchange, [&](int64_t val) { return vector[val]; })); 69 } 70 /// Helper method to apply to invert a permutation. 71 static SmallVector<int64_t> 72 invertPermutationVector(ArrayRef<int64_t> interchange) { 73 SmallVector<int64_t> inversion(interchange.size()); 74 for (const auto &pos : llvm::enumerate(interchange)) { 75 inversion[pos.value()] = pos.index(); 76 } 77 return inversion; 78 } 79 /// Method to check if an interchange vector is a permutation. 80 static bool isPermutation(ArrayRef<int64_t> interchange) { 81 llvm::SmallDenseSet<int64_t, 4> seenVals; 82 for (auto val : interchange) { 83 if (seenVals.count(val)) 84 return false; 85 seenVals.insert(val); 86 } 87 return seenVals.size() == interchange.size(); 88 } 89 90 //===----------------------------------------------------------------------===// 91 // tileUsingSCFForOp implementation. 92 //===----------------------------------------------------------------------===// 93 94 // Check if `stride` evenly divides the trip count `size - offset`. 95 static bool tileDividesIterationDomain(Range loopRange) { 96 Optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 97 if (!offsetAsInt) 98 return false; 99 Optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 100 if (!sizeAsInt) 101 return false; 102 Optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 103 if (!strideAsInt) 104 return false; 105 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 106 } 107 108 /// Returns the bounded tile size given the current `iv`, `loopRange` and 109 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. 110 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 111 Range loopRange, Value iv, 112 Value tileSize) { 113 Optional<int64_t> ts = getConstantIntValue(tileSize); 114 if (ts && ts.value() == 1) 115 return getAsOpFoldResult(tileSize); 116 117 if (tileDividesIterationDomain( 118 Range{loopRange.offset, loopRange.size, tileSize})) 119 return tileSize; 120 121 // The tile size to use (to avoid out of bounds access) is minimum of 122 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 123 // loop. 124 AffineExpr s0, s1, d0; 125 bindDims(b.getContext(), d0); 126 bindSymbols(b.getContext(), s0, s1); 127 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); 128 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 129 return makeComposedFoldedAffineMin( 130 b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size}); 131 } 132 133 /// Generate an empty loop nest that represents the tiled loop nest shell. 134 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 135 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 136 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 137 /// the 138 /// tile processed within the inner most loop. 139 static SmallVector<scf::ForOp> 140 generateTileLoopNest(OpBuilder &builder, Location loc, 141 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 142 SmallVector<OpFoldResult> &offsets, 143 SmallVector<OpFoldResult> &sizes) { 144 assert(!loopRanges.empty() && "expected at least one loop range"); 145 assert(loopRanges.size() == tileSizeVals.size() && 146 "expected as many tile sizes as loop ranges"); 147 OpBuilder::InsertionGuard guard(builder); 148 SmallVector<scf::ForOp> loops; 149 offsets.resize(loopRanges.size()); 150 sizes.resize(loopRanges.size()); 151 152 for (auto loopRange : llvm::enumerate(loopRanges)) { 153 Value offset = 154 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); 155 Value size = 156 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); 157 Value tileSize = tileSizeVals[loopRange.index()]; 158 // No loops if tile size is zero. Set offset and size to the loop 159 // offset and size. 160 if (matchPattern(tileSize, m_Zero())) { 161 offsets[loopRange.index()] = offset; 162 sizes[loopRange.index()] = size; 163 continue; 164 } 165 166 auto loop = builder.create<scf::ForOp>( 167 loc, offset, size, tileSize, ValueRange{}, 168 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 169 ValueRange /*iterArgs*/) { 170 sizes[loopRange.index()] = getBoundedTileSize( 171 bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize); 172 builder.create<scf::YieldOp>(loc); 173 }); 174 offsets[loopRange.index()] = loop.getInductionVar(); 175 loops.push_back(loop); 176 builder.setInsertionPoint(loop.getBody()->getTerminator()); 177 } 178 return loops; 179 } 180 181 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, 182 /// construct the destructive update pattern that inserts the yielded 183 /// value into a destination tensor provided by `initValue` at offset 184 /// `tileOffsets` and size `tileSizes`. For example, 185 /// 186 /// ```mlir 187 /// scf.for %iv0 = ... { 188 /// %0 = tiled_op 189 /// } 190 /// ``` 191 /// 192 /// is transformed to 193 /// 194 /// ```mlir 195 /// scf.for %iv0 = ... iter_args(%arg = %0) { 196 /// %1 = tensor.extract_slice %arg 197 /// %2 = tiled_op 198 /// %3 = tensor.insert_slice %2 into %arg 199 /// scf.yield %3 200 /// } 201 /// ``` 202 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. 203 static FailureOr<SmallVector<Value>> 204 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, 205 ValueRange yieldedValues, 206 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 207 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 208 MutableArrayRef<scf::ForOp> loops) { 209 NewYieldValueFn yieldValueFn = 210 [&](OpBuilder &b, Location loc, 211 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 212 SmallVector<Value> inserts; 213 for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) { 214 ArrayRef<OpFoldResult> tileOffsets = 215 tileOffsetsList[yieldedValue.index()]; 216 ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()]; 217 SmallVector<OpFoldResult> tileStrides(tileOffsets.size(), 218 b.getIndexAttr(1)); 219 Value insert = b.create<tensor::InsertSliceOp>( 220 loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], 221 tileOffsets, tileSizes, tileStrides); 222 inserts.push_back(insert); 223 } 224 return inserts; 225 }; 226 227 SmallVector<scf::ForOp> newLoops = 228 replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, 229 /*replaceIterOperandsUsesInLoop =*/false); 230 for (const auto &loop : llvm::enumerate(loops)) { 231 rewriter.eraseOp(loop.value()); 232 loops[loop.index()] = newLoops[loop.index()]; 233 } 234 return llvm::to_vector(llvm::map_range( 235 loops.front().getResults().take_back(yieldedValues.size()), 236 [](OpResult r) -> Value { return r; })); 237 } 238 239 /// If the tiled operation is destination passing style, update the 240 /// slice of the destination used (which refers to the untiled destination) 241 /// to use the corresponding region argument of the innermost loop. 242 /// 243 /// ```mlir 244 /// %0 = 245 /// scf.for %iv0 = ... iter_args(%arg = %0) { 246 /// %1 = tensor.extract_slice %0 247 /// %2 = tiled_op 248 /// %3 = tensor.insert_slice %2 into %arg 249 /// scf.yield %3 250 /// } 251 /// ``` 252 /// 253 /// is transformed to 254 /// 255 /// ```mlir 256 /// scf.for %iv0 = ... iter_args(%arg = %0) { 257 /// %1 = tensor.extract_slice %arg 258 /// %2 = tiled_op 259 /// %3 = tensor.insert_slice %2 into %arg 260 /// scf.yield %3 261 /// } 262 /// ``` 263 static void 264 updateDestinationOperandsForTiledOp(OpBuilder &builder, 265 ValueRange tiledOpDestinationValues, 266 ValueRange bbArgsList) { 267 for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) { 268 auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>(); 269 if (!sliceOp) 270 continue; 271 sliceOp.setOperand(0, bbArgsList[destValue.index()]); 272 } 273 } 274 275 /// Implementation of tiling transformation of `op` that implements the 276 /// `TilingInterface` using `scf.for` to iterate over the tiles. 277 FailureOr<scf::SCFTilingResult> 278 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, 279 const scf::SCFTilingOptions &options) { 280 OpBuilder::InsertionGuard guard(rewriter); 281 rewriter.setInsertionPointAfter(op); 282 283 if (!options.tileSizeComputationFunction) { 284 return rewriter.notifyMatchFailure( 285 op, "missing tile size computation function"); 286 } 287 288 // Get destination tensors. 289 SmallVector<Value> destinationTensors; 290 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 291 destinationTensors))) 292 return rewriter.notifyMatchFailure(op, "failed to get destinations"); 293 294 // 1. Get the range of the loops that are represented by the operation. 295 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 296 size_t numLoops = iterationDomain.size(); 297 if (numLoops == 0) { 298 return rewriter.notifyMatchFailure( 299 op, "unable to tile op with no iteration domain"); 300 } 301 302 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 303 // skips tiling a particular dimension. This convention is significantly 304 // simpler to handle instead of adjusting affine maps to account for missing 305 // dimensions. 306 SmallVector<Value> tileSizeVector = 307 options.tileSizeComputationFunction(rewriter, op); 308 if (tileSizeVector.size() < iterationDomain.size()) { 309 auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 310 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 311 } 312 313 scf::SCFTilingResult tilingResult; 314 SmallVector<OpFoldResult> offsets, sizes; 315 { 316 // If there is an interchange specified, permute the iteration domain and 317 // the tile sizes. 318 SmallVector<int64_t> interchangeVector; 319 if (!options.interchangeVector.empty()) { 320 interchangeVector = fillInterchangeVector(options.interchangeVector, 321 iterationDomain.size()); 322 } 323 if (!interchangeVector.empty()) { 324 if (!isPermutation(interchangeVector)) { 325 return rewriter.notifyMatchFailure( 326 op, "invalid intechange vector, not a permutation of the entire " 327 "iteration space"); 328 } 329 330 iterationDomain = 331 applyPermutationToVector(iterationDomain, interchangeVector); 332 tileSizeVector = 333 applyPermutationToVector(tileSizeVector, interchangeVector); 334 } 335 336 // 3. Materialize an empty loop nest that iterates over the tiles. These 337 // loops for now do not return any values even if the original operation has 338 // results. 339 tilingResult.loops = generateTileLoopNest( 340 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 341 342 if (!interchangeVector.empty()) { 343 auto inversePermutation = invertPermutationVector(interchangeVector); 344 offsets = applyPermutationToVector(offsets, inversePermutation); 345 sizes = applyPermutationToVector(sizes, inversePermutation); 346 } 347 } 348 349 LLVM_DEBUG({ 350 if (!tilingResult.loops.empty()) { 351 llvm::dbgs() << "LoopNest shell :\n"; 352 tilingResult.loops.front().dump(); 353 llvm::dbgs() << "\n"; 354 } 355 }); 356 357 // 4. Generate the tiled implementation within the inner most loop. 358 if (!tilingResult.loops.empty()) 359 rewriter.setInsertionPoint( 360 tilingResult.loops.back().getBody()->getTerminator()); 361 SmallVector<Operation *> tiledImplementation = 362 op.getTiledImplementation(rewriter, offsets, sizes); 363 tilingResult.tiledOps.append(tiledImplementation); 364 if (op->getNumResults() == 0) { 365 // nothing more to do. 366 return tilingResult; 367 } 368 369 // If loops are empty, the tiled op is used as the replacement for the untiled 370 // op. 371 if (tilingResult.loops.empty()) { 372 tilingResult.replacements = llvm::to_vector( 373 llvm::map_range(tiledImplementation[0]->getResults(), 374 [](OpResult result) -> Value { return result; })); 375 return tilingResult; 376 } 377 378 // 5. Yield all the results of the tiled operation. The surrounding loop 379 // nest is modified to insert a destructive update pattern to yield 380 // from the loop nest values to replace the untiled op with. 381 int64_t numResults = op->getNumResults(); 382 SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults), 383 resultSizesList(numResults); 384 for (const auto &result : llvm::enumerate(op->getResults())) { 385 if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, 386 sizes, 387 resultOffsetsList[result.index()], 388 resultSizesList[result.index()]))) { 389 return rewriter.notifyMatchFailure( 390 op, "failed to get slice of result produced"); 391 } 392 } 393 394 FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues( 395 rewriter, destinationTensors, tilingResult.tiledOps.back()->getResults(), 396 resultOffsetsList, resultSizesList, tilingResult.loops); 397 if (failed(replacementOr)) 398 return rewriter.notifyMatchFailure(op, "failed to yield replacement"); 399 400 if (auto dstOp = 401 dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOps.back())) { 402 auto innerMostLoop = tilingResult.loops.back(); 403 SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands(); 404 assert(destinationTensors.size() == 405 innerMostLoop.getRegionIterArgs().size() && 406 "unexpected number of outputs"); 407 updateDestinationOperandsForTiledOp(rewriter, destinationTensors, 408 innerMostLoop.getRegionIterArgs()); 409 } 410 411 tilingResult.replacements = replacementOr.value(); 412 413 LLVM_DEBUG({ 414 if (!tilingResult.loops.empty()) { 415 llvm::dbgs() << "After tiled implementation :\n"; 416 tilingResult.loops.front().dump(); 417 llvm::dbgs() << "\n"; 418 } 419 }); 420 return tilingResult; 421 } 422 423 FailureOr<scf::SCFReductionTilingResult> 424 mlir::scf::tileReductionUsingScf(PatternRewriter &b, 425 PartialReductionOpInterface op, 426 ArrayRef<OpFoldResult> tileSize) { 427 Location loc = op.getLoc(); 428 // Ops implementing PartialReductionOpInterface are expected to implement 429 // TilingInterface. 430 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 431 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 432 SmallVector<Value> tileSizeVector = 433 getValueOrCreateConstantIndexOp(b, loc, tileSize); 434 if (tileSizeVector.size() < iterationDomain.size()) { 435 auto zero = b.create<arith::ConstantIndexOp>(loc, 0); 436 tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero); 437 } 438 if (op->getNumResults() != 1) 439 return b.notifyMatchFailure( 440 op, "don't support ops with multiple results for now"); 441 SmallVector<utils::IteratorType> iterators = 442 tilingInterfaceOp.getLoopIteratorTypes(); 443 int64_t numReductionDims = llvm::count( 444 tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction); 445 if (numReductionDims != 1) 446 return b.notifyMatchFailure( 447 op, "only support ops with one reduction dimension."); 448 int reductionDim; 449 for (auto &[idx, iteratorType] : 450 llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { 451 if (iteratorType == utils::IteratorType::reduction) { 452 reductionDim = idx; 453 break; 454 } 455 } 456 // 1. create the inital tensor value. 457 FailureOr<Operation *> identityTensor = 458 op.generateInitialTensorForPartialReduction(b, loc, tileSize, 459 reductionDim); 460 if (failed(identityTensor)) 461 return b.notifyMatchFailure(op, 462 "cannot create a tensor of identity value."); 463 // 2. Create the nested loops. 464 SmallVector<OpFoldResult> offsets, sizes; 465 SmallVector<scf::ForOp> loops = generateTileLoopNest( 466 b, loc, iterationDomain, tileSizeVector, offsets, sizes); 467 468 // 3. Generate the tiled implementation within the inner most loop. 469 b.setInsertionPoint(loops.back().getBody()->getTerminator()); 470 Operation *parallelOp = 471 op.tileToPartialReduction(b, loc, identityTensor.value()->getResults(), 472 offsets, sizes, reductionDim); 473 474 SmallVector<OpFoldResult> resultSizesList; 475 for (size_t i = 0; i < offsets.size(); i++) 476 resultSizesList.push_back( 477 b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i)); 478 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); 479 FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues( 480 b, identityTensor.value()->getResults(), parallelOp->getResults(), 481 outOffsets, resultSizesList, loops); 482 if (failed(replacementOr)) 483 return b.notifyMatchFailure(op, "failed to yield replacement"); 484 485 auto dstOp = cast<DestinationStyleOpInterface>(parallelOp); 486 auto innerMostLoop = loops.back(); 487 SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands(); 488 assert(destinationTensors.size() == 489 innerMostLoop.getRegionIterArgs().size() && 490 "unexpected number of outputs"); 491 updateDestinationOperandsForTiledOp(b, destinationTensors, 492 innerMostLoop.getRegionIterArgs()); 493 494 // 4. Apply the merge reduction to combine all the partial values. 495 b.setInsertionPointAfter(*loops.begin()); 496 Operation *mergeOp = 497 op.mergeReductions(b, loc, replacementOr.value(), reductionDim); 498 b.replaceOp(op, mergeOp->getResults()); 499 500 SCFReductionTilingResult results; 501 results.initialOp = identityTensor.value(); 502 results.loops = std::move(loops); 503 results.parallelTiledOp = parallelOp; 504 results.mergeOp = mergeOp; 505 return results; 506 } 507 //===----------------------------------------------------------------------===// 508 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. 509 //===----------------------------------------------------------------------===// 510 511 /// Return the untiled producer whose slice is used in a tiled consumer. The 512 /// method traverses the tile loop nest (`loops`) if needed, and returns the 513 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 514 /// indicates that this is a destination operand of the consumer. If there was 515 /// no loop traversal needed, the second value of the returned tuple is empty. 516 static std::tuple<OpResult, Optional<OpOperand *>> 517 getUntiledProducerFromSliceSource(OpOperand *source, 518 ArrayRef<scf::ForOp> loops) { 519 Optional<OpOperand *> destinationIterArg; 520 auto loopIt = loops.rbegin(); 521 while (auto iterArg = source->get().dyn_cast<BlockArgument>()) { 522 scf::ForOp loop = *loopIt; 523 if (iterArg.getOwner()->getParentOp() != loop) 524 break; 525 source = &loop.getOpOperandForRegionIterArg(iterArg); 526 loopIt++; 527 } 528 if (loopIt == loops.rend()) 529 destinationIterArg = source; 530 return {source->get().dyn_cast<OpResult>(), destinationIterArg}; 531 } 532 533 /// Implementation of tile consumer and fuse producer greedily. 534 FailureOr<scf::SCFTileAndFuseResult> 535 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( 536 RewriterBase &rewriter, TilingInterface consumer, 537 const scf::SCFTileAndFuseOptions &options) { 538 // This transformation is only valid for ops that return values (i.e. not 539 // valid to use with operations that have memref operands). 540 if (!consumer->getNumResults()) { 541 return rewriter.notifyMatchFailure( 542 consumer, "invalid pattern for op with no results"); 543 } 544 545 // 1. First tile the consumer. 546 scf::SCFTileAndFuseResult tileAndFuseResult; 547 llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber; 548 { 549 FailureOr<scf::SCFTilingResult> tilingResult = 550 tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); 551 if (failed(tilingResult)) 552 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 553 for (auto *tiledOp : tilingResult->tiledOps) 554 tileAndFuseResult.tiledAndFusedOps.insert(tiledOp); 555 tileAndFuseResult.loops = std::move(tilingResult->loops); 556 for (const auto &result : llvm::enumerate( 557 llvm::zip(consumer->getResults(), tilingResult->replacements))) { 558 tileAndFuseResult.replacements[std::get<0>(result.value())] = 559 std::get<1>(result.value()); 560 yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult( 561 result.index())] = result.index(); 562 } 563 } 564 565 // If there are no loops generated, fusion is immaterial. 566 if (tileAndFuseResult.loops.empty()) 567 return tileAndFuseResult; 568 569 // 2. Typically, the operands of the tiled operation are slices of the 570 // operands of the untiled operation. These are expressed in IR using 571 // `tensor.extract_slice` operations with source being the operands of the 572 // untiled operation. Create a worklist of these `tensor.extract_slice` 573 // operations. If the producers of the source of the `tensor.extract_slice` 574 // can be tiled such that the tiled value is generated in-place, that 575 // effectively tiles + fuses the operations. 576 auto addCandidateSlices = [](Operation *fusedOp, 577 std::deque<tensor::ExtractSliceOp> &candidates) { 578 for (Value operand : fusedOp->getOperands()) 579 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 580 candidates.push_back(sliceOp); 581 }; 582 583 std::deque<tensor::ExtractSliceOp> candidates; 584 addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 585 OpBuilder::InsertionGuard g(rewriter); 586 while (!candidates.empty()) { 587 // 2a. Traverse the slices in BFS fashion. 588 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 589 candidates.pop_front(); 590 591 // 2b. Get the producer of the source (potentially walking through 592 // `iter_args` of nested `scf.for`) 593 auto [fusableProducer, destinationIterArg] = 594 getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0), 595 tileAndFuseResult.loops); 596 if (!fusableProducer) 597 continue; 598 599 // 2c. Generate the tiled implementation of the producer of the source 600 rewriter.setInsertionPoint(candidateSliceOp); 601 FailureOr<Value> fusedProducerValue = 602 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 603 fusableProducer); 604 if (failed(fusedProducerValue)) 605 continue; 606 rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); 607 608 // 2d. The operands of the fused producer might themselved be slices of 609 // values produced by operations that implement the `TilingInterface`. 610 // Add these operations to the worklist. 611 Operation *fusedProducer = fusedProducerValue->getDefiningOp(); 612 tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer); 613 addCandidateSlices(fusedProducer, candidates); 614 615 // 2e. If the slice is for a destination operand, for example, 616 // 617 // ```mlir 618 // %0 = linalg.init 619 // %1 = linalg.fill .. outs(%0 : ) 620 // %2 = scf.for .. iter_args(%arg0 = %1) { 621 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 622 // %4 = tensor.extract_slice %arg1 [..] 623 // .. = linalg.matmul .. outs(%4 : ) 624 // } 625 // } 626 // ``` 627 // 628 // the IR is currently 629 // 630 // ``` 631 // %0 = linalg.init 632 // %1 = linalg.fill 633 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 634 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 635 // %4 = tensor.extract_slice %0 /*incorrect value */ [..] 636 // %5 = linalg.fill .. outs(%4 : ) 637 // .. = linalg.matmul .. outs(%5 : ) 638 // } 639 // } 640 // ``` 641 // 642 // The untiled `linalg.fill` is still used as the `init_value` since it 643 // was originally a destination operand of the untiled `linalg.matmul`. 644 // When fusing an operand that is a destination operand. 645 // - Update the iter_arg of the outer most loop to use the destination 646 // of the untiled producer. 647 // - Update the destination of the slice of the tiled producer generated 648 // to use the same basic block argument as the slice that was used to 649 // generate inplace the tiled implementation of the producer. 650 // With this the IR will be. 651 // 652 // ``` 653 // %0 = linalg.init 654 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 655 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 656 // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] 657 // %4 = linalg.fill .. outs(%3 : ) 658 // .. = linalg.matmul .. outs(%4 : ) 659 // } 660 // } 661 // ``` 662 // TODO: This can be modeled better if the `DestinationStyleOpInterface`. 663 // Update to use that when it does become available. 664 scf::ForOp outerMostLoop = tileAndFuseResult.loops.front(); 665 Optional<unsigned> iterArgNumber; 666 if (destinationIterArg) { 667 iterArgNumber = outerMostLoop.getIterArgNumberForOpOperand( 668 *destinationIterArg.value()); 669 } 670 if (iterArgNumber) { 671 int64_t resultNumber = fusableProducer.getResultNumber(); 672 if (auto dstOp = dyn_cast<DestinationStyleOpInterface>( 673 fusableProducer.getOwner())) { 674 outerMostLoop.setIterArg( 675 iterArgNumber.value(), 676 dstOp.getTiedOpOperand(fusableProducer)->get()); 677 } 678 if (auto dstOp = fusedProducerValue.value() 679 .getDefiningOp<DestinationStyleOpInterface>()) { 680 scf::ForOp innerMostLoop = tileAndFuseResult.loops.back(); 681 updateDestinationOperandsForTiledOp( 682 rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), 683 innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); 684 } 685 } 686 } 687 return tileAndFuseResult; 688 } 689 690 //===----------------------------------------------------------------------===// 691 // lowerToLoopsUsingSCFForOp implementation. 692 //===----------------------------------------------------------------------===// 693 694 FailureOr<SmallVector<scf::ForOp>> 695 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 696 TilingInterface op) { 697 // TODO: Handle cases where the op has results if needed. 698 if (op->getNumResults() > 0) { 699 return rewriter.notifyMatchFailure( 700 op, "unable to lower to loops operations with return values"); 701 } 702 703 SmallVector<Range> domain = op.getIterationDomain(rewriter); 704 SmallVector<Value> ivs; 705 SmallVector<scf::ForOp> loops; 706 Location loc = op.getLoc(); 707 for (auto loopRange : domain) { 708 Value offsetVal = 709 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 710 Value sizeVal = 711 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 712 Value strideVal = 713 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 714 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 715 strideVal, ValueRange{}); 716 loops.push_back(loop); 717 ivs.push_back(loop.getInductionVar()); 718 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 719 } 720 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 721 return failure(); 722 } 723 return loops; 724 } 725