1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 24 #include "mlir/Interfaces/TilingInterface.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "tile-using-interface" 28 29 using namespace mlir; 30 31 scf::SCFTilingOptions & 32 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 33 assert(!tileSizeComputationFunction && "tile sizes already set"); 34 SmallVector<int64_t> tileSizes(ts.begin(), ts.end()); 35 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 36 OpBuilder::InsertionGuard guard(b); 37 b.setInsertionPointToStart( 38 &op->getParentOfType<func::FuncOp>().getBody().front()); 39 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 40 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 41 return v; 42 })); 43 }; 44 return *this; 45 } 46 47 /// Helper method to adjust the interchange vector to match the iteration 48 /// domain. 49 static SmallVector<int64_t> 50 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 51 size_t iterationDomainSize) { 52 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 53 if (filledVector.size() < iterationDomainSize) { 54 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 55 filledVector.append(range.begin(), range.end()); 56 } 57 if (filledVector.size() > iterationDomainSize) 58 filledVector.resize(iterationDomainSize); 59 return filledVector; 60 } 61 62 /// Helper method to apply permutation to a vector 63 template <typename T> 64 static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector, 65 ArrayRef<int64_t> interchange) { 66 assert(interchange.size() == vector.size()); 67 return llvm::to_vector( 68 llvm::map_range(interchange, [&](int64_t val) { return vector[val]; })); 69 } 70 /// Helper method to apply to invert a permutation. 71 static SmallVector<int64_t> 72 invertPermutationVector(ArrayRef<int64_t> interchange) { 73 SmallVector<int64_t> inversion(interchange.size()); 74 for (const auto &pos : llvm::enumerate(interchange)) { 75 inversion[pos.value()] = pos.index(); 76 } 77 return inversion; 78 } 79 /// Method to check if an interchange vector is a permutation. 80 static bool isPermutation(ArrayRef<int64_t> interchange) { 81 llvm::SmallDenseSet<int64_t, 4> seenVals; 82 for (auto val : interchange) { 83 if (seenVals.count(val)) 84 return false; 85 seenVals.insert(val); 86 } 87 return seenVals.size() == interchange.size(); 88 } 89 90 //===----------------------------------------------------------------------===// 91 // tileUsingSCFForOp implementation. 92 //===----------------------------------------------------------------------===// 93 94 // Check if `stride` evenly divides the trip count `size - offset`. 95 static bool tileDividesIterationDomain(Range loopRange) { 96 Optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 97 if (!offsetAsInt) 98 return false; 99 Optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 100 if (!sizeAsInt) 101 return false; 102 Optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 103 if (!strideAsInt) 104 return false; 105 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 106 } 107 108 /// Returns the bounded tile size given the current `iv`, `loopRange` and 109 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. 110 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 111 Range loopRange, Value iv, 112 Value tileSize) { 113 Optional<int64_t> ts = getConstantIntValue(tileSize); 114 if (ts && ts.value() == 1) 115 return getAsOpFoldResult(tileSize); 116 117 if (tileDividesIterationDomain( 118 Range{loopRange.offset, loopRange.size, tileSize})) 119 return tileSize; 120 121 // The tile size to use (to avoid out of bounds access) is minimum of 122 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 123 // loop. 124 AffineExpr s0, s1, d0; 125 bindDims(b.getContext(), d0); 126 bindSymbols(b.getContext(), s0, s1); 127 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); 128 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 129 return b.create<AffineMinOp>(loc, minMap, ValueRange{iv, tileSize, size}) 130 .getResult(); 131 } 132 133 /// Generate an empty loop nest that represents the tiled loop nest shell. 134 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 135 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 136 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 137 /// the 138 /// tile processed within the inner most loop. 139 static SmallVector<scf::ForOp> 140 generateTileLoopNest(OpBuilder &builder, Location loc, 141 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 142 SmallVector<OpFoldResult> &offsets, 143 SmallVector<OpFoldResult> &sizes) { 144 assert(!loopRanges.empty() && "expected at least one loop range"); 145 assert(loopRanges.size() == tileSizeVals.size() && 146 "expected as many tile sizes as loop ranges"); 147 OpBuilder::InsertionGuard guard(builder); 148 SmallVector<scf::ForOp> loops; 149 offsets.resize(loopRanges.size()); 150 sizes.resize(loopRanges.size()); 151 152 for (auto loopRange : llvm::enumerate(loopRanges)) { 153 Value offset = 154 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); 155 Value size = 156 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); 157 Value tileSize = tileSizeVals[loopRange.index()]; 158 // No loops if tile size is zero. Set offset and size to the loop 159 // offset and size. 160 if (matchPattern(tileSize, m_Zero())) { 161 offsets[loopRange.index()] = offset; 162 sizes[loopRange.index()] = size; 163 continue; 164 } 165 166 auto loop = builder.create<scf::ForOp>( 167 loc, offset, size, tileSize, ValueRange{}, 168 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 169 ValueRange /*iterArgs*/) { 170 sizes[loopRange.index()] = getBoundedTileSize( 171 bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize); 172 builder.create<scf::YieldOp>(loc); 173 }); 174 offsets[loopRange.index()] = loop.getInductionVar(); 175 loops.push_back(loop); 176 builder.setInsertionPoint(loop.getBody()->getTerminator()); 177 } 178 return loops; 179 } 180 181 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, 182 /// construct the destructive update pattern that inserts the yielded 183 /// value into a destination tensor provided by `initValue` at offset 184 /// `tileOffsets` and size `tileSizes`. For example, 185 /// 186 /// ```mlir 187 /// scf.for %iv0 = ... { 188 /// %0 = tiled_op 189 /// } 190 /// ``` 191 /// 192 /// is transformed to 193 /// 194 /// ```mlir 195 /// scf.for %iv0 = ... iter_args(%arg = %0) { 196 /// %1 = tensor.extract_slice %arg 197 /// %2 = tiled_op 198 /// %3 = tensor.insert_slice %2 into %arg 199 /// scf.yield %3 200 /// } 201 /// ``` 202 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. 203 static FailureOr<SmallVector<Value>> 204 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, 205 ValueRange yieldedValues, 206 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 207 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 208 MutableArrayRef<scf::ForOp> loops) { 209 NewYieldValueFn yieldValueFn = 210 [&](OpBuilder &b, Location loc, 211 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 212 SmallVector<Value> inserts; 213 for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) { 214 ArrayRef<OpFoldResult> tileOffsets = 215 tileOffsetsList[yieldedValue.index()]; 216 ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()]; 217 SmallVector<OpFoldResult> tileStrides(tileOffsets.size(), 218 b.getIndexAttr(1)); 219 Value insert = b.create<tensor::InsertSliceOp>( 220 loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], 221 tileOffsets, tileSizes, tileStrides); 222 inserts.push_back(insert); 223 } 224 return inserts; 225 }; 226 227 SmallVector<scf::ForOp> newLoops = 228 replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, 229 /*replaceIterOperandsUsesInLoop =*/false); 230 for (const auto &loop : llvm::enumerate(loops)) { 231 rewriter.eraseOp(loop.value()); 232 loops[loop.index()] = newLoops[loop.index()]; 233 } 234 return llvm::to_vector(llvm::map_range( 235 loops.front().getResults().take_back(yieldedValues.size()), 236 [](OpResult r) -> Value { return r; })); 237 } 238 239 /// If the tiled operation is destination passing style, update the 240 /// slice of the destination used (which refers to the untiled destination) 241 /// to use the corresponding region argument of the innermost loop. 242 /// 243 /// ```mlir 244 /// %0 = 245 /// scf.for %iv0 = ... iter_args(%arg = %0) { 246 /// %1 = tensor.extract_slice %0 247 /// %2 = tiled_op 248 /// %3 = tensor.insert_slice %2 into %arg 249 /// scf.yield %3 250 /// } 251 /// ``` 252 /// 253 /// is transformed to 254 /// 255 /// ```mlir 256 /// scf.for %iv0 = ... iter_args(%arg = %0) { 257 /// %1 = tensor.extract_slice %arg 258 /// %2 = tiled_op 259 /// %3 = tensor.insert_slice %2 into %arg 260 /// scf.yield %3 261 /// } 262 /// ``` 263 static void 264 updateDestinationOperandsForTiledOp(OpBuilder &builder, 265 ValueRange tiledOpDestinationValues, 266 ValueRange bbArgsList) { 267 for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) { 268 auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>(); 269 if (!sliceOp) 270 continue; 271 sliceOp.setOperand(0, bbArgsList[destValue.index()]); 272 } 273 } 274 275 /// Implementation of tiling transformation of `op` that implements the 276 /// `TilingInterface` using `scf.for` to iterate over the tiles. 277 FailureOr<scf::SCFTilingResult> 278 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, 279 const scf::SCFTilingOptions &options) { 280 OpBuilder::InsertionGuard guard(rewriter); 281 rewriter.setInsertionPointAfter(op); 282 283 if (!options.tileSizeComputationFunction) { 284 return rewriter.notifyMatchFailure( 285 op, "missing tile size computation function"); 286 } 287 288 // Get destination tensors. 289 SmallVector<Value> destinationTensors; 290 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 291 destinationTensors))) 292 return rewriter.notifyMatchFailure(op, "failed to get destinations"); 293 294 // 1. Get the range of the loops that are represented by the operation. 295 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 296 size_t numLoops = iterationDomain.size(); 297 if (numLoops == 0) { 298 return rewriter.notifyMatchFailure( 299 op, "unable to tile op with no iteration domain"); 300 } 301 302 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 303 // skips tiling a particular dimension. This convention is significantly 304 // simpler to handle instead of adjusting affine maps to account for missing 305 // dimensions. 306 SmallVector<Value> tileSizeVector = 307 options.tileSizeComputationFunction(rewriter, op); 308 if (tileSizeVector.size() < iterationDomain.size()) { 309 auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 310 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 311 } 312 313 scf::SCFTilingResult tilingResult; 314 SmallVector<OpFoldResult> offsets, sizes; 315 { 316 // If there is an interchange specified, permute the iteration domain and 317 // the tile sizes. 318 SmallVector<int64_t> interchangeVector; 319 if (!options.interchangeVector.empty()) { 320 interchangeVector = fillInterchangeVector(options.interchangeVector, 321 iterationDomain.size()); 322 } 323 if (!interchangeVector.empty()) { 324 if (!isPermutation(interchangeVector)) { 325 return rewriter.notifyMatchFailure( 326 op, "invalid intechange vector, not a permutation of the entire " 327 "iteration space"); 328 } 329 330 iterationDomain = 331 applyPermutationToVector(iterationDomain, interchangeVector); 332 tileSizeVector = 333 applyPermutationToVector(tileSizeVector, interchangeVector); 334 } 335 336 // 3. Materialize an empty loop nest that iterates over the tiles. These 337 // loops for now do not return any values even if the original operation has 338 // results. 339 tilingResult.loops = generateTileLoopNest( 340 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 341 342 if (!interchangeVector.empty()) { 343 auto inversePermutation = invertPermutationVector(interchangeVector); 344 offsets = applyPermutationToVector(offsets, inversePermutation); 345 sizes = applyPermutationToVector(sizes, inversePermutation); 346 } 347 } 348 349 LLVM_DEBUG({ 350 if (!tilingResult.loops.empty()) { 351 llvm::dbgs() << "LoopNest shell :\n"; 352 tilingResult.loops.front().dump(); 353 llvm::dbgs() << "\n"; 354 } 355 }); 356 357 // 4. Generate the tiled implementation within the inner most loop. 358 if (!tilingResult.loops.empty()) 359 rewriter.setInsertionPoint( 360 tilingResult.loops.back().getBody()->getTerminator()); 361 SmallVector<Operation *> tiledImplementation = 362 op.getTiledImplementation(rewriter, offsets, sizes); 363 if (tiledImplementation.size() != 1) { 364 return rewriter.notifyMatchFailure( 365 op, "expected tiled implementation to return a single op"); 366 } 367 tilingResult.tiledOp = tiledImplementation[0]; 368 if (op->getNumResults() == 0) { 369 // nothing more to do. 370 return tilingResult; 371 } 372 373 // If loops are empty, the tiled op is used as the replacement for the untiled 374 // op. 375 if (tilingResult.loops.empty()) { 376 tilingResult.replacements = llvm::to_vector( 377 llvm::map_range(tiledImplementation[0]->getResults(), 378 [](OpResult result) -> Value { return result; })); 379 return tilingResult; 380 } 381 382 // 5. Yield all the results of the tiled operation. The surrounding loop 383 // nest is modified to insert a destructive update pattern to yield 384 // from the loop nest values to replace the untiled op with. 385 int64_t numResults = op->getNumResults(); 386 SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults), 387 resultSizesList(numResults); 388 for (const auto &result : llvm::enumerate(op->getResults())) { 389 if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, 390 sizes, 391 resultOffsetsList[result.index()], 392 resultSizesList[result.index()]))) { 393 return rewriter.notifyMatchFailure( 394 op, "failed to get slice of result produced"); 395 } 396 } 397 398 FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues( 399 rewriter, destinationTensors, tilingResult.tiledOp->getResults(), 400 resultOffsetsList, resultSizesList, tilingResult.loops); 401 if (failed(replacementOr)) 402 return rewriter.notifyMatchFailure(op, "failed to yield replacement"); 403 404 if (auto dstOp = 405 dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOp)) { 406 auto innerMostLoop = tilingResult.loops.back(); 407 SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands(); 408 assert(destinationTensors.size() == 409 innerMostLoop.getRegionIterArgs().size() && 410 "unexpected number of outputs"); 411 updateDestinationOperandsForTiledOp(rewriter, destinationTensors, 412 innerMostLoop.getRegionIterArgs()); 413 } 414 415 tilingResult.replacements = replacementOr.value(); 416 417 LLVM_DEBUG({ 418 if (!tilingResult.loops.empty()) { 419 llvm::dbgs() << "After tiled implementation :\n"; 420 tilingResult.loops.front().dump(); 421 llvm::dbgs() << "\n"; 422 } 423 }); 424 return tilingResult; 425 } 426 427 //===----------------------------------------------------------------------===// 428 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. 429 //===----------------------------------------------------------------------===// 430 431 /// Return the untiled producer whose slice is used in a tiled consumer. The 432 /// method traverses the tile loop nest (`loops`) if needed, and returns the 433 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 434 /// indicates that this is a destination operand of the consumer. If there was 435 /// no loop traversal needed, the second value of the returned tuple is empty. 436 static std::tuple<OpResult, Optional<OpOperand *>> 437 getUntiledProducerFromSliceSource(OpOperand *source, 438 ArrayRef<scf::ForOp> loops) { 439 Optional<OpOperand *> destinationIterArg; 440 auto loopIt = loops.rbegin(); 441 while (auto iterArg = source->get().dyn_cast<BlockArgument>()) { 442 scf::ForOp loop = *loopIt; 443 if (iterArg.getOwner()->getParentOp() != loop) 444 break; 445 source = &loop.getOpOperandForRegionIterArg(iterArg); 446 loopIt++; 447 } 448 if (loopIt == loops.rend()) 449 destinationIterArg = source; 450 return {source->get().dyn_cast<OpResult>(), destinationIterArg}; 451 } 452 453 /// Implementation of tile consumer and fuse producer greedily. 454 FailureOr<scf::SCFTileAndFuseResult> 455 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( 456 RewriterBase &rewriter, TilingInterface consumer, 457 const scf::SCFTileAndFuseOptions &options) { 458 // This transformation is only valid for ops that return values (i.e. not 459 // valid to use with operations that have memref operands). 460 if (!consumer->getNumResults()) { 461 return rewriter.notifyMatchFailure( 462 consumer, "invalid pattern for op with no results"); 463 } 464 465 // 1. First tile the consumer. 466 scf::SCFTileAndFuseResult tileAndFuseResult; 467 llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber; 468 { 469 FailureOr<scf::SCFTilingResult> tilingResult = 470 tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); 471 if (failed(tilingResult)) 472 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 473 tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp); 474 tileAndFuseResult.loops = std::move(tilingResult->loops); 475 for (const auto &result : llvm::enumerate( 476 llvm::zip(consumer->getResults(), tilingResult->replacements))) { 477 tileAndFuseResult.replacements[std::get<0>(result.value())] = 478 std::get<1>(result.value()); 479 yieldedValueToResultNumber[tilingResult->tiledOp->getResult( 480 result.index())] = result.index(); 481 } 482 } 483 484 // If there are no loops generated, fusion is immaterial. 485 if (tileAndFuseResult.loops.empty()) 486 return tileAndFuseResult; 487 488 // 2. Typically, the operands of the tiled operation are slices of the 489 // operands of the untiled operation. These are expressed in IR using 490 // `tensor.extract_slice` operations with source being the operands of the 491 // untiled operation. Create a worklist of these `tensor.extract_slice` 492 // operations. If the producers of the source of the `tensor.extract_slice` 493 // can be tiled such that the tiled value is generated in-place, that 494 // effectively tiles + fuses the operations. 495 auto addCandidateSlices = [](Operation *fusedOp, 496 std::deque<tensor::ExtractSliceOp> &candidates) { 497 for (Value operand : fusedOp->getOperands()) 498 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 499 candidates.push_back(sliceOp); 500 }; 501 502 std::deque<tensor::ExtractSliceOp> candidates; 503 addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 504 OpBuilder::InsertionGuard g(rewriter); 505 while (!candidates.empty()) { 506 // 2a. Traverse the slices in BFS fashion. 507 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 508 candidates.pop_front(); 509 510 // 2b. Get the producer of the source (potentially walking through 511 // `iter_args` of nested `scf.for`) 512 auto [fusableProducer, destinationIterArg] = 513 getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0), 514 tileAndFuseResult.loops); 515 if (!fusableProducer) 516 continue; 517 518 // 2c. Generate the tiled implementation of the producer of the source 519 rewriter.setInsertionPoint(candidateSliceOp); 520 FailureOr<Value> fusedProducerValue = 521 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 522 fusableProducer); 523 if (failed(fusedProducerValue)) 524 continue; 525 rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); 526 527 // 2d. The operands of the fused producer might themselved be slices of 528 // values produced by operations that implement the `TilingInterface`. 529 // Add these operations to the worklist. 530 Operation *fusedProducer = fusedProducerValue->getDefiningOp(); 531 tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer); 532 addCandidateSlices(fusedProducer, candidates); 533 534 // 2e. If the slice is for a destination operand, for example, 535 // 536 // ```mlir 537 // %0 = linalg.init 538 // %1 = linalg.fill .. outs(%0 : ) 539 // %2 = scf.for .. iter_args(%arg0 = %1) { 540 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 541 // %4 = tensor.extract_slice %arg1 [..] 542 // .. = linalg.matmul .. outs(%4 : ) 543 // } 544 // } 545 // ``` 546 // 547 // the IR is currently 548 // 549 // ``` 550 // %0 = linalg.init 551 // %1 = linalg.fill 552 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 553 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 554 // %4 = tensor.extract_slice %0 /*incorrect value */ [..] 555 // %5 = linalg.fill .. outs(%4 : ) 556 // .. = linalg.matmul .. outs(%5 : ) 557 // } 558 // } 559 // ``` 560 // 561 // The untiled `linalg.fill` is still used as the `init_value` since it 562 // was originally a destination operand of the untiled `linalg.matmul`. 563 // When fusing an operand that is a destination operand. 564 // - Update the iter_arg of the outer most loop to use the destination 565 // of the untiled producer. 566 // - Update the destination of the slice of the tiled producer generated 567 // to use the same basic block argument as the slice that was used to 568 // generate inplace the tiled implementation of the producer. 569 // With this the IR will be. 570 // 571 // ``` 572 // %0 = linalg.init 573 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 574 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 575 // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] 576 // %4 = linalg.fill .. outs(%3 : ) 577 // .. = linalg.matmul .. outs(%4 : ) 578 // } 579 // } 580 // ``` 581 // TODO: This can be modeled better if the `DestinationStyleOpInterface`. 582 // Update to use that when it does become available. 583 scf::ForOp outerMostLoop = tileAndFuseResult.loops.front(); 584 Optional<unsigned> iterArgNumber; 585 if (destinationIterArg) { 586 iterArgNumber = outerMostLoop.getIterArgNumberForOpOperand( 587 *destinationIterArg.value()); 588 } 589 if (iterArgNumber) { 590 int64_t resultNumber = fusableProducer.getResultNumber(); 591 if (auto dstOp = dyn_cast<DestinationStyleOpInterface>( 592 fusableProducer.getOwner())) { 593 outerMostLoop.setIterArg( 594 iterArgNumber.value(), 595 dstOp.getTiedOpOperand(fusableProducer)->get()); 596 } 597 if (auto dstOp = fusedProducerValue.value() 598 .getDefiningOp<DestinationStyleOpInterface>()) { 599 scf::ForOp innerMostLoop = tileAndFuseResult.loops.back(); 600 updateDestinationOperandsForTiledOp( 601 rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), 602 innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); 603 } 604 } 605 } 606 return tileAndFuseResult; 607 } 608 609 //===----------------------------------------------------------------------===// 610 // lowerToLoopsUsingSCFForOp implementation. 611 //===----------------------------------------------------------------------===// 612 613 FailureOr<SmallVector<scf::ForOp>> 614 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 615 TilingInterface op) { 616 // TODO: Handle cases where the op has results if needed. 617 if (op->getNumResults() > 0) { 618 return rewriter.notifyMatchFailure( 619 op, "unable to lower to loops operations with return values"); 620 } 621 622 SmallVector<Range> domain = op.getIterationDomain(rewriter); 623 SmallVector<Value> ivs; 624 SmallVector<scf::ForOp> loops; 625 Location loc = op.getLoc(); 626 for (auto loopRange : domain) { 627 Value offsetVal = 628 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 629 Value sizeVal = 630 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 631 Value strideVal = 632 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 633 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 634 strideVal, ValueRange{}); 635 loops.push_back(loop); 636 ivs.push_back(loop.getInductionVar()); 637 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 638 } 639 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 640 return failure(); 641 } 642 return loops; 643 } 644