xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision d69293c1c80f0f0e3eb012bc006573d4d5cb820f)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
25 #include "mlir/Interfaces/TilingInterface.h"
26 #include "llvm/Support/Debug.h"
27 #include <optional>
28 
29 #define DEBUG_TYPE "tile-using-interface"
30 
31 using namespace mlir;
32 
33 scf::SCFTilingOptions &
34 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
35   assert(!tileSizeComputationFunction && "tile sizes already set");
36   auto tileSizes = llvm::to_vector(ts);
37   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
38     return tileSizes;
39   };
40   return *this;
41 }
42 
43 /// Helper method to adjust the interchange vector to match the iteration
44 /// domain.
45 static SmallVector<int64_t>
46 fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
47                       size_t iterationDomainSize) {
48   SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
49   if (filledVector.size() < iterationDomainSize) {
50     auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
51     filledVector.append(range.begin(), range.end());
52   }
53   if (filledVector.size() > iterationDomainSize)
54     filledVector.resize(iterationDomainSize);
55   return filledVector;
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // tileUsingSCFForOp implementation.
60 //===----------------------------------------------------------------------===//
61 
62 // Check if `stride` evenly divides the trip count `size - offset`.
63 static bool tileDividesIterationDomain(Range loopRange) {
64   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
65   if (!offsetAsInt)
66     return false;
67   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
68   if (!sizeAsInt)
69     return false;
70   std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
71   if (!strideAsInt)
72     return false;
73   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
74 }
75 
76 /// Returns the bounded tile size given the current `iv`, `loopRange` and
77 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
78 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
79                                        Range loopRange, Value iv,
80                                        Value tileSize) {
81   std::optional<int64_t> ts = getConstantIntValue(tileSize);
82   if (ts && ts.value() == 1)
83     return getAsOpFoldResult(tileSize);
84 
85   if (tileDividesIterationDomain(
86           Range{loopRange.offset, loopRange.size, tileSize}))
87     return tileSize;
88 
89   // The tile size to use (to avoid out of bounds access) is  minimum of
90   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
91   // loop.
92   AffineExpr s0, s1, d0;
93   bindDims(b.getContext(), d0);
94   bindSymbols(b.getContext(), s0, s1);
95   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
96   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
97   return affine::makeComposedFoldedAffineMin(
98       b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
99 }
100 
101 /// Generate an empty loop nest that represents the tiled loop nest shell.
102 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
103 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
104 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
105 /// the
106 ///   tile processed within the inner most loop.
107 static SmallVector<scf::ForOp> generateTileLoopNest(
108     OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
109     ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
110     SmallVector<OpFoldResult> &sizes) {
111   assert(!loopRanges.empty() && "expected at least one loop range");
112   assert(loopRanges.size() == tileSizes.size() &&
113          "expected as many tile sizes as loop ranges");
114   OpBuilder::InsertionGuard guard(builder);
115   SmallVector<scf::ForOp> loops;
116   offsets.resize(loopRanges.size());
117   sizes.resize(loopRanges.size());
118 
119   for (auto loopRange : llvm::enumerate(loopRanges)) {
120     Value offset =
121         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
122     Value size =
123         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
124     Value tileSize = getValueOrCreateConstantIndexOp(
125         builder, loc, tileSizes[loopRange.index()]);
126     // No loops if tile size is zero. Set offset and size to the loop
127     // offset and size.
128     if (matchPattern(tileSize, m_Zero())) {
129       offsets[loopRange.index()] = offset;
130       sizes[loopRange.index()] = size;
131       continue;
132     }
133 
134     auto loop = builder.create<scf::ForOp>(
135         loc, offset, size, tileSize, ValueRange{},
136         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
137             ValueRange /*iterArgs*/) {
138           sizes[loopRange.index()] = getBoundedTileSize(
139               bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize);
140           builder.create<scf::YieldOp>(loc);
141         });
142     offsets[loopRange.index()] = loop.getInductionVar();
143     loops.push_back(loop);
144     builder.setInsertionPoint(loop.getBody()->getTerminator());
145   }
146   return loops;
147 }
148 
149 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
150 /// construct the destructive update pattern that inserts the yielded
151 /// value into a destination tensor provided by `initValue` at offset
152 /// `tileOffsets` and size `tileSizes`. For example,
153 ///
154 /// ```mlir
155 /// scf.for %iv0 = ... {
156 ///   %0 = tiled_op
157 /// }
158 /// ```
159 ///
160 /// is transformed to
161 ///
162 /// ```mlir
163 /// scf.for %iv0 = ... iter_args(%arg = %0) {
164 ///   %1 = tensor.extract_slice %arg
165 ///   %2 = tiled_op
166 ///   %3 = tensor.insert_slice %2 into %arg
167 ///   scf.yield %3
168 /// }
169 /// ```
170 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
171 static SmallVector<Value>
172 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
173                  ValueRange yieldedValues,
174                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
175                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
176                  MutableArrayRef<scf::ForOp> loops) {
177   NewYieldValueFn yieldValueFn =
178       [&](OpBuilder &b, Location loc,
179           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
180     SmallVector<Value> inserts;
181     for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) {
182       ArrayRef<OpFoldResult> tileOffsets =
183           tileOffsetsList[yieldedValue.index()];
184       ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
185       SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
186                                             b.getIndexAttr(1));
187       Value insert = b.create<tensor::InsertSliceOp>(
188           loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
189           tileOffsets, tileSizes, tileStrides);
190       inserts.push_back(insert);
191     }
192     return inserts;
193   };
194 
195   SmallVector<scf::ForOp> newLoops =
196       replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
197                                    /*replaceIterOperandsUsesInLoop =*/false);
198   for (const auto &loop : llvm::enumerate(loops)) {
199     rewriter.eraseOp(loop.value());
200     loops[loop.index()] = newLoops[loop.index()];
201   }
202   return llvm::to_vector(llvm::map_range(
203       loops.front().getResults().take_back(yieldedValues.size()),
204       [](OpResult r) -> Value { return r; }));
205 }
206 
207 /// If the tiled operation is destination passing style, update the
208 /// slice of the destination used (which refers to the untiled destination)
209 /// to use the corresponding region argument of the innermost loop.
210 ///
211 /// ```mlir
212 /// %0 =
213 /// scf.for %iv0 = ... iter_args(%arg = %0) {
214 ///   %1 = tensor.extract_slice %0
215 ///   %2 = tiled_op
216 ///   %3 = tensor.insert_slice %2 into %arg
217 ///   scf.yield %3
218 /// }
219 /// ```
220 ///
221 /// is transformed to
222 ///
223 /// ```mlir
224 /// scf.for %iv0 = ... iter_args(%arg = %0) {
225 ///   %1 = tensor.extract_slice %arg
226 ///   %2 = tiled_op
227 ///   %3 = tensor.insert_slice %2 into %arg
228 ///   scf.yield %3
229 /// }
230 /// ```
231 static void
232 updateDestinationOperandsForTiledOp(OpBuilder &builder,
233                                     ValueRange tiledOpDestinationValues,
234                                     ValueRange bbArgsList) {
235   for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) {
236     auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
237     if (!sliceOp)
238       continue;
239     sliceOp.setOperand(0, bbArgsList[destValue.index()]);
240   }
241 }
242 
243 /// Helper method to yield the values of the tiled op, as well as
244 /// update the destination operands of the tiled op, if it is
245 /// a destination passing style op.
246 static SmallVector<Value>
247 yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues,
248                  TilingResult tilingResult,
249                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
250                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
251                  MutableArrayRef<scf::ForOp> loops) {
252   SmallVector<Value> replacements =
253       yieldTiledValues(rewriter, initValues, tilingResult.tiledValues,
254                        tileOffsetsList, tileSizesList, loops);
255   for (auto tiledOp : tilingResult.tiledOps) {
256     if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
257       auto innerMostLoop = loops.back();
258       SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands();
259       updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
260                                           innerMostLoop.getRegionIterArgs());
261     }
262   }
263   return replacements;
264 }
265 
266 /// Implementation of tiling transformation of `op` that implements the
267 /// `TilingInterface` using `scf.for` to iterate over the tiles.
268 FailureOr<scf::SCFTilingResult>
269 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
270                              const scf::SCFTilingOptions &options) {
271   OpBuilder::InsertionGuard guard(rewriter);
272   rewriter.setInsertionPointAfter(op);
273 
274   if (!options.tileSizeComputationFunction) {
275     return rewriter.notifyMatchFailure(
276         op, "missing tile size computation function");
277   }
278 
279   // 1. Get the range of the loops that are represented by the operation.
280   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
281   size_t numLoops = iterationDomain.size();
282   if (numLoops == 0) {
283     return rewriter.notifyMatchFailure(
284         op, "unable to tile op with no iteration domain");
285   }
286 
287   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
288   // skips tiling a particular dimension. This convention is significantly
289   // simpler to handle instead of adjusting affine maps to account for missing
290   // dimensions.
291   SmallVector<OpFoldResult> tileSizeVector =
292       options.tileSizeComputationFunction(rewriter, op);
293   if (tileSizeVector.size() < iterationDomain.size()) {
294     auto zero = rewriter.getIndexAttr(0);
295     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
296   }
297 
298   scf::SCFTilingResult tilingResult;
299   SmallVector<OpFoldResult> offsets, sizes;
300   {
301     // If there is an interchange specified, permute the iteration domain and
302     // the tile sizes.
303     SmallVector<int64_t> interchangeVector;
304     if (!options.interchangeVector.empty()) {
305       interchangeVector = fillInterchangeVector(options.interchangeVector,
306                                                 iterationDomain.size());
307     }
308     if (!interchangeVector.empty()) {
309       if (!isPermutationVector(interchangeVector)) {
310         return rewriter.notifyMatchFailure(
311             op, "invalid intechange vector, not a permutation of the entire "
312                 "iteration space");
313       }
314 
315       applyPermutationToVector(iterationDomain, interchangeVector);
316       applyPermutationToVector(tileSizeVector, interchangeVector);
317     }
318 
319     // 3. Materialize an empty loop nest that iterates over the tiles. These
320     // loops for now do not return any values even if the original operation has
321     // results.
322     tilingResult.loops = generateTileLoopNest(
323         rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
324 
325     if (!interchangeVector.empty()) {
326       auto inversePermutation = invertPermutationVector(interchangeVector);
327       applyPermutationToVector(offsets, inversePermutation);
328       applyPermutationToVector(sizes, inversePermutation);
329     }
330   }
331 
332   LLVM_DEBUG({
333     if (!tilingResult.loops.empty()) {
334       llvm::dbgs() << "LoopNest shell :\n";
335       tilingResult.loops.front().dump();
336       llvm::dbgs() << "\n";
337     }
338   });
339 
340   // 4. Generate the tiled implementation within the inner most loop.
341   if (!tilingResult.loops.empty())
342     rewriter.setInsertionPoint(
343         tilingResult.loops.back().getBody()->getTerminator());
344   FailureOr<TilingResult> tiledImplementation =
345       op.getTiledImplementation(rewriter, offsets, sizes);
346   tilingResult.tiledOps.append(tiledImplementation->tiledOps);
347   if (op->getNumResults() == 0) {
348     // nothing more to do.
349     return tilingResult;
350   }
351 
352   // If loops are empty, the tiled op is used as the replacement for the untiled
353   // op.
354   if (tilingResult.loops.empty()) {
355     tilingResult.replacements = tiledImplementation->tiledValues;
356     return tilingResult;
357   }
358 
359   // 5. Yield all the results of the tiled operation. The surrounding loop
360   //    nest is modified to insert a destructive update pattern to yield
361   //    from the loop nest values to replace the untiled op with.
362   int64_t numResults = op->getNumResults();
363   SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
364       resultSizesList(numResults);
365   for (const auto &result : llvm::enumerate(op->getResults())) {
366     if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
367                                         sizes,
368                                         resultOffsetsList[result.index()],
369                                         resultSizesList[result.index()]))) {
370       return rewriter.notifyMatchFailure(
371           op, "failed to get slice of result produced");
372     }
373   }
374 
375   SmallVector<Value> destinationTensors;
376   if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
377                                              destinationTensors)))
378     return rewriter.notifyMatchFailure(op, "failed to get destinations");
379 
380   tilingResult.replacements = yieldTiledValues(
381       rewriter, destinationTensors, tiledImplementation.value(),
382       resultOffsetsList, resultSizesList, tilingResult.loops);
383 
384   LLVM_DEBUG({
385     if (!tilingResult.loops.empty()) {
386       llvm::dbgs() << "After tiled implementation :\n";
387       tilingResult.loops.front().dump();
388       llvm::dbgs() << "\n";
389     }
390   });
391   return tilingResult;
392 }
393 
394 FailureOr<scf::SCFReductionTilingResult>
395 mlir::scf::tileReductionUsingScf(RewriterBase &b,
396                                  PartialReductionOpInterface op,
397                                  ArrayRef<OpFoldResult> tileSizes) {
398   Location loc = op.getLoc();
399   // Ops implementing PartialReductionOpInterface are expected to implement
400   // TilingInterface.
401   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
402   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
403   auto tileSizesVector = llvm::to_vector(tileSizes);
404   if (tileSizesVector.size() < iterationDomain.size()) {
405     auto zero = b.getIndexAttr(0);
406     tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
407                            zero);
408   }
409   if (op->getNumResults() != 1)
410     return b.notifyMatchFailure(
411         op, "don't support ops with multiple results for now");
412   SmallVector<utils::IteratorType> iterators =
413       tilingInterfaceOp.getLoopIteratorTypes();
414 
415   SmallVector<int> reductionDims;
416   for (auto [idx, iteratorType] :
417        llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
418     if (iteratorType == utils::IteratorType::reduction)
419       reductionDims.push_back(idx);
420   }
421 
422   // 1. create the inital tensor value.
423   FailureOr<Operation *> identityTensor =
424       op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
425                                                   reductionDims);
426   if (failed(identityTensor))
427     return b.notifyMatchFailure(op,
428                                 "cannot create a tensor of identity value.");
429   // 2. Create the nested loops.
430   SmallVector<OpFoldResult> offsets, sizes;
431   SmallVector<scf::ForOp> loops = generateTileLoopNest(
432       b, loc, iterationDomain, tileSizesVector, offsets, sizes);
433 
434   // 3. Generate the tiled implementation within the inner most loop.
435   b.setInsertionPoint(loops.back().getBody()->getTerminator());
436   Operation *parallelOp = op.tileToPartialReduction(
437       b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDims);
438 
439   SmallVector<OpFoldResult> resultSizesList;
440   for (size_t i = 0; i < offsets.size(); i++)
441     resultSizesList.push_back(
442         tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
443   SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
444   SmallVector<Value> replacements = yieldTiledValues(
445       b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
446       resultSizesList, loops);
447 
448   auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
449   auto innerMostLoop = loops.back();
450   SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
451   assert(destinationTensors.size() ==
452              innerMostLoop.getRegionIterArgs().size() &&
453          "unexpected number of outputs");
454   updateDestinationOperandsForTiledOp(b, destinationTensors,
455                                       innerMostLoop.getRegionIterArgs());
456 
457   // 4. Apply the merge reduction to combine all the partial values.
458   b.setInsertionPointAfter(*loops.begin());
459   Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
460   b.replaceOp(op, mergeOp->getResults());
461 
462   SCFReductionTilingResult results;
463   results.initialOp = *identityTensor;
464   results.loops = std::move(loops);
465   results.parallelTiledOp = parallelOp;
466   results.mergeOp = mergeOp;
467   return results;
468 }
469 //===----------------------------------------------------------------------===//
470 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
471 //===----------------------------------------------------------------------===//
472 
473 /// Return the untiled producer whose slice is used in a tiled consumer. The
474 /// method traverses the tile loop nest (`loops`) if needed, and returns the
475 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
476 /// indicates that this is a destination operand of the consumer. If there was
477 /// no loop traversal needed, the second value of the returned tuple is empty.
478 static std::tuple<OpResult, std::optional<OpOperand *>>
479 getUntiledProducerFromSliceSource(OpOperand *source,
480                                   ArrayRef<scf::ForOp> loops) {
481   std::optional<OpOperand *> destinationIterArg;
482   auto loopIt = loops.rbegin();
483   while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
484     scf::ForOp loop = *loopIt;
485     if (iterArg.getOwner()->getParentOp() != loop)
486       break;
487     source = &loop.getOpOperandForRegionIterArg(iterArg);
488     loopIt++;
489   }
490   if (loopIt == loops.rend())
491     destinationIterArg = source;
492   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
493 }
494 
495 /// Implementation of fusing producer of a single slice by computing the
496 /// slice of the producer in-place.
497 std::optional<scf::SCFFuseProducerOfSliceResult>
498 mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
499                                       tensor::ExtractSliceOp candidateSliceOp,
500                                       MutableArrayRef<scf::ForOp> loops) {
501   // 1. Get the producer of the source (potentially walking through
502   // `iter_args` of nested `scf.for`)
503   auto [fusableProducer, destinationInitArg] =
504       getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
505                                         loops);
506   if (!fusableProducer)
507     return std::nullopt;
508 
509   // 2. Generate the tiled implementation of the producer of the source
510   OpBuilder::InsertionGuard g(rewriter);
511   rewriter.setInsertionPoint(candidateSliceOp);
512   FailureOr<TilingResult> tileAndFuseResult =
513       tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
514                                                    fusableProducer);
515   if (failed(tileAndFuseResult))
516     return std::nullopt;
517   rewriter.replaceAllUsesWith(candidateSliceOp,
518                               tileAndFuseResult->tiledValues[0]);
519 
520   // 3. If the slice is for a destination operand, for example,
521   //
522   // ```mlir
523   // %0 = linalg.init
524   // %1 = linalg.fill .. outs(%0 : )
525   // %2 = scf.for .. iter_args(%arg0 = %1) {
526   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
527   //     %4 = tensor.extract_slice %arg1 [..]
528   //     .. = linalg.matmul .. outs(%4 : )
529   //   }
530   // }
531   // ```
532   //
533   // the IR is currently
534   //
535   // ```
536   // %0 = linalg.init
537   // %1 = linalg.fill
538   // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
539   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
540   //     %4 = tensor.extract_slice %0 /*incorrect value */ [..]
541   //     %5 = linalg.fill .. outs(%4 : )
542   //     .. = linalg.matmul .. outs(%5 : )
543   //   }
544   // }
545   // ```
546   //
547   // The untiled `linalg.fill` is still used as the `init_value` since it
548   // was originally a destination operand of the untiled `linalg.matmul`.
549   // When fusing an operand that is a destination operand.
550   //   - Update the iter_arg of the outer most loop to use the destination
551   //     of the untiled producer.
552   //   - Update the destination of the slice of the tiled producer generated
553   //     to use the same basic block argument as the slice that was used to
554   //     generate inplace the tiled implementation of the producer.
555   // With this the IR will be.
556   //
557   // ```
558   // %0 = linalg.init
559   // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
560   //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
561   //     %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
562   //     %4 = linalg.fill .. outs(%3 : )
563   //     .. = linalg.matmul .. outs(%4 : )
564   //   }
565   // }
566   // ```
567   // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
568   // Update to use that when it does become available.
569   scf::ForOp outerMostLoop = loops.front();
570   if (destinationInitArg &&
571       (*destinationInitArg)->getOwner() == outerMostLoop) {
572     unsigned iterArgNumber =
573         outerMostLoop.getResultForOpOperand(**destinationInitArg)
574             .getResultNumber();
575     int64_t resultNumber = fusableProducer.getResultNumber();
576     if (auto dstOp =
577             dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
578       (*destinationInitArg)
579           ->set(dstOp.getTiedOpOperand(fusableProducer)->get());
580     }
581     for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
582       auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
583       if (!dstOp)
584         continue;
585       scf::ForOp innerMostLoop = loops.back();
586       updateDestinationOperandsForTiledOp(
587           rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
588           innerMostLoop.getRegionIterArgs()[iterArgNumber]);
589     }
590   }
591   return scf::SCFFuseProducerOfSliceResult{fusableProducer,
592                                            tileAndFuseResult->tiledValues[0],
593                                            tileAndFuseResult->tiledOps};
594 }
595 
596 /// Reconstruct the fused producer from within the tiled-and-fused code.
597 void mlir::scf::yieldReplacementForFusedProducer(
598     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
599     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
600     MutableArrayRef<scf::ForOp> loops) {
601   auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
602       fusedProducerInfo;
603   SmallVector<Value> initValues;
604   FailureOr<Value> initValue = tensor::getOrCreateDestination(
605       rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
606   if (succeeded(initValue)) {
607     SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets();
608     SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes();
609     SmallVector<Value> yieldedVals =
610         yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
611                          resultOffsets, resultSizes, loops);
612   }
613   for (auto tileAndFusedOp : tileAndFusedOps) {
614     auto dstStyleProducer =
615         dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
616     if (!dstStyleProducer)
617       continue;
618     Value dstValue =
619         dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
620             ->get();
621     updateDestinationOperandsForTiledOp(
622         rewriter, dstValue, loops.back().getRegionIterArgs().back());
623   }
624 }
625 
626 /// Implementation of tile consumer and fuse producer greedily.
627 FailureOr<scf::SCFTileAndFuseResult>
628 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
629     RewriterBase &rewriter, TilingInterface consumer,
630     const scf::SCFTileAndFuseOptions &options) {
631   // This transformation is only valid for ops that return values (i.e. not
632   // valid to use with operations that have memref operands).
633   if (!consumer->getNumResults()) {
634     return rewriter.notifyMatchFailure(
635         consumer, "invalid pattern for op with no results");
636   }
637 
638   // 1. First tile the consumer.
639   scf::SCFTileAndFuseResult tileAndFuseResult;
640   llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
641   {
642     FailureOr<scf::SCFTilingResult> tilingResult =
643         tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
644     if (failed(tilingResult))
645       return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
646     for (auto *tiledOp : tilingResult->tiledOps)
647       tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
648     tileAndFuseResult.loops = std::move(tilingResult->loops);
649     for (const auto &result : llvm::enumerate(
650              llvm::zip(consumer->getResults(), tilingResult->replacements))) {
651       tileAndFuseResult.replacements[std::get<0>(result.value())] =
652           std::get<1>(result.value());
653       yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
654           result.index())] = result.index();
655     }
656   }
657 
658   // If there are no loops generated, fusion is immaterial.
659   if (tileAndFuseResult.loops.empty())
660     return tileAndFuseResult;
661 
662   // 2. Typically, the operands of the tiled operation are slices of the
663   //    operands of the untiled operation. These are expressed in IR using
664   //    `tensor.extract_slice` operations with source being the operands of the
665   //    untiled operation. Create a worklist of these `tensor.extract_slice`
666   //    operations. If the producers of the source of the `tensor.extract_slice`
667   //    can be tiled such that the tiled value is generated in-place, that
668   //    effectively tiles + fuses the operations.
669   auto addCandidateSlices = [](Operation *fusedOp,
670                                std::deque<tensor::ExtractSliceOp> &candidates) {
671     for (Value operand : fusedOp->getOperands())
672       if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
673         candidates.push_back(sliceOp);
674   };
675 
676   std::deque<tensor::ExtractSliceOp> candidates;
677   addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
678   OpBuilder::InsertionGuard g(rewriter);
679   while (!candidates.empty()) {
680     // Traverse the slices in BFS fashion.
681     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
682     candidates.pop_front();
683 
684     // The operands of the fused producer might themselved be slices of
685     // values produced by operations that implement the `TilingInterface`.
686     // Add these operations to the worklist.
687     std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
688         tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
689                                    tileAndFuseResult.loops);
690     if (!fusedProducer)
691       continue;
692 
693     if (Operation *tiledAndFusedOp =
694             fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
695       tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
696       addCandidateSlices(tiledAndFusedOp, candidates);
697     }
698   }
699   return tileAndFuseResult;
700 }
701 
702 //===----------------------------------------------------------------------===//
703 // lowerToLoopsUsingSCFForOp implementation.
704 //===----------------------------------------------------------------------===//
705 
706 FailureOr<SmallVector<scf::ForOp>>
707 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
708                                      TilingInterface op) {
709   // TODO: Handle cases where the op has results if needed.
710   if (op->getNumResults() > 0) {
711     return rewriter.notifyMatchFailure(
712         op, "unable to lower to loops operations with return values");
713   }
714 
715   SmallVector<Range> domain = op.getIterationDomain(rewriter);
716   SmallVector<Value> ivs;
717   SmallVector<scf::ForOp> loops;
718   Location loc = op.getLoc();
719   for (auto loopRange : domain) {
720     Value offsetVal =
721         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
722     Value sizeVal =
723         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
724     Value strideVal =
725         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
726     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
727                                             strideVal, ValueRange{});
728     loops.push_back(loop);
729     ivs.push_back(loop.getInductionVar());
730     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
731   }
732   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
733     return failure();
734   }
735   return loops;
736 }
737