1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Interfaces/TilingInterface.h" 24 #include "llvm/Support/Debug.h" 25 26 #define DEBUG_TYPE "tile-using-interface" 27 28 using namespace mlir; 29 30 scf::SCFTilingOptions & 31 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 32 assert(!tileSizeComputationFunction && "tile sizes already set"); 33 SmallVector<int64_t> tileSizes(ts.begin(), ts.end()); 34 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 35 OpBuilder::InsertionGuard guard(b); 36 b.setInsertionPointToStart( 37 &op->getParentOfType<func::FuncOp>().getBody().front()); 38 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 39 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 40 return v; 41 })); 42 }; 43 return *this; 44 } 45 46 /// Helper method to adjust the interchange vector to match the iteration 47 /// domain. 48 static SmallVector<unsigned> 49 fillInterchangeVector(ArrayRef<unsigned> interchangeVector, 50 size_t iterationDomainSize) { 51 SmallVector<unsigned> filledVector = llvm::to_vector(interchangeVector); 52 if (filledVector.size() < iterationDomainSize) { 53 auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize); 54 filledVector.append(range.begin(), range.end()); 55 } 56 if (filledVector.size() > iterationDomainSize) 57 filledVector.resize(iterationDomainSize); 58 return filledVector; 59 } 60 61 /// Helper method to apply permutation to a vector 62 template <typename T> 63 static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector, 64 ArrayRef<unsigned> interchange) { 65 assert(interchange.size() == vector.size()); 66 return llvm::to_vector( 67 llvm::map_range(interchange, [&](unsigned val) { return vector[val]; })); 68 } 69 /// Helper method to apply to invert a permutation. 70 static SmallVector<unsigned> 71 invertPermutationVector(ArrayRef<unsigned> interchange) { 72 SmallVector<unsigned> inversion(interchange.size()); 73 for (const auto &pos : llvm::enumerate(interchange)) { 74 inversion[pos.value()] = pos.index(); 75 } 76 return inversion; 77 } 78 /// Method to check if an interchange vector is a permutation. 79 static bool isPermutation(ArrayRef<unsigned> interchange) { 80 llvm::SmallDenseSet<unsigned, 4> seenVals; 81 for (auto val : interchange) { 82 if (seenVals.count(val)) 83 return false; 84 seenVals.insert(val); 85 } 86 return seenVals.size() == interchange.size(); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // tileUsingSCFForOp implementation. 91 //===----------------------------------------------------------------------===// 92 93 // Check if `stride` evenly divides the trip count `size - offset`. 94 static bool tileDividesIterationDomain(Range loopRange) { 95 Optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 96 if (!offsetAsInt) 97 return false; 98 Optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 99 if (!sizeAsInt) 100 return false; 101 Optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 102 if (!strideAsInt) 103 return false; 104 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 105 } 106 107 /// Generate an empty loop nest that represents the tiled loop nest shell. 108 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 109 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 110 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 111 /// the 112 /// tile processed within the inner most loop. 113 static SmallVector<scf::ForOp> 114 generateTileLoopNest(OpBuilder &builder, Location loc, 115 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 116 SmallVector<OpFoldResult> &offsets, 117 SmallVector<OpFoldResult> &sizes) { 118 assert(!loopRanges.empty() && "expected at least one loop range"); 119 assert(loopRanges.size() == tileSizeVals.size() && 120 "expected as many tile sizes as loop ranges"); 121 OpBuilder::InsertionGuard guard(builder); 122 SmallVector<scf::ForOp> loops; 123 offsets.resize(loopRanges.size()); 124 sizes.resize(loopRanges.size()); 125 126 // The tile size to use (to avoid out of bounds access) is minimum of 127 // `tileSize` and `ub - iv`, where `iv` is the induction variable 128 // of the tiled loop. 129 AffineExpr s0, s1, d0; 130 bindDims(builder.getContext(), d0); 131 bindSymbols(builder.getContext(), s0, s1); 132 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); 133 134 for (auto loopRange : llvm::enumerate(loopRanges)) { 135 Value offset = 136 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); 137 Value size = 138 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); 139 // No loops if tile size is zero. Set offset and size to the loop 140 // offset and size. 141 if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { 142 offsets[loopRange.index()] = offset; 143 sizes[loopRange.index()] = size; 144 continue; 145 } 146 147 auto loop = builder.create<scf::ForOp>( 148 loc, offset, size, tileSizeVals[loopRange.index()], ValueRange{}, 149 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 150 ValueRange /*iterArgs*/) { 151 bool canAvoidMap = tileDividesIterationDomain( 152 Range{loopRange.value().offset, loopRange.value().size, 153 tileSizeVals[loopRange.index()]}); 154 Value boundedTileSize = 155 (canAvoidMap) 156 ? tileSizeVals[loopRange.index()] 157 : builder.create<AffineMinOp>( 158 bodyLoc, minMap, 159 ValueRange{iv, tileSizeVals[loopRange.index()], size}); 160 sizes[loopRange.index()] = boundedTileSize; 161 builder.create<scf::YieldOp>(loc); 162 }); 163 offsets[loopRange.index()] = loop.getInductionVar(); 164 loops.push_back(loop); 165 builder.setInsertionPoint(loop.getBody()->getTerminator()); 166 } 167 return loops; 168 } 169 170 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, 171 /// construct the destructive update pattern that inserts the yielded 172 /// value into a destination tensor provided by `initValue` at offset 173 /// `tileOffsets` and size `tileSizes`. For example, 174 /// 175 /// ```mlir 176 /// scf.for %iv0 = ... { 177 /// %0 = tiled_op 178 /// } 179 /// ``` 180 /// 181 /// is transformed to 182 /// 183 /// ```mlir 184 /// scf.for %iv0 = ... iter_args(%arg = %0) { 185 /// %1 = tensor.extract_slice %arg 186 /// %2 = tiled_op 187 /// %3 = tensor.insert_slice %2 into %arg 188 /// scf.yield %3 189 /// } 190 /// ``` 191 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. 192 static FailureOr<SmallVector<Value>> 193 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, 194 ValueRange yieldedValues, 195 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 196 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 197 MutableArrayRef<scf::ForOp> loops) { 198 NewYieldValueFn yieldValueFn = 199 [&](OpBuilder &b, Location loc, 200 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 201 SmallVector<Value> inserts; 202 for (auto yieldedValue : llvm::enumerate(yieldedValues)) { 203 ArrayRef<OpFoldResult> tileOffsets = 204 tileOffsetsList[yieldedValue.index()]; 205 ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()]; 206 SmallVector<OpFoldResult> tileStrides(tileOffsets.size(), 207 b.getIndexAttr(1)); 208 Value insert = b.create<tensor::InsertSliceOp>( 209 loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], 210 tileOffsets, tileSizes, tileStrides); 211 inserts.push_back(insert); 212 } 213 return inserts; 214 }; 215 216 SmallVector<scf::ForOp> newLoops = 217 replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, 218 /*replaceIterOperandsUsesInLoop =*/false); 219 for (const auto &loop : llvm::enumerate(loops)) { 220 rewriter.eraseOp(loop.value()); 221 loops[loop.index()] = newLoops[loop.index()]; 222 } 223 return llvm::to_vector(llvm::map_range( 224 loops.front().getResults().take_back(yieldedValues.size()), 225 [](OpResult r) -> Value { return r; })); 226 } 227 228 /// If the tiled operation is destination passing style, update the 229 /// slice of the destination used (which refers to the untiled destination) 230 /// to use the corresponding region argument of the innermost loop. 231 /// 232 /// ```mlir 233 /// %0 = 234 /// scf.for %iv0 = ... iter_args(%arg = %0) { 235 /// %1 = tensor.extract_slice %0 236 /// %2 = tiled_op 237 /// %3 = tensor.insert_slice %2 into %arg 238 /// scf.yield %3 239 /// } 240 /// ``` 241 /// 242 /// is transformed to 243 /// 244 /// ```mlir 245 /// scf.for %iv0 = ... iter_args(%arg = %0) { 246 /// %1 = tensor.extract_slice %arg 247 /// %2 = tiled_op 248 /// %3 = tensor.insert_slice %2 into %arg 249 /// scf.yield %3 250 /// } 251 /// ``` 252 static void 253 updateDestinationOperandsForTiledOp(OpBuilder &builder, 254 ValueRange tiledOpDestinationValues, 255 ValueRange bbArgsList) { 256 for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) { 257 auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>(); 258 if (!sliceOp) 259 continue; 260 sliceOp.setOperand(0, bbArgsList[destValue.index()]); 261 } 262 } 263 264 /// Implementation of tiling transformation of `op` that implements the 265 /// `TilingInterface` using `scf.for` to iterate over the tiles. 266 FailureOr<scf::SCFTilingResult> 267 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, 268 scf::SCFTilingOptions options) { 269 OpBuilder::InsertionGuard guard(rewriter); 270 rewriter.setInsertionPointAfter(op); 271 272 if (!options.tileSizeComputationFunction) { 273 return rewriter.notifyMatchFailure( 274 op, "missing tile size computation function"); 275 } 276 277 // 1. Get the range of the loops that are represented by the operation. 278 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 279 size_t numLoops = iterationDomain.size(); 280 if (numLoops == 0) { 281 return rewriter.notifyMatchFailure( 282 op, "unable to tile op with no iteration domain"); 283 } 284 285 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 286 // skips tiling a particular dimension. This convention is significantly 287 // simpler to handle instead of adjusting affine maps to account for missing 288 // dimensions. 289 SmallVector<Value> tileSizeVector = 290 options.tileSizeComputationFunction(rewriter, op); 291 if (tileSizeVector.size() < iterationDomain.size()) { 292 auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 293 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 294 } 295 296 scf::SCFTilingResult tilingResult; 297 SmallVector<OpFoldResult> offsets, sizes; 298 { 299 // If there is an interchange specified, permute the iteration domain and 300 // the tile sizes. 301 SmallVector<unsigned> interchangeVector; 302 if (!options.interchangeVector.empty()) { 303 interchangeVector = fillInterchangeVector(options.interchangeVector, 304 iterationDomain.size()); 305 } 306 if (!interchangeVector.empty()) { 307 if (!isPermutation(interchangeVector)) { 308 return rewriter.notifyMatchFailure( 309 op, "invalid intechange vector, not a permutation of the entire " 310 "iteration space"); 311 } 312 313 iterationDomain = 314 applyPermutationToVector(iterationDomain, interchangeVector); 315 tileSizeVector = 316 applyPermutationToVector(tileSizeVector, interchangeVector); 317 } 318 319 // 3. Materialize an empty loop nest that iterates over the tiles. These 320 // loops for now do not return any values even if the original operation has 321 // results. 322 tilingResult.loops = generateTileLoopNest( 323 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 324 325 if (!interchangeVector.empty()) { 326 auto inversePermutation = invertPermutationVector(interchangeVector); 327 offsets = applyPermutationToVector(offsets, inversePermutation); 328 sizes = applyPermutationToVector(sizes, inversePermutation); 329 } 330 } 331 332 LLVM_DEBUG({ 333 if (!tilingResult.loops.empty()) { 334 llvm::dbgs() << "LoopNest shell :\n"; 335 tilingResult.loops.front().dump(); 336 llvm::dbgs() << "\n"; 337 } 338 }); 339 340 // 4. Generate the tiled implementation within the inner most loop. 341 if (!tilingResult.loops.empty()) 342 rewriter.setInsertionPoint( 343 tilingResult.loops.back().getBody()->getTerminator()); 344 SmallVector<Operation *> tiledImplementation = 345 op.getTiledImplementation(rewriter, offsets, sizes); 346 if (tiledImplementation.size() != 1) { 347 return rewriter.notifyMatchFailure( 348 op, "expected tiled implementation to return a single op"); 349 } 350 tilingResult.tiledOp = tiledImplementation[0]; 351 if (op->getNumResults() == 0) { 352 // nothing more to do. 353 return tilingResult; 354 } 355 356 // If loops are empty, the tiled op is used as the replacement for the untiled 357 // op. 358 if (tilingResult.loops.empty()) { 359 tilingResult.replacements = llvm::to_vector( 360 llvm::map_range(tiledImplementation[0]->getResults(), 361 [](OpResult result) -> Value { return result; })); 362 return tilingResult; 363 } 364 365 // 5. Yield all the results of the tiled operation. The surrounding loop 366 // nest is modified to insert a destructive update pattern to yield 367 // from the loop nest values to replace the untiled op with. 368 unsigned numResults = op->getNumResults(); 369 SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults), 370 resultSizesList(numResults); 371 for (auto result : llvm::enumerate(op->getResults())) { 372 if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, 373 sizes, 374 resultOffsetsList[result.index()], 375 resultSizesList[result.index()]))) { 376 return rewriter.notifyMatchFailure( 377 op, "failed to get slice of result produced"); 378 } 379 } 380 381 FailureOr<SmallVector<Value>> replacementOr = 382 yieldTiledValues(rewriter, op.getDestinationOperands(rewriter), 383 tilingResult.tiledOp->getResults(), resultOffsetsList, 384 resultSizesList, tilingResult.loops); 385 if (failed(replacementOr)) 386 return rewriter.notifyMatchFailure(op, "failed to yield replacement"); 387 if (auto tiledInterfaceOp = dyn_cast<TilingInterface>(tilingResult.tiledOp)) { 388 auto innerMostLoop = tilingResult.loops.back(); 389 updateDestinationOperandsForTiledOp( 390 rewriter, tiledInterfaceOp.getDestinationOperands(rewriter), 391 innerMostLoop.getRegionIterArgs()); 392 } 393 394 tilingResult.replacements = replacementOr.value(); 395 396 LLVM_DEBUG({ 397 if (!tilingResult.loops.empty()) { 398 llvm::dbgs() << "After tiled implementation :\n"; 399 tilingResult.loops.front().dump(); 400 llvm::dbgs() << "\n"; 401 } 402 }); 403 return tilingResult; 404 } 405 406 //===----------------------------------------------------------------------===// 407 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. 408 //===----------------------------------------------------------------------===// 409 410 /// Return the untiled producer whose slice is used in a tiled consumer. The 411 /// method traverses the tile loop nest (`loops`) if needed, and returns the 412 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 413 /// indicates that this is a destination operand of the consumer. If there was 414 /// no loop traversal needed, the second value of the returned tuple is empty. 415 static std::tuple<OpResult, Optional<OpOperand *>> 416 getUntiledProducerFromSliceSource(OpOperand *source, 417 ArrayRef<scf::ForOp> loops) { 418 Optional<OpOperand *> destinationIterArg; 419 auto loopIt = loops.rbegin(); 420 while (auto iterArg = source->get().dyn_cast<BlockArgument>()) { 421 scf::ForOp loop = *loopIt; 422 if (iterArg.getOwner()->getParentOp() != loop) 423 break; 424 source = &loop.getOpOperandForRegionIterArg(iterArg); 425 loopIt++; 426 } 427 if (loopIt == loops.rend()) 428 destinationIterArg = source; 429 return {source->get().dyn_cast<OpResult>(), destinationIterArg}; 430 } 431 432 /// Implementation of tile consumer and fuse producer greedily. 433 FailureOr<scf::SCFTileAndFuseResult> 434 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( 435 RewriterBase &rewriter, TilingInterface consumer, 436 scf::SCFTileAndFuseOptions options) { 437 // This transformation is only valid for ops that return values (i.e. not 438 // valid to use with operations that have memref operands). 439 if (!consumer->getNumResults()) { 440 return rewriter.notifyMatchFailure( 441 consumer, "invalid pattern for op with no results"); 442 } 443 444 // 1. First tile the consumer. 445 scf::SCFTileAndFuseResult tileAndFuseResult; 446 llvm::SmallDenseMap<Value, unsigned> yieldedValueToResultNumber; 447 { 448 FailureOr<scf::SCFTilingResult> tilingResult = 449 tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); 450 if (failed(tilingResult)) 451 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 452 tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp); 453 tileAndFuseResult.loops = std::move(tilingResult->loops); 454 for (auto result : llvm::enumerate( 455 llvm::zip(consumer->getResults(), tilingResult->replacements))) { 456 tileAndFuseResult.replacements[std::get<0>(result.value())] = 457 std::get<1>(result.value()); 458 yieldedValueToResultNumber[tilingResult->tiledOp->getResult( 459 result.index())] = result.index(); 460 } 461 } 462 463 // If there are no loops generated, fusion is immaterial. 464 if (tileAndFuseResult.loops.empty()) 465 return tileAndFuseResult; 466 467 // 2. Typically, the operands of the tiled operation are slices of the 468 // operands of the untiled operation. These are expressed in IR using 469 // `tensor.extract_slice` operations with source being the operands of the 470 // untiled operation. Create a worklist of these `tensor.extract_slice` 471 // operations. If the producers of the source of the `tensor.extract_slice` 472 // can be tiled such that the tiled value is generated in-place, that 473 // effectively tiles + fuses the operations. 474 auto addCandidateSlices = [](Operation *fusedOp, 475 std::deque<tensor::ExtractSliceOp> &candidates) { 476 for (Value operand : fusedOp->getOperands()) 477 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 478 candidates.push_back(sliceOp); 479 }; 480 481 std::deque<tensor::ExtractSliceOp> candidates; 482 addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 483 OpBuilder::InsertionGuard g(rewriter); 484 while (!candidates.empty()) { 485 // 2a. Traverse the slices in BFS fashion. 486 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 487 candidates.pop_front(); 488 489 // 2b. Get the producer of the source (potentially walking through 490 // `iter_args` of nested `scf.for`) 491 auto [fusableProducer, destinationIterArg] = 492 getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0), 493 tileAndFuseResult.loops); 494 if (!fusableProducer) 495 continue; 496 497 // 2c. Generate the tiled implementation of the producer of the source 498 rewriter.setInsertionPoint(candidateSliceOp); 499 FailureOr<Value> fusedProducerValue = 500 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 501 fusableProducer); 502 if (failed(fusedProducerValue)) 503 continue; 504 rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); 505 506 // 2d. The operands of the fused producer might themselved be slices of 507 // values produced by operations that implement the `TilingInterface`. 508 // Add these operations to the worklist. 509 Operation *fusedProducer = fusedProducerValue->getDefiningOp(); 510 tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer); 511 addCandidateSlices(fusedProducer, candidates); 512 513 // 2e. If the slice is for a destination operand, for example, 514 // 515 // ```mlir 516 // %0 = linalg.init 517 // %1 = linalg.fill .. outs(%0 : ) 518 // %2 = scf.for .. iter_args(%arg0 = %1) { 519 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 520 // %4 = tensor.extract_slice %arg1 [..] 521 // .. = linalg.matmul .. outs(%4 : ) 522 // } 523 // } 524 // ``` 525 // 526 // the IR is currently 527 // 528 // ``` 529 // %0 = linalg.init 530 // %1 = linalg.fill 531 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 532 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 533 // %4 = tensor.extract_slice %0 /*incorrect value */ [..] 534 // %5 = linalg.fill .. outs(%4 : ) 535 // .. = linalg.matmul .. outs(%5 : ) 536 // } 537 // } 538 // ``` 539 // 540 // The untiled `linalg.fill` is still used as the `init_value` since it 541 // was originally a destination operand of the untiled `linalg.matmul`. 542 // When fusing an operand that is a destination operand. 543 // - Update the iter_arg of the outer most loop to use the destination 544 // of the untiled producer. 545 // - Update the destination of the slice of the tiled producer generated 546 // to use the same basic block argument as the slice that was used to 547 // generate inplace the tiled implementation of the producer. 548 // With this the IR will be. 549 // 550 // ``` 551 // %0 = linalg.init 552 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 553 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 554 // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] 555 // %4 = linalg.fill .. outs(%3 : ) 556 // .. = linalg.matmul .. outs(%4 : ) 557 // } 558 // } 559 // ``` 560 // TODO: This can be modeled better if the `DestinationStyleOpInterface`. 561 // Update to use that when it does become available. 562 scf::ForOp outerMostLoop = tileAndFuseResult.loops.front(); 563 Optional<unsigned> iterArgNumber; 564 if (destinationIterArg) { 565 iterArgNumber = outerMostLoop.getIterArgNumberForOpOperand( 566 *destinationIterArg.value()); 567 } 568 if (iterArgNumber) { 569 unsigned resultNumber = fusableProducer.getResultNumber(); 570 if (auto producerOp = 571 dyn_cast<TilingInterface>(fusableProducer.getOwner())) { 572 SmallVector<Value> destination = 573 producerOp.getDestinationOperands(rewriter); 574 outerMostLoop.setIterArg(iterArgNumber.value(), 575 destination[resultNumber]); 576 } 577 if (auto tiledAndFusedInterfaceOp = 578 fusedProducerValue.value().getDefiningOp<TilingInterface>()) { 579 scf::ForOp innerMostLoop = tileAndFuseResult.loops.back(); 580 SmallVector<Value> destination = 581 tiledAndFusedInterfaceOp.getDestinationOperands(rewriter); 582 updateDestinationOperandsForTiledOp( 583 rewriter, destination[resultNumber], 584 innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); 585 } 586 } 587 } 588 return tileAndFuseResult; 589 } 590 591 //===----------------------------------------------------------------------===// 592 // lowerToLoopsUsingSCFForOp implementation. 593 //===----------------------------------------------------------------------===// 594 595 FailureOr<SmallVector<scf::ForOp>> 596 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 597 TilingInterface op) { 598 // TODO: Handle cases where the op has results if needed. 599 if (op->getNumResults() > 0) { 600 return rewriter.notifyMatchFailure( 601 op, "unable to lower to loops operations with return values"); 602 } 603 604 SmallVector<Range> domain = op.getIterationDomain(rewriter); 605 SmallVector<Value> ivs; 606 SmallVector<scf::ForOp> loops; 607 Location loc = op.getLoc(); 608 for (auto loopRange : domain) { 609 Value offsetVal = 610 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 611 Value sizeVal = 612 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 613 Value strideVal = 614 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 615 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 616 strideVal, ValueRange{}); 617 loops.push_back(loop); 618 ivs.push_back(loop.getInductionVar()); 619 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 620 } 621 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 622 return failure(); 623 } 624 return loops; 625 } 626