1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 24 #include "mlir/Interfaces/TilingInterface.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "tile-using-interface" 28 29 using namespace mlir; 30 31 scf::SCFTilingOptions & 32 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 33 assert(!tileSizeComputationFunction && "tile sizes already set"); 34 SmallVector<int64_t> tileSizes(ts.begin(), ts.end()); 35 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 36 OpBuilder::InsertionGuard guard(b); 37 b.setInsertionPointToStart( 38 &op->getParentOfType<func::FuncOp>().getBody().front()); 39 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 40 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 41 return v; 42 })); 43 }; 44 return *this; 45 } 46 47 /// Helper method to adjust the interchange vector to match the iteration 48 /// domain. 49 static SmallVector<int64_t> 50 fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 51 size_t iterationDomainSize) { 52 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 53 if (filledVector.size() < iterationDomainSize) { 54 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 55 filledVector.append(range.begin(), range.end()); 56 } 57 if (filledVector.size() > iterationDomainSize) 58 filledVector.resize(iterationDomainSize); 59 return filledVector; 60 } 61 62 /// Helper method to apply permutation to a vector 63 template <typename T> 64 static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector, 65 ArrayRef<int64_t> interchange) { 66 assert(interchange.size() == vector.size()); 67 return llvm::to_vector( 68 llvm::map_range(interchange, [&](int64_t val) { return vector[val]; })); 69 } 70 /// Helper method to apply to invert a permutation. 71 static SmallVector<int64_t> 72 invertPermutationVector(ArrayRef<int64_t> interchange) { 73 SmallVector<int64_t> inversion(interchange.size()); 74 for (const auto &pos : llvm::enumerate(interchange)) { 75 inversion[pos.value()] = pos.index(); 76 } 77 return inversion; 78 } 79 /// Method to check if an interchange vector is a permutation. 80 static bool isPermutation(ArrayRef<int64_t> interchange) { 81 llvm::SmallDenseSet<int64_t, 4> seenVals; 82 for (auto val : interchange) { 83 if (seenVals.count(val)) 84 return false; 85 seenVals.insert(val); 86 } 87 return seenVals.size() == interchange.size(); 88 } 89 90 //===----------------------------------------------------------------------===// 91 // tileUsingSCFForOp implementation. 92 //===----------------------------------------------------------------------===// 93 94 // Check if `stride` evenly divides the trip count `size - offset`. 95 static bool tileDividesIterationDomain(Range loopRange) { 96 Optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 97 if (!offsetAsInt) 98 return false; 99 Optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 100 if (!sizeAsInt) 101 return false; 102 Optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 103 if (!strideAsInt) 104 return false; 105 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 106 } 107 108 /// Generate an empty loop nest that represents the tiled loop nest shell. 109 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 110 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 111 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 112 /// the 113 /// tile processed within the inner most loop. 114 static SmallVector<scf::ForOp> 115 generateTileLoopNest(OpBuilder &builder, Location loc, 116 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 117 SmallVector<OpFoldResult> &offsets, 118 SmallVector<OpFoldResult> &sizes) { 119 assert(!loopRanges.empty() && "expected at least one loop range"); 120 assert(loopRanges.size() == tileSizeVals.size() && 121 "expected as many tile sizes as loop ranges"); 122 OpBuilder::InsertionGuard guard(builder); 123 SmallVector<scf::ForOp> loops; 124 offsets.resize(loopRanges.size()); 125 sizes.resize(loopRanges.size()); 126 127 // The tile size to use (to avoid out of bounds access) is minimum of 128 // `tileSize` and `ub - iv`, where `iv` is the induction variable 129 // of the tiled loop. 130 AffineExpr s0, s1, d0; 131 bindDims(builder.getContext(), d0); 132 bindSymbols(builder.getContext(), s0, s1); 133 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); 134 135 for (auto loopRange : llvm::enumerate(loopRanges)) { 136 Value offset = 137 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); 138 Value size = 139 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); 140 // No loops if tile size is zero. Set offset and size to the loop 141 // offset and size. 142 if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { 143 offsets[loopRange.index()] = offset; 144 sizes[loopRange.index()] = size; 145 continue; 146 } 147 148 auto loop = builder.create<scf::ForOp>( 149 loc, offset, size, tileSizeVals[loopRange.index()], ValueRange{}, 150 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 151 ValueRange /*iterArgs*/) { 152 bool canAvoidMap = tileDividesIterationDomain( 153 Range{loopRange.value().offset, loopRange.value().size, 154 tileSizeVals[loopRange.index()]}); 155 Value boundedTileSize = 156 (canAvoidMap) 157 ? tileSizeVals[loopRange.index()] 158 : builder.create<AffineMinOp>( 159 bodyLoc, minMap, 160 ValueRange{iv, tileSizeVals[loopRange.index()], size}); 161 sizes[loopRange.index()] = boundedTileSize; 162 builder.create<scf::YieldOp>(loc); 163 }); 164 offsets[loopRange.index()] = loop.getInductionVar(); 165 loops.push_back(loop); 166 builder.setInsertionPoint(loop.getBody()->getTerminator()); 167 } 168 return loops; 169 } 170 171 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, 172 /// construct the destructive update pattern that inserts the yielded 173 /// value into a destination tensor provided by `initValue` at offset 174 /// `tileOffsets` and size `tileSizes`. For example, 175 /// 176 /// ```mlir 177 /// scf.for %iv0 = ... { 178 /// %0 = tiled_op 179 /// } 180 /// ``` 181 /// 182 /// is transformed to 183 /// 184 /// ```mlir 185 /// scf.for %iv0 = ... iter_args(%arg = %0) { 186 /// %1 = tensor.extract_slice %arg 187 /// %2 = tiled_op 188 /// %3 = tensor.insert_slice %2 into %arg 189 /// scf.yield %3 190 /// } 191 /// ``` 192 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. 193 static FailureOr<SmallVector<Value>> 194 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, 195 ValueRange yieldedValues, 196 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, 197 ArrayRef<SmallVector<OpFoldResult>> tileSizesList, 198 MutableArrayRef<scf::ForOp> loops) { 199 NewYieldValueFn yieldValueFn = 200 [&](OpBuilder &b, Location loc, 201 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 202 SmallVector<Value> inserts; 203 for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) { 204 ArrayRef<OpFoldResult> tileOffsets = 205 tileOffsetsList[yieldedValue.index()]; 206 ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()]; 207 SmallVector<OpFoldResult> tileStrides(tileOffsets.size(), 208 b.getIndexAttr(1)); 209 Value insert = b.create<tensor::InsertSliceOp>( 210 loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], 211 tileOffsets, tileSizes, tileStrides); 212 inserts.push_back(insert); 213 } 214 return inserts; 215 }; 216 217 SmallVector<scf::ForOp> newLoops = 218 replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, 219 /*replaceIterOperandsUsesInLoop =*/false); 220 for (const auto &loop : llvm::enumerate(loops)) { 221 rewriter.eraseOp(loop.value()); 222 loops[loop.index()] = newLoops[loop.index()]; 223 } 224 return llvm::to_vector(llvm::map_range( 225 loops.front().getResults().take_back(yieldedValues.size()), 226 [](OpResult r) -> Value { return r; })); 227 } 228 229 /// If the tiled operation is destination passing style, update the 230 /// slice of the destination used (which refers to the untiled destination) 231 /// to use the corresponding region argument of the innermost loop. 232 /// 233 /// ```mlir 234 /// %0 = 235 /// scf.for %iv0 = ... iter_args(%arg = %0) { 236 /// %1 = tensor.extract_slice %0 237 /// %2 = tiled_op 238 /// %3 = tensor.insert_slice %2 into %arg 239 /// scf.yield %3 240 /// } 241 /// ``` 242 /// 243 /// is transformed to 244 /// 245 /// ```mlir 246 /// scf.for %iv0 = ... iter_args(%arg = %0) { 247 /// %1 = tensor.extract_slice %arg 248 /// %2 = tiled_op 249 /// %3 = tensor.insert_slice %2 into %arg 250 /// scf.yield %3 251 /// } 252 /// ``` 253 static void 254 updateDestinationOperandsForTiledOp(OpBuilder &builder, 255 ValueRange tiledOpDestinationValues, 256 ValueRange bbArgsList) { 257 for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) { 258 auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>(); 259 if (!sliceOp) 260 continue; 261 sliceOp.setOperand(0, bbArgsList[destValue.index()]); 262 } 263 } 264 265 /// Implementation of tiling transformation of `op` that implements the 266 /// `TilingInterface` using `scf.for` to iterate over the tiles. 267 FailureOr<scf::SCFTilingResult> 268 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, 269 const scf::SCFTilingOptions &options) { 270 OpBuilder::InsertionGuard guard(rewriter); 271 rewriter.setInsertionPointAfter(op); 272 273 if (!options.tileSizeComputationFunction) { 274 return rewriter.notifyMatchFailure( 275 op, "missing tile size computation function"); 276 } 277 278 // Get destination tensors. 279 SmallVector<Value> destinationTensors; 280 if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, 281 destinationTensors))) 282 return rewriter.notifyMatchFailure(op, "failed to get destinations"); 283 284 // 1. Get the range of the loops that are represented by the operation. 285 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 286 size_t numLoops = iterationDomain.size(); 287 if (numLoops == 0) { 288 return rewriter.notifyMatchFailure( 289 op, "unable to tile op with no iteration domain"); 290 } 291 292 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 293 // skips tiling a particular dimension. This convention is significantly 294 // simpler to handle instead of adjusting affine maps to account for missing 295 // dimensions. 296 SmallVector<Value> tileSizeVector = 297 options.tileSizeComputationFunction(rewriter, op); 298 if (tileSizeVector.size() < iterationDomain.size()) { 299 auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 300 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 301 } 302 303 scf::SCFTilingResult tilingResult; 304 SmallVector<OpFoldResult> offsets, sizes; 305 { 306 // If there is an interchange specified, permute the iteration domain and 307 // the tile sizes. 308 SmallVector<int64_t> interchangeVector; 309 if (!options.interchangeVector.empty()) { 310 interchangeVector = fillInterchangeVector(options.interchangeVector, 311 iterationDomain.size()); 312 } 313 if (!interchangeVector.empty()) { 314 if (!isPermutation(interchangeVector)) { 315 return rewriter.notifyMatchFailure( 316 op, "invalid intechange vector, not a permutation of the entire " 317 "iteration space"); 318 } 319 320 iterationDomain = 321 applyPermutationToVector(iterationDomain, interchangeVector); 322 tileSizeVector = 323 applyPermutationToVector(tileSizeVector, interchangeVector); 324 } 325 326 // 3. Materialize an empty loop nest that iterates over the tiles. These 327 // loops for now do not return any values even if the original operation has 328 // results. 329 tilingResult.loops = generateTileLoopNest( 330 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 331 332 if (!interchangeVector.empty()) { 333 auto inversePermutation = invertPermutationVector(interchangeVector); 334 offsets = applyPermutationToVector(offsets, inversePermutation); 335 sizes = applyPermutationToVector(sizes, inversePermutation); 336 } 337 } 338 339 LLVM_DEBUG({ 340 if (!tilingResult.loops.empty()) { 341 llvm::dbgs() << "LoopNest shell :\n"; 342 tilingResult.loops.front().dump(); 343 llvm::dbgs() << "\n"; 344 } 345 }); 346 347 // 4. Generate the tiled implementation within the inner most loop. 348 if (!tilingResult.loops.empty()) 349 rewriter.setInsertionPoint( 350 tilingResult.loops.back().getBody()->getTerminator()); 351 SmallVector<Operation *> tiledImplementation = 352 op.getTiledImplementation(rewriter, offsets, sizes); 353 if (tiledImplementation.size() != 1) { 354 return rewriter.notifyMatchFailure( 355 op, "expected tiled implementation to return a single op"); 356 } 357 tilingResult.tiledOp = tiledImplementation[0]; 358 if (op->getNumResults() == 0) { 359 // nothing more to do. 360 return tilingResult; 361 } 362 363 // If loops are empty, the tiled op is used as the replacement for the untiled 364 // op. 365 if (tilingResult.loops.empty()) { 366 tilingResult.replacements = llvm::to_vector( 367 llvm::map_range(tiledImplementation[0]->getResults(), 368 [](OpResult result) -> Value { return result; })); 369 return tilingResult; 370 } 371 372 // 5. Yield all the results of the tiled operation. The surrounding loop 373 // nest is modified to insert a destructive update pattern to yield 374 // from the loop nest values to replace the untiled op with. 375 int64_t numResults = op->getNumResults(); 376 SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults), 377 resultSizesList(numResults); 378 for (const auto &result : llvm::enumerate(op->getResults())) { 379 if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, 380 sizes, 381 resultOffsetsList[result.index()], 382 resultSizesList[result.index()]))) { 383 return rewriter.notifyMatchFailure( 384 op, "failed to get slice of result produced"); 385 } 386 } 387 388 FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues( 389 rewriter, destinationTensors, tilingResult.tiledOp->getResults(), 390 resultOffsetsList, resultSizesList, tilingResult.loops); 391 if (failed(replacementOr)) 392 return rewriter.notifyMatchFailure(op, "failed to yield replacement"); 393 394 if (auto dstOp = 395 dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOp)) { 396 auto innerMostLoop = tilingResult.loops.back(); 397 SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands(); 398 assert(destinationTensors.size() == 399 innerMostLoop.getRegionIterArgs().size() && 400 "unexpected number of outputs"); 401 updateDestinationOperandsForTiledOp(rewriter, destinationTensors, 402 innerMostLoop.getRegionIterArgs()); 403 } 404 405 tilingResult.replacements = replacementOr.value(); 406 407 LLVM_DEBUG({ 408 if (!tilingResult.loops.empty()) { 409 llvm::dbgs() << "After tiled implementation :\n"; 410 tilingResult.loops.front().dump(); 411 llvm::dbgs() << "\n"; 412 } 413 }); 414 return tilingResult; 415 } 416 417 //===----------------------------------------------------------------------===// 418 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. 419 //===----------------------------------------------------------------------===// 420 421 /// Return the untiled producer whose slice is used in a tiled consumer. The 422 /// method traverses the tile loop nest (`loops`) if needed, and returns the 423 /// `iter_args` of the outer most that is encountered. Traversing the iter_args 424 /// indicates that this is a destination operand of the consumer. If there was 425 /// no loop traversal needed, the second value of the returned tuple is empty. 426 static std::tuple<OpResult, Optional<OpOperand *>> 427 getUntiledProducerFromSliceSource(OpOperand *source, 428 ArrayRef<scf::ForOp> loops) { 429 Optional<OpOperand *> destinationIterArg; 430 auto loopIt = loops.rbegin(); 431 while (auto iterArg = source->get().dyn_cast<BlockArgument>()) { 432 scf::ForOp loop = *loopIt; 433 if (iterArg.getOwner()->getParentOp() != loop) 434 break; 435 source = &loop.getOpOperandForRegionIterArg(iterArg); 436 loopIt++; 437 } 438 if (loopIt == loops.rend()) 439 destinationIterArg = source; 440 return {source->get().dyn_cast<OpResult>(), destinationIterArg}; 441 } 442 443 /// Implementation of tile consumer and fuse producer greedily. 444 FailureOr<scf::SCFTileAndFuseResult> 445 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( 446 RewriterBase &rewriter, TilingInterface consumer, 447 const scf::SCFTileAndFuseOptions &options) { 448 // This transformation is only valid for ops that return values (i.e. not 449 // valid to use with operations that have memref operands). 450 if (!consumer->getNumResults()) { 451 return rewriter.notifyMatchFailure( 452 consumer, "invalid pattern for op with no results"); 453 } 454 455 // 1. First tile the consumer. 456 scf::SCFTileAndFuseResult tileAndFuseResult; 457 llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber; 458 { 459 FailureOr<scf::SCFTilingResult> tilingResult = 460 tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); 461 if (failed(tilingResult)) 462 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 463 tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp); 464 tileAndFuseResult.loops = std::move(tilingResult->loops); 465 for (const auto &result : llvm::enumerate( 466 llvm::zip(consumer->getResults(), tilingResult->replacements))) { 467 tileAndFuseResult.replacements[std::get<0>(result.value())] = 468 std::get<1>(result.value()); 469 yieldedValueToResultNumber[tilingResult->tiledOp->getResult( 470 result.index())] = result.index(); 471 } 472 } 473 474 // If there are no loops generated, fusion is immaterial. 475 if (tileAndFuseResult.loops.empty()) 476 return tileAndFuseResult; 477 478 // 2. Typically, the operands of the tiled operation are slices of the 479 // operands of the untiled operation. These are expressed in IR using 480 // `tensor.extract_slice` operations with source being the operands of the 481 // untiled operation. Create a worklist of these `tensor.extract_slice` 482 // operations. If the producers of the source of the `tensor.extract_slice` 483 // can be tiled such that the tiled value is generated in-place, that 484 // effectively tiles + fuses the operations. 485 auto addCandidateSlices = [](Operation *fusedOp, 486 std::deque<tensor::ExtractSliceOp> &candidates) { 487 for (Value operand : fusedOp->getOperands()) 488 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 489 candidates.push_back(sliceOp); 490 }; 491 492 std::deque<tensor::ExtractSliceOp> candidates; 493 addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 494 OpBuilder::InsertionGuard g(rewriter); 495 while (!candidates.empty()) { 496 // 2a. Traverse the slices in BFS fashion. 497 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 498 candidates.pop_front(); 499 500 // 2b. Get the producer of the source (potentially walking through 501 // `iter_args` of nested `scf.for`) 502 auto [fusableProducer, destinationIterArg] = 503 getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0), 504 tileAndFuseResult.loops); 505 if (!fusableProducer) 506 continue; 507 508 // 2c. Generate the tiled implementation of the producer of the source 509 rewriter.setInsertionPoint(candidateSliceOp); 510 FailureOr<Value> fusedProducerValue = 511 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 512 fusableProducer); 513 if (failed(fusedProducerValue)) 514 continue; 515 rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); 516 517 // 2d. The operands of the fused producer might themselved be slices of 518 // values produced by operations that implement the `TilingInterface`. 519 // Add these operations to the worklist. 520 Operation *fusedProducer = fusedProducerValue->getDefiningOp(); 521 tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer); 522 addCandidateSlices(fusedProducer, candidates); 523 524 // 2e. If the slice is for a destination operand, for example, 525 // 526 // ```mlir 527 // %0 = linalg.init 528 // %1 = linalg.fill .. outs(%0 : ) 529 // %2 = scf.for .. iter_args(%arg0 = %1) { 530 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 531 // %4 = tensor.extract_slice %arg1 [..] 532 // .. = linalg.matmul .. outs(%4 : ) 533 // } 534 // } 535 // ``` 536 // 537 // the IR is currently 538 // 539 // ``` 540 // %0 = linalg.init 541 // %1 = linalg.fill 542 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 543 // %3 = scf.for .. iter_args(%arg1 = %arg0) { 544 // %4 = tensor.extract_slice %0 /*incorrect value */ [..] 545 // %5 = linalg.fill .. outs(%4 : ) 546 // .. = linalg.matmul .. outs(%5 : ) 547 // } 548 // } 549 // ``` 550 // 551 // The untiled `linalg.fill` is still used as the `init_value` since it 552 // was originally a destination operand of the untiled `linalg.matmul`. 553 // When fusing an operand that is a destination operand. 554 // - Update the iter_arg of the outer most loop to use the destination 555 // of the untiled producer. 556 // - Update the destination of the slice of the tiled producer generated 557 // to use the same basic block argument as the slice that was used to 558 // generate inplace the tiled implementation of the producer. 559 // With this the IR will be. 560 // 561 // ``` 562 // %0 = linalg.init 563 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 564 // %2 = scf.for .. iter_args(%arg1 = %arg0) { 565 // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] 566 // %4 = linalg.fill .. outs(%3 : ) 567 // .. = linalg.matmul .. outs(%4 : ) 568 // } 569 // } 570 // ``` 571 // TODO: This can be modeled better if the `DestinationStyleOpInterface`. 572 // Update to use that when it does become available. 573 scf::ForOp outerMostLoop = tileAndFuseResult.loops.front(); 574 Optional<unsigned> iterArgNumber; 575 if (destinationIterArg) { 576 iterArgNumber = outerMostLoop.getIterArgNumberForOpOperand( 577 *destinationIterArg.value()); 578 } 579 if (iterArgNumber) { 580 int64_t resultNumber = fusableProducer.getResultNumber(); 581 if (auto dstOp = dyn_cast<DestinationStyleOpInterface>( 582 fusableProducer.getOwner())) { 583 outerMostLoop.setIterArg( 584 iterArgNumber.value(), 585 dstOp.getTiedOpOperand(fusableProducer)->get()); 586 } 587 if (auto dstOp = fusedProducerValue.value() 588 .getDefiningOp<DestinationStyleOpInterface>()) { 589 scf::ForOp innerMostLoop = tileAndFuseResult.loops.back(); 590 updateDestinationOperandsForTiledOp( 591 rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), 592 innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); 593 } 594 } 595 } 596 return tileAndFuseResult; 597 } 598 599 //===----------------------------------------------------------------------===// 600 // lowerToLoopsUsingSCFForOp implementation. 601 //===----------------------------------------------------------------------===// 602 603 FailureOr<SmallVector<scf::ForOp>> 604 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 605 TilingInterface op) { 606 // TODO: Handle cases where the op has results if needed. 607 if (op->getNumResults() > 0) { 608 return rewriter.notifyMatchFailure( 609 op, "unable to lower to loops operations with return values"); 610 } 611 612 SmallVector<Range> domain = op.getIterationDomain(rewriter); 613 SmallVector<Value> ivs; 614 SmallVector<scf::ForOp> loops; 615 Location loc = op.getLoc(); 616 for (auto loopRange : domain) { 617 Value offsetVal = 618 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 619 Value sizeVal = 620 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 621 Value strideVal = 622 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 623 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 624 strideVal, ValueRange{}); 625 loops.push_back(loop); 626 ivs.push_back(loop.getInductionVar()); 627 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 628 } 629 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 630 return failure(); 631 } 632 return loops; 633 } 634