xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision 5c9013e2664141d3b4071c37b3cbb8cd09a008d8)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
25 #include "mlir/Interfaces/TilingInterface.h"
26 #include "llvm/Support/Debug.h"
27 #include <optional>
28 
29 #define DEBUG_TYPE "tile-using-interface"
30 
31 using namespace mlir;
32 
33 scf::SCFTilingOptions &
34 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
35   assert(!tileSizeComputationFunction && "tile sizes already set");
36   SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
37   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
38     OpBuilder::InsertionGuard guard(b);
39     b.setInsertionPointToStart(
40         &op->getParentOfType<func::FuncOp>().getBody().front());
41     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
42       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
43       return v;
44     }));
45   };
46   return *this;
47 }
48 
49 /// Helper method to adjust the interchange vector to match the iteration
50 /// domain.
51 static SmallVector<int64_t>
52 fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
53                       size_t iterationDomainSize) {
54   SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
55   if (filledVector.size() < iterationDomainSize) {
56     auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
57     filledVector.append(range.begin(), range.end());
58   }
59   if (filledVector.size() > iterationDomainSize)
60     filledVector.resize(iterationDomainSize);
61   return filledVector;
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // tileUsingSCFForOp implementation.
66 //===----------------------------------------------------------------------===//
67 
68 // Check if `stride` evenly divides the trip count `size - offset`.
69 static bool tileDividesIterationDomain(Range loopRange) {
70   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
71   if (!offsetAsInt)
72     return false;
73   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
74   if (!sizeAsInt)
75     return false;
76   std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
77   if (!strideAsInt)
78     return false;
79   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
80 }
81 
82 /// Returns the bounded tile size given the current `iv`, `loopRange` and
83 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
84 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
85                                        Range loopRange, Value iv,
86                                        Value tileSize) {
87   std::optional<int64_t> ts = getConstantIntValue(tileSize);
88   if (ts && ts.value() == 1)
89     return getAsOpFoldResult(tileSize);
90 
91   if (tileDividesIterationDomain(
92           Range{loopRange.offset, loopRange.size, tileSize}))
93     return tileSize;
94 
95   // The tile size to use (to avoid out of bounds access) is  minimum of
96   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
97   // loop.
98   AffineExpr s0, s1, d0;
99   bindDims(b.getContext(), d0);
100   bindSymbols(b.getContext(), s0, s1);
101   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
102   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
103   return makeComposedFoldedAffineMin(
104       b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
105 }
106 
107 /// Generate an empty loop nest that represents the tiled loop nest shell.
108 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
109 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
110 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
111 /// the
112 ///   tile processed within the inner most loop.
113 static SmallVector<scf::ForOp>
114 generateTileLoopNest(OpBuilder &builder, Location loc,
115                      ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
116                      SmallVector<OpFoldResult> &offsets,
117                      SmallVector<OpFoldResult> &sizes) {
118   assert(!loopRanges.empty() && "expected at least one loop range");
119   assert(loopRanges.size() == tileSizeVals.size() &&
120          "expected as many tile sizes as loop ranges");
121   OpBuilder::InsertionGuard guard(builder);
122   SmallVector<scf::ForOp> loops;
123   offsets.resize(loopRanges.size());
124   sizes.resize(loopRanges.size());
125 
126   for (auto loopRange : llvm::enumerate(loopRanges)) {
127     Value offset =
128         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
129     Value size =
130         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
131     Value tileSize = tileSizeVals[loopRange.index()];
132     // No loops if tile size is zero. Set offset and size to the loop
133     // offset and size.
134     if (matchPattern(tileSize, m_Zero())) {
135       offsets[loopRange.index()] = offset;
136       sizes[loopRange.index()] = size;
137       continue;
138     }
139 
140     auto loop = builder.create<scf::ForOp>(
141         loc, offset, size, tileSize, ValueRange{},
142         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
143             ValueRange /*iterArgs*/) {
144           sizes[loopRange.index()] = getBoundedTileSize(
145               bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize);
146           builder.create<scf::YieldOp>(loc);
147         });
148     offsets[loopRange.index()] = loop.getInductionVar();
149     loops.push_back(loop);
150     builder.setInsertionPoint(loop.getBody()->getTerminator());
151   }
152   return loops;
153 }
154 
155 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
156 /// construct the destructive update pattern that inserts the yielded
157 /// value into a destination tensor provided by `initValue` at offset
158 /// `tileOffsets` and size `tileSizes`. For example,
159 ///
160 /// ```mlir
161 /// scf.for %iv0 = ... {
162 ///   %0 = tiled_op
163 /// }
164 /// ```
165 ///
166 /// is transformed to
167 ///
168 /// ```mlir
169 /// scf.for %iv0 = ... iter_args(%arg = %0) {
170 ///   %1 = tensor.extract_slice %arg
171 ///   %2 = tiled_op
172 ///   %3 = tensor.insert_slice %2 into %arg
173 ///   scf.yield %3
174 /// }
175 /// ```
176 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
177 static SmallVector<Value>
178 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
179                  ValueRange yieldedValues,
180                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
181                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
182                  MutableArrayRef<scf::ForOp> loops) {
183   NewYieldValueFn yieldValueFn =
184       [&](OpBuilder &b, Location loc,
185           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
186     SmallVector<Value> inserts;
187     for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) {
188       ArrayRef<OpFoldResult> tileOffsets =
189           tileOffsetsList[yieldedValue.index()];
190       ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
191       SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
192                                             b.getIndexAttr(1));
193       Value insert = b.create<tensor::InsertSliceOp>(
194           loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
195           tileOffsets, tileSizes, tileStrides);
196       inserts.push_back(insert);
197     }
198     return inserts;
199   };
200 
201   SmallVector<scf::ForOp> newLoops =
202       replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
203                                    /*replaceIterOperandsUsesInLoop =*/false);
204   for (const auto &loop : llvm::enumerate(loops)) {
205     rewriter.eraseOp(loop.value());
206     loops[loop.index()] = newLoops[loop.index()];
207   }
208   return llvm::to_vector(llvm::map_range(
209       loops.front().getResults().take_back(yieldedValues.size()),
210       [](OpResult r) -> Value { return r; }));
211 }
212 
213 /// If the tiled operation is destination passing style, update the
214 /// slice of the destination used (which refers to the untiled destination)
215 /// to use the corresponding region argument of the innermost loop.
216 ///
217 /// ```mlir
218 /// %0 =
219 /// scf.for %iv0 = ... iter_args(%arg = %0) {
220 ///   %1 = tensor.extract_slice %0
221 ///   %2 = tiled_op
222 ///   %3 = tensor.insert_slice %2 into %arg
223 ///   scf.yield %3
224 /// }
225 /// ```
226 ///
227 /// is transformed to
228 ///
229 /// ```mlir
230 /// scf.for %iv0 = ... iter_args(%arg = %0) {
231 ///   %1 = tensor.extract_slice %arg
232 ///   %2 = tiled_op
233 ///   %3 = tensor.insert_slice %2 into %arg
234 ///   scf.yield %3
235 /// }
236 /// ```
237 static void
238 updateDestinationOperandsForTiledOp(OpBuilder &builder,
239                                     ValueRange tiledOpDestinationValues,
240                                     ValueRange bbArgsList) {
241   for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) {
242     auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
243     if (!sliceOp)
244       continue;
245     sliceOp.setOperand(0, bbArgsList[destValue.index()]);
246   }
247 }
248 
249 /// Helper method to yield the values of the tiled op, as well as
250 /// update the destination operands of the tiled op, if it is
251 /// a destination passing style op.
252 static SmallVector<Value>
253 yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues,
254                  Operation *tiledOp,
255                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
256                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
257                  MutableArrayRef<scf::ForOp> loops) {
258   SmallVector<Value> replacements =
259       yieldTiledValues(rewriter, initValues, tiledOp->getResults(),
260                        tileOffsetsList, tileSizesList, loops);
261   if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
262     auto innerMostLoop = loops.back();
263     SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands();
264     updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
265                                         innerMostLoop.getRegionIterArgs());
266   }
267   return replacements;
268 }
269 
270 /// Implementation of tiling transformation of `op` that implements the
271 /// `TilingInterface` using `scf.for` to iterate over the tiles.
272 FailureOr<scf::SCFTilingResult>
273 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
274                              const scf::SCFTilingOptions &options) {
275   OpBuilder::InsertionGuard guard(rewriter);
276   rewriter.setInsertionPointAfter(op);
277 
278   if (!options.tileSizeComputationFunction) {
279     return rewriter.notifyMatchFailure(
280         op, "missing tile size computation function");
281   }
282 
283   // 1. Get the range of the loops that are represented by the operation.
284   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
285   size_t numLoops = iterationDomain.size();
286   if (numLoops == 0) {
287     return rewriter.notifyMatchFailure(
288         op, "unable to tile op with no iteration domain");
289   }
290 
291   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
292   // skips tiling a particular dimension. This convention is significantly
293   // simpler to handle instead of adjusting affine maps to account for missing
294   // dimensions.
295   SmallVector<Value> tileSizeVector =
296       options.tileSizeComputationFunction(rewriter, op);
297   if (tileSizeVector.size() < iterationDomain.size()) {
298     auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
299     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
300   }
301 
302   scf::SCFTilingResult tilingResult;
303   SmallVector<OpFoldResult> offsets, sizes;
304   {
305     // If there is an interchange specified, permute the iteration domain and
306     // the tile sizes.
307     SmallVector<int64_t> interchangeVector;
308     if (!options.interchangeVector.empty()) {
309       interchangeVector = fillInterchangeVector(options.interchangeVector,
310                                                 iterationDomain.size());
311     }
312     if (!interchangeVector.empty()) {
313       if (!isPermutationVector(interchangeVector)) {
314         return rewriter.notifyMatchFailure(
315             op, "invalid intechange vector, not a permutation of the entire "
316                 "iteration space");
317       }
318 
319       applyPermutationToVector(iterationDomain, interchangeVector);
320       applyPermutationToVector(tileSizeVector, interchangeVector);
321     }
322 
323     // 3. Materialize an empty loop nest that iterates over the tiles. These
324     // loops for now do not return any values even if the original operation has
325     // results.
326     tilingResult.loops = generateTileLoopNest(
327         rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
328 
329     if (!interchangeVector.empty()) {
330       auto inversePermutation = invertPermutationVector(interchangeVector);
331       applyPermutationToVector(offsets, inversePermutation);
332       applyPermutationToVector(sizes, inversePermutation);
333     }
334   }
335 
336   LLVM_DEBUG({
337     if (!tilingResult.loops.empty()) {
338       llvm::dbgs() << "LoopNest shell :\n";
339       tilingResult.loops.front().dump();
340       llvm::dbgs() << "\n";
341     }
342   });
343 
344   // 4. Generate the tiled implementation within the inner most loop.
345   if (!tilingResult.loops.empty())
346     rewriter.setInsertionPoint(
347         tilingResult.loops.back().getBody()->getTerminator());
348   SmallVector<Operation *> tiledImplementation =
349       op.getTiledImplementation(rewriter, offsets, sizes);
350   tilingResult.tiledOps.append(tiledImplementation);
351   if (op->getNumResults() == 0) {
352     // nothing more to do.
353     return tilingResult;
354   }
355 
356   // If loops are empty, the tiled op is used as the replacement for the untiled
357   // op.
358   if (tilingResult.loops.empty()) {
359     tilingResult.replacements = llvm::to_vector(
360         llvm::map_range(tiledImplementation[0]->getResults(),
361                         [](OpResult result) -> Value { return result; }));
362     return tilingResult;
363   }
364 
365   // 5. Yield all the results of the tiled operation. The surrounding loop
366   //    nest is modified to insert a destructive update pattern to yield
367   //    from the loop nest values to replace the untiled op with.
368   int64_t numResults = op->getNumResults();
369   SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
370       resultSizesList(numResults);
371   for (const auto &result : llvm::enumerate(op->getResults())) {
372     if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
373                                         sizes,
374                                         resultOffsetsList[result.index()],
375                                         resultSizesList[result.index()]))) {
376       return rewriter.notifyMatchFailure(
377           op, "failed to get slice of result produced");
378     }
379   }
380 
381   SmallVector<Value> destinationTensors;
382   if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
383                                              destinationTensors)))
384     return rewriter.notifyMatchFailure(op, "failed to get destinations");
385 
386   tilingResult.replacements = yieldTiledValues(
387       rewriter, destinationTensors, tilingResult.tiledOps.back(),
388       resultOffsetsList, resultSizesList, tilingResult.loops);
389 
390   LLVM_DEBUG({
391     if (!tilingResult.loops.empty()) {
392       llvm::dbgs() << "After tiled implementation :\n";
393       tilingResult.loops.front().dump();
394       llvm::dbgs() << "\n";
395     }
396   });
397   return tilingResult;
398 }
399 
400 FailureOr<scf::SCFReductionTilingResult>
401 mlir::scf::tileReductionUsingScf(PatternRewriter &b,
402                                  PartialReductionOpInterface op,
403                                  ArrayRef<OpFoldResult> tileSize) {
404   Location loc = op.getLoc();
405   // Ops implementing PartialReductionOpInterface are expected to implement
406   // TilingInterface.
407   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
408   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
409   SmallVector<Value> tileSizeVector =
410       getValueOrCreateConstantIndexOp(b, loc, tileSize);
411   if (tileSizeVector.size() < iterationDomain.size()) {
412     auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
413     tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
414   }
415   if (op->getNumResults() != 1)
416     return b.notifyMatchFailure(
417         op, "don't support ops with multiple results for now");
418   SmallVector<utils::IteratorType> iterators =
419       tilingInterfaceOp.getLoopIteratorTypes();
420   int64_t numReductionDims = llvm::count(
421       tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction);
422   if (numReductionDims != 1)
423     return b.notifyMatchFailure(
424         op, "only support ops with one reduction dimension.");
425   int reductionDim;
426   for (auto &[idx, iteratorType] :
427        llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
428     if (iteratorType == utils::IteratorType::reduction) {
429       reductionDim = idx;
430       break;
431     }
432   }
433   if (static_cast<size_t>(reductionDim) >= tileSize.size())
434     return b.notifyMatchFailure(op, "reduction dimension must be tiled");
435 
436   // 1. create the inital tensor value.
437   FailureOr<Operation *> identityTensor =
438       op.generateInitialTensorForPartialReduction(b, loc, tileSize,
439                                                   reductionDim);
440   if (failed(identityTensor))
441     return b.notifyMatchFailure(op,
442                                 "cannot create a tensor of identity value.");
443   // 2. Create the nested loops.
444   SmallVector<OpFoldResult> offsets, sizes;
445   SmallVector<scf::ForOp> loops = generateTileLoopNest(
446       b, loc, iterationDomain, tileSizeVector, offsets, sizes);
447 
448   // 3. Generate the tiled implementation within the inner most loop.
449   b.setInsertionPoint(loops.back().getBody()->getTerminator());
450   Operation *parallelOp = op.tileToPartialReduction(
451       b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDim);
452 
453   SmallVector<OpFoldResult> resultSizesList;
454   for (size_t i = 0; i < offsets.size(); i++)
455     resultSizesList.push_back(
456         b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i));
457   SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
458   SmallVector<Value> replacements = yieldTiledValues(
459       b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
460       resultSizesList, loops);
461 
462   auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
463   auto innerMostLoop = loops.back();
464   SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
465   assert(destinationTensors.size() ==
466              innerMostLoop.getRegionIterArgs().size() &&
467          "unexpected number of outputs");
468   updateDestinationOperandsForTiledOp(b, destinationTensors,
469                                       innerMostLoop.getRegionIterArgs());
470 
471   // 4. Apply the merge reduction to combine all the partial values.
472   b.setInsertionPointAfter(*loops.begin());
473   Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim);
474   b.replaceOp(op, mergeOp->getResults());
475 
476   SCFReductionTilingResult results;
477   results.initialOp = *identityTensor;
478   results.loops = std::move(loops);
479   results.parallelTiledOp = parallelOp;
480   results.mergeOp = mergeOp;
481   return results;
482 }
483 //===----------------------------------------------------------------------===//
484 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
485 //===----------------------------------------------------------------------===//
486 
487 /// Return the untiled producer whose slice is used in a tiled consumer. The
488 /// method traverses the tile loop nest (`loops`) if needed, and returns the
489 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
490 /// indicates that this is a destination operand of the consumer. If there was
491 /// no loop traversal needed, the second value of the returned tuple is empty.
492 static std::tuple<OpResult, std::optional<OpOperand *>>
493 getUntiledProducerFromSliceSource(OpOperand *source,
494                                   ArrayRef<scf::ForOp> loops) {
495   std::optional<OpOperand *> destinationIterArg;
496   auto loopIt = loops.rbegin();
497   while (auto iterArg = source->get().dyn_cast<BlockArgument>()) {
498     scf::ForOp loop = *loopIt;
499     if (iterArg.getOwner()->getParentOp() != loop)
500       break;
501     source = &loop.getOpOperandForRegionIterArg(iterArg);
502     loopIt++;
503   }
504   if (loopIt == loops.rend())
505     destinationIterArg = source;
506   return {source->get().dyn_cast<OpResult>(), destinationIterArg};
507 }
508 
509 /// Implementation of fusing producer of a single slice by computing the
510 /// slice of the producer in-place.
511 std::optional<scf::SCFFuseProducerOfSliceResult>
512 mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
513                                       tensor::ExtractSliceOp candidateSliceOp,
514                                       MutableArrayRef<scf::ForOp> loops) {
515   // 1. Get the producer of the source (potentially walking through
516   // `iter_args` of nested `scf.for`)
517   auto [fusableProducer, destinationIterArg] =
518       getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
519                                         loops);
520   if (!fusableProducer)
521     return std::nullopt;
522 
523   // 2. Generate the tiled implementation of the producer of the source
524   OpBuilder::InsertionGuard g(rewriter);
525   rewriter.setInsertionPoint(candidateSliceOp);
526   FailureOr<Value> fusedProducerValue =
527       tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
528                                                    fusableProducer);
529   if (failed(fusedProducerValue))
530     return std::nullopt;
531   rewriter.replaceAllUsesWith(candidateSliceOp, fusedProducerValue.value());
532 
533   // 3. If the slice is for a destination operand, for example,
534   //
535   // ```mlir
536   // %0 = linalg.init
537   // %1 = linalg.fill .. outs(%0 : )
538   // %2 = scf.for .. iter_args(%arg0 = %1) {
539   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
540   //     %4 = tensor.extract_slice %arg1 [..]
541   //     .. = linalg.matmul .. outs(%4 : )
542   //   }
543   // }
544   // ```
545   //
546   // the IR is currently
547   //
548   // ```
549   // %0 = linalg.init
550   // %1 = linalg.fill
551   // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
552   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
553   //     %4 = tensor.extract_slice %0 /*incorrect value */ [..]
554   //     %5 = linalg.fill .. outs(%4 : )
555   //     .. = linalg.matmul .. outs(%5 : )
556   //   }
557   // }
558   // ```
559   //
560   // The untiled `linalg.fill` is still used as the `init_value` since it
561   // was originally a destination operand of the untiled `linalg.matmul`.
562   // When fusing an operand that is a destination operand.
563   //   - Update the iter_arg of the outer most loop to use the destination
564   //     of the untiled producer.
565   //   - Update the destination of the slice of the tiled producer generated
566   //     to use the same basic block argument as the slice that was used to
567   //     generate inplace the tiled implementation of the producer.
568   // With this the IR will be.
569   //
570   // ```
571   // %0 = linalg.init
572   // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
573   //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
574   //     %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
575   //     %4 = linalg.fill .. outs(%3 : )
576   //     .. = linalg.matmul .. outs(%4 : )
577   //   }
578   // }
579   // ```
580   // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
581   // Update to use that when it does become available.
582   scf::ForOp outerMostLoop = loops.front();
583   std::optional<unsigned> iterArgNumber;
584   if (destinationIterArg) {
585     iterArgNumber =
586         outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
587   }
588   if (iterArgNumber) {
589     int64_t resultNumber = fusableProducer.getResultNumber();
590     if (auto dstOp =
591             dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
592       outerMostLoop.setIterArg(iterArgNumber.value(),
593                                dstOp.getTiedOpOperand(fusableProducer)->get());
594     }
595     if (auto dstOp = fusedProducerValue.value()
596                          .getDefiningOp<DestinationStyleOpInterface>()) {
597       scf::ForOp innerMostLoop = loops.back();
598       updateDestinationOperandsForTiledOp(
599           rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
600           innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
601     }
602   }
603   return scf::SCFFuseProducerOfSliceResult{fusableProducer,
604                                            fusedProducerValue.value()};
605 }
606 
607 /// Reconstruct the fused producer from within the tiled-and-fused code.
608 void mlir::scf::yieldReplacementForFusedProducer(
609     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
610     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
611     MutableArrayRef<scf::ForOp> loops) {
612   auto [fusableProducer, fusedProducerValue] = fusedProducerInfo;
613   SmallVector<Value> initValues;
614   FailureOr<Value> initValue = tensor::getOrCreateDestination(
615       rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
616   if (succeeded(initValue)) {
617     SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets();
618     SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes();
619     SmallVector<Value> yieldedVals =
620         yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
621                          resultOffsets, resultSizes, loops);
622   }
623   if (auto dstStyleProducer =
624           fusedProducerValue.getDefiningOp<DestinationStyleOpInterface>()) {
625     Value dstValue =
626         dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
627             ->get();
628     updateDestinationOperandsForTiledOp(
629         rewriter, dstValue, loops.back().getRegionIterArgs().back());
630   }
631 }
632 
633 /// Implementation of tile consumer and fuse producer greedily.
634 FailureOr<scf::SCFTileAndFuseResult>
635 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
636     RewriterBase &rewriter, TilingInterface consumer,
637     const scf::SCFTileAndFuseOptions &options) {
638   // This transformation is only valid for ops that return values (i.e. not
639   // valid to use with operations that have memref operands).
640   if (!consumer->getNumResults()) {
641     return rewriter.notifyMatchFailure(
642         consumer, "invalid pattern for op with no results");
643   }
644 
645   // 1. First tile the consumer.
646   scf::SCFTileAndFuseResult tileAndFuseResult;
647   llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
648   {
649     FailureOr<scf::SCFTilingResult> tilingResult =
650         tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
651     if (failed(tilingResult))
652       return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
653     for (auto *tiledOp : tilingResult->tiledOps)
654       tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
655     tileAndFuseResult.loops = std::move(tilingResult->loops);
656     for (const auto &result : llvm::enumerate(
657              llvm::zip(consumer->getResults(), tilingResult->replacements))) {
658       tileAndFuseResult.replacements[std::get<0>(result.value())] =
659           std::get<1>(result.value());
660       yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
661           result.index())] = result.index();
662     }
663   }
664 
665   // If there are no loops generated, fusion is immaterial.
666   if (tileAndFuseResult.loops.empty())
667     return tileAndFuseResult;
668 
669   // 2. Typically, the operands of the tiled operation are slices of the
670   //    operands of the untiled operation. These are expressed in IR using
671   //    `tensor.extract_slice` operations with source being the operands of the
672   //    untiled operation. Create a worklist of these `tensor.extract_slice`
673   //    operations. If the producers of the source of the `tensor.extract_slice`
674   //    can be tiled such that the tiled value is generated in-place, that
675   //    effectively tiles + fuses the operations.
676   auto addCandidateSlices = [](Operation *fusedOp,
677                                std::deque<tensor::ExtractSliceOp> &candidates) {
678     for (Value operand : fusedOp->getOperands())
679       if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
680         candidates.push_back(sliceOp);
681   };
682 
683   std::deque<tensor::ExtractSliceOp> candidates;
684   addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
685   OpBuilder::InsertionGuard g(rewriter);
686   while (!candidates.empty()) {
687     // Traverse the slices in BFS fashion.
688     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
689     candidates.pop_front();
690 
691     // The operands of the fused producer might themselved be slices of
692     // values produced by operations that implement the `TilingInterface`.
693     // Add these operations to the worklist.
694     std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
695         tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
696                                    tileAndFuseResult.loops);
697     if (!fusedProducer)
698       continue;
699 
700     if (Operation *tiledAndFusedOp =
701             fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
702       tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
703       addCandidateSlices(tiledAndFusedOp, candidates);
704     }
705   }
706   return tileAndFuseResult;
707 }
708 
709 //===----------------------------------------------------------------------===//
710 // lowerToLoopsUsingSCFForOp implementation.
711 //===----------------------------------------------------------------------===//
712 
713 FailureOr<SmallVector<scf::ForOp>>
714 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
715                                      TilingInterface op) {
716   // TODO: Handle cases where the op has results if needed.
717   if (op->getNumResults() > 0) {
718     return rewriter.notifyMatchFailure(
719         op, "unable to lower to loops operations with return values");
720   }
721 
722   SmallVector<Range> domain = op.getIterationDomain(rewriter);
723   SmallVector<Value> ivs;
724   SmallVector<scf::ForOp> loops;
725   Location loc = op.getLoc();
726   for (auto loopRange : domain) {
727     Value offsetVal =
728         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
729     Value sizeVal =
730         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
731     Value strideVal =
732         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
733     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
734                                             strideVal, ValueRange{});
735     loops.push_back(loop);
736     ivs.push_back(loop.getInductionVar());
737     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
738   }
739   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
740     return failure();
741   }
742   return loops;
743 }
744