xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision 93c42299bdb1ef094857ef2d065670af0695c26b)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
25 #include "mlir/Interfaces/TilingInterface.h"
26 #include "llvm/Support/Debug.h"
27 #include <optional>
28 
29 #define DEBUG_TYPE "tile-using-interface"
30 
31 using namespace mlir;
32 
33 scf::SCFTilingOptions &
34 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
35   assert(!tileSizeComputationFunction && "tile sizes already set");
36   auto tileSizes = llvm::to_vector(ts);
37   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
38     return tileSizes;
39   };
40   return *this;
41 }
42 
43 /// Helper method to adjust the interchange vector to match the iteration
44 /// domain.
45 static SmallVector<int64_t>
46 fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
47                       size_t iterationDomainSize) {
48   SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
49   if (filledVector.size() < iterationDomainSize) {
50     auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
51     filledVector.append(range.begin(), range.end());
52   }
53   if (filledVector.size() > iterationDomainSize)
54     filledVector.resize(iterationDomainSize);
55   return filledVector;
56 }
57 
58 /// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
59 template <typename SrcOpTy>
60 static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
61   return llvm::to_vector(
62       llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
63 }
64 template <typename SrcOpTy>
65 static SmallVector<Operation *>
66 getAsOperations(const SmallVector<SrcOpTy> &ops) {
67   return getAsOperations(ArrayRef<SrcOpTy>(ops));
68 }
69 
70 /// Convert a list of `Operation *` to a list of `DstOpTy.
71 template <typename DstOpTy>
72 static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
73   return llvm::to_vector(
74       llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
75 }
76 template <typename DstOpTy>
77 static SmallVector<DstOpTy>
78 castToTypedOperations(const SmallVector<Operation *> &ops) {
79   return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // tileUsingSCFForOp implementation.
84 //===----------------------------------------------------------------------===//
85 
86 // Check if `stride` evenly divides the trip count `size - offset`.
87 static bool tileDividesIterationDomain(Range loopRange) {
88   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
89   if (!offsetAsInt)
90     return false;
91   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
92   if (!sizeAsInt)
93     return false;
94   std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
95   if (!strideAsInt)
96     return false;
97   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
98 }
99 
100 /// Returns the bounded tile size given the current `iv`, `loopRange` and
101 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
102 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
103                                        Range loopRange, Value iv,
104                                        OpFoldResult tileSize) {
105   if (isConstantIntValue(tileSize, 1))
106     return tileSize;
107 
108   if (tileDividesIterationDomain(
109           Range{loopRange.offset, loopRange.size, tileSize}))
110     return tileSize;
111 
112   // The tile size to use (to avoid out of bounds access) is  minimum of
113   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
114   // loop.
115   AffineExpr s0, s1, d0;
116   bindDims(b.getContext(), d0);
117   bindSymbols(b.getContext(), s0, s1);
118   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
119   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
120   return affine::makeComposedFoldedAffineMin(
121       b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
122 }
123 
124 /// Generate an empty loop nest that represents the tiled loop nest shell.
125 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
126 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
127 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
128 /// the
129 ///   tile processed within the inner most loop.
130 static SmallVector<scf::ForOp> generateTileLoopNest(
131     OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
132     ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
133     SmallVector<OpFoldResult> &sizes) {
134   assert(!loopRanges.empty() && "expected at least one loop range");
135   assert(loopRanges.size() == tileSizes.size() &&
136          "expected as many tile sizes as loop ranges");
137   OpBuilder::InsertionGuard guard(builder);
138   SmallVector<scf::ForOp> loops;
139   offsets.resize(loopRanges.size());
140   sizes.resize(loopRanges.size());
141 
142   for (auto loopRange : llvm::enumerate(loopRanges)) {
143     Value offset =
144         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
145     Value size =
146         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
147     Value tileSize = getValueOrCreateConstantIndexOp(
148         builder, loc, tileSizes[loopRange.index()]);
149     // No loops if tile size is zero. Set offset and size to the loop
150     // offset and size.
151     if (matchPattern(tileSize, m_Zero())) {
152       offsets[loopRange.index()] = offset;
153       sizes[loopRange.index()] = size;
154       continue;
155     }
156 
157     auto loop = builder.create<scf::ForOp>(
158         loc, offset, size, tileSize, ValueRange{},
159         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
160             ValueRange /*iterArgs*/) {
161           sizes[loopRange.index()] = getBoundedTileSize(
162               bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize);
163           builder.create<scf::YieldOp>(loc);
164         });
165     offsets[loopRange.index()] = loop.getInductionVar();
166     loops.push_back(loop);
167     builder.setInsertionPoint(loop.getBody()->getTerminator());
168   }
169   return loops;
170 }
171 
172 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
173 /// construct the destructive update pattern that inserts the yielded
174 /// value into a destination tensor provided by `initValue` at offset
175 /// `tileOffsets` and size `tileSizes`. For example,
176 ///
177 /// ```mlir
178 /// scf.for %iv0 = ... {
179 ///   %0 = tiled_op
180 /// }
181 /// ```
182 ///
183 /// is transformed to
184 ///
185 /// ```mlir
186 /// scf.for %iv0 = ... iter_args(%arg = %0) {
187 ///   %1 = tensor.extract_slice %arg
188 ///   %2 = tiled_op
189 ///   %3 = tensor.insert_slice %2 into %arg
190 ///   scf.yield %3
191 /// }
192 /// ```
193 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
194 static SmallVector<Value>
195 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
196                  ValueRange yieldedValues,
197                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
198                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
199                  MutableArrayRef<scf::ForOp> loops) {
200   NewYieldValueFn yieldValueFn =
201       [&](OpBuilder &b, Location loc,
202           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
203     SmallVector<Value> inserts;
204     for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) {
205       ArrayRef<OpFoldResult> tileOffsets =
206           tileOffsetsList[yieldedValue.index()];
207       ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
208       SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
209                                             b.getIndexAttr(1));
210       Value insert = b.create<tensor::InsertSliceOp>(
211           loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
212           tileOffsets, tileSizes, tileStrides);
213       inserts.push_back(insert);
214     }
215     return inserts;
216   };
217 
218   SmallVector<scf::ForOp> newLoops =
219       replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
220                                    /*replaceIterOperandsUsesInLoop =*/false);
221   for (const auto &loop : llvm::enumerate(loops)) {
222     rewriter.eraseOp(loop.value());
223     loops[loop.index()] = newLoops[loop.index()];
224   }
225   return llvm::to_vector(llvm::map_range(
226       loops.front().getResults().take_back(yieldedValues.size()),
227       [](OpResult r) -> Value { return r; }));
228 }
229 
230 /// If the tiled operation is destination passing style, update the
231 /// slice of the destination used (which refers to the untiled destination)
232 /// to use the corresponding region argument of the innermost loop.
233 ///
234 /// ```mlir
235 /// %0 =
236 /// scf.for %iv0 = ... iter_args(%arg = %0) {
237 ///   %1 = tensor.extract_slice %0
238 ///   %2 = tiled_op
239 ///   %3 = tensor.insert_slice %2 into %arg
240 ///   scf.yield %3
241 /// }
242 /// ```
243 ///
244 /// is transformed to
245 ///
246 /// ```mlir
247 /// scf.for %iv0 = ... iter_args(%arg = %0) {
248 ///   %1 = tensor.extract_slice %arg
249 ///   %2 = tiled_op
250 ///   %3 = tensor.insert_slice %2 into %arg
251 ///   scf.yield %3
252 /// }
253 /// ```
254 static void
255 updateDestinationOperandsForTiledOp(OpBuilder &builder,
256                                     ValueRange tiledOpDestinationValues,
257                                     ValueRange bbArgsList) {
258   for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) {
259     auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
260     if (!sliceOp)
261       continue;
262     sliceOp.setOperand(0, bbArgsList[destValue.index()]);
263   }
264 }
265 
266 /// Helper method to yield the values of the tiled op, as well as
267 /// update the destination operands of the tiled op, if it is
268 /// a destination passing style op.
269 static SmallVector<Value>
270 yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues,
271                  TilingResult tilingResult,
272                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
273                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
274                  MutableArrayRef<scf::ForOp> loops) {
275   SmallVector<Value> replacements =
276       yieldTiledValues(rewriter, initValues, tilingResult.tiledValues,
277                        tileOffsetsList, tileSizesList, loops);
278   for (auto tiledOp : tilingResult.tiledOps) {
279     if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
280       auto innerMostLoop = loops.back();
281       SmallVector<Value> tiledOpDestinationTensors =
282           llvm::to_vector(dstOp.getDpsInits());
283       updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
284                                           innerMostLoop.getRegionIterArgs());
285     }
286   }
287   return replacements;
288 }
289 
290 /// Implementation of tiling transformation of `op` that implements the
291 /// `TilingInterface` using `scf.for` to iterate over the tiles.
292 FailureOr<scf::SCFTilingResult>
293 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
294                              const scf::SCFTilingOptions &options) {
295   OpBuilder::InsertionGuard guard(rewriter);
296   rewriter.setInsertionPointAfter(op);
297 
298   if (!options.tileSizeComputationFunction) {
299     return rewriter.notifyMatchFailure(
300         op, "missing tile size computation function");
301   }
302 
303   // 1. Get the range of the loops that are represented by the operation.
304   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
305   size_t numLoops = iterationDomain.size();
306   if (numLoops == 0) {
307     return rewriter.notifyMatchFailure(
308         op, "unable to tile op with no iteration domain");
309   }
310 
311   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
312   // skips tiling a particular dimension. This convention is significantly
313   // simpler to handle instead of adjusting affine maps to account for missing
314   // dimensions.
315   SmallVector<OpFoldResult> tileSizeVector =
316       options.tileSizeComputationFunction(rewriter, op);
317   if (tileSizeVector.size() < iterationDomain.size()) {
318     auto zero = rewriter.getIndexAttr(0);
319     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
320   }
321 
322   SmallVector<OpFoldResult> offsets, sizes;
323   SmallVector<scf::ForOp> forLoops;
324   {
325     // If there is an interchange specified, permute the iteration domain and
326     // the tile sizes.
327     SmallVector<int64_t> interchangeVector;
328     if (!options.interchangeVector.empty()) {
329       interchangeVector = fillInterchangeVector(options.interchangeVector,
330                                                 iterationDomain.size());
331     }
332     if (!interchangeVector.empty()) {
333       if (!isPermutationVector(interchangeVector)) {
334         return rewriter.notifyMatchFailure(
335             op, "invalid intechange vector, not a permutation of the entire "
336                 "iteration space");
337       }
338 
339       applyPermutationToVector(iterationDomain, interchangeVector);
340       applyPermutationToVector(tileSizeVector, interchangeVector);
341     }
342 
343     // 3. Materialize an empty loop nest that iterates over the tiles. These
344     // loops for now do not return any values even if the original operation has
345     // results.
346     forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
347                                     tileSizeVector, offsets, sizes);
348 
349     if (!interchangeVector.empty()) {
350       auto inversePermutation = invertPermutationVector(interchangeVector);
351       applyPermutationToVector(offsets, inversePermutation);
352       applyPermutationToVector(sizes, inversePermutation);
353     }
354   }
355 
356   LLVM_DEBUG({
357     if (!forLoops.empty()) {
358       llvm::dbgs() << "LoopNest shell :\n";
359       forLoops.front().dump();
360       llvm::dbgs() << "\n";
361     }
362   });
363 
364   // 4. Generate the tiled implementation within the inner most loop.
365   if (!forLoops.empty())
366     rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator());
367   FailureOr<TilingResult> tiledImplementation =
368       op.getTiledImplementation(rewriter, offsets, sizes);
369 
370   if (op->getNumResults() == 0) {
371     return scf::SCFTilingResult{
372         tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
373   }
374 
375   // If loops are empty, the tiled op is used as the replacement for the untiled
376   // op.
377   if (forLoops.empty()) {
378     return scf::SCFTilingResult{tiledImplementation->tiledOps,
379                                 getAsOperations(forLoops),
380                                 tiledImplementation->tiledValues};
381   }
382 
383   // 5. Yield all the results of the tiled operation. The surrounding loop
384   //    nest is modified to insert a destructive update pattern to yield
385   //    from the loop nest values to replace the untiled op with.
386   int64_t numResults = op->getNumResults();
387   SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
388       resultSizesList(numResults);
389   for (const auto &result : llvm::enumerate(op->getResults())) {
390     if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
391                                         sizes,
392                                         resultOffsetsList[result.index()],
393                                         resultSizesList[result.index()]))) {
394       return rewriter.notifyMatchFailure(
395           op, "failed to get slice of result produced");
396     }
397   }
398 
399   SmallVector<Value> destinationTensors;
400   if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
401                                              destinationTensors)))
402     return rewriter.notifyMatchFailure(op, "failed to get destinations");
403 
404   SmallVector<Value> replacements = yieldTiledValues(
405       rewriter, destinationTensors, tiledImplementation.value(),
406       resultOffsetsList, resultSizesList, forLoops);
407   LLVM_DEBUG({
408     if (!forLoops.empty()) {
409       llvm::dbgs() << "After tiled implementation :\n";
410       forLoops.front().dump();
411       llvm::dbgs() << "\n";
412     }
413   });
414   return scf::SCFTilingResult{tiledImplementation->tiledOps,
415                               getAsOperations(forLoops), replacements};
416 }
417 
418 FailureOr<scf::SCFReductionTilingResult>
419 mlir::scf::tileReductionUsingScf(RewriterBase &b,
420                                  PartialReductionOpInterface op,
421                                  ArrayRef<OpFoldResult> tileSizes) {
422   Location loc = op.getLoc();
423   // Ops implementing PartialReductionOpInterface are expected to implement
424   // TilingInterface.
425   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
426   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
427   auto tileSizesVector = llvm::to_vector(tileSizes);
428   if (tileSizesVector.size() < iterationDomain.size()) {
429     auto zero = b.getIndexAttr(0);
430     tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
431                            zero);
432   }
433   if (op->getNumResults() != 1)
434     return b.notifyMatchFailure(
435         op, "don't support ops with multiple results for now");
436   SmallVector<utils::IteratorType> iterators =
437       tilingInterfaceOp.getLoopIteratorTypes();
438 
439   SmallVector<int> reductionDims;
440   for (auto [idx, iteratorType] :
441        llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
442     if (iteratorType == utils::IteratorType::reduction)
443       reductionDims.push_back(idx);
444   }
445 
446   // 1. create the inital tensor value.
447   FailureOr<Operation *> identityTensor =
448       op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
449                                                   reductionDims);
450   if (failed(identityTensor))
451     return b.notifyMatchFailure(op,
452                                 "cannot create a tensor of identity value.");
453   // 2. Create the nested loops.
454   SmallVector<OpFoldResult> offsets, sizes;
455   SmallVector<scf::ForOp> loops = generateTileLoopNest(
456       b, loc, iterationDomain, tileSizesVector, offsets, sizes);
457 
458   // 3. Generate the tiled implementation within the inner most loop.
459   b.setInsertionPoint(loops.back().getBody()->getTerminator());
460   Operation *parallelOp = op.tileToPartialReduction(
461       b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDims);
462 
463   SmallVector<OpFoldResult> resultSizesList;
464   for (size_t i = 0; i < offsets.size(); i++)
465     resultSizesList.push_back(
466         tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
467   SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
468   SmallVector<Value> replacements = yieldTiledValues(
469       b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
470       resultSizesList, loops);
471 
472   auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
473   auto innerMostLoop = loops.back();
474   SmallVector<Value> destinationTensors = llvm::to_vector(dstOp.getDpsInits());
475   assert(destinationTensors.size() ==
476              innerMostLoop.getRegionIterArgs().size() &&
477          "unexpected number of outputs");
478   updateDestinationOperandsForTiledOp(b, destinationTensors,
479                                       innerMostLoop.getRegionIterArgs());
480 
481   // 4. Apply the merge reduction to combine all the partial values.
482   b.setInsertionPointAfter(*loops.begin());
483   Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
484   b.replaceOp(op, mergeOp->getResults());
485 
486   SCFReductionTilingResult results;
487   results.initialOp = *identityTensor;
488   results.loops = std::move(loops);
489   results.parallelTiledOp = parallelOp;
490   results.mergeOp = mergeOp;
491   return results;
492 }
493 
494 //===----------------------------------------------------------------------===//
495 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
496 //===----------------------------------------------------------------------===//
497 
498 /// Return the untiled producer whose slice is used in a tiled consumer. The
499 /// method traverses the tile loop nest (`loops`) if needed, and returns the
500 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
501 /// indicates that this is a destination operand of the consumer. If there was
502 /// no loop traversal needed, the second value of the returned tuple is empty.
503 static std::tuple<OpResult, std::optional<OpOperand *>>
504 getUntiledProducerFromSliceSource(OpOperand *source,
505                                   ArrayRef<scf::ForOp> loops) {
506   std::optional<OpOperand *> destinationIterArg;
507   auto loopIt = loops.rbegin();
508   while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
509     scf::ForOp loop = *loopIt;
510     if (iterArg.getOwner()->getParentOp() != loop)
511       break;
512     source = &loop.getOpOperandForRegionIterArg(iterArg);
513     loopIt++;
514   }
515   if (loopIt == loops.rend())
516     destinationIterArg = source;
517   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
518 }
519 
520 /// Implementation of fusing producer of a single slice by computing the
521 /// slice of the producer in-place.
522 std::optional<scf::SCFFuseProducerOfSliceResult>
523 mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
524                                       tensor::ExtractSliceOp candidateSliceOp,
525                                       MutableArrayRef<scf::ForOp> loops) {
526   // 1. Get the producer of the source (potentially walking through
527   // `iter_args` of nested `scf.for`)
528   auto [fusableProducer, destinationInitArg] =
529       getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
530                                         loops);
531   if (!fusableProducer)
532     return std::nullopt;
533 
534   // 2. Generate the tiled implementation of the producer of the source
535   OpBuilder::InsertionGuard g(rewriter);
536   rewriter.setInsertionPoint(candidateSliceOp);
537   FailureOr<TilingResult> tileAndFuseResult =
538       tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
539                                                    fusableProducer);
540   if (failed(tileAndFuseResult))
541     return std::nullopt;
542   rewriter.replaceAllUsesWith(candidateSliceOp,
543                               tileAndFuseResult->tiledValues[0]);
544 
545   // 3. If the slice is for a destination operand, for example,
546   //
547   // ```mlir
548   // %0 = linalg.init
549   // %1 = linalg.fill .. outs(%0 : )
550   // %2 = scf.for .. iter_args(%arg0 = %1) {
551   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
552   //     %4 = tensor.extract_slice %arg1 [..]
553   //     .. = linalg.matmul .. outs(%4 : )
554   //   }
555   // }
556   // ```
557   //
558   // the IR is currently
559   //
560   // ```
561   // %0 = linalg.init
562   // %1 = linalg.fill
563   // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
564   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
565   //     %4 = tensor.extract_slice %0 /*incorrect value */ [..]
566   //     %5 = linalg.fill .. outs(%4 : )
567   //     .. = linalg.matmul .. outs(%5 : )
568   //   }
569   // }
570   // ```
571   //
572   // The untiled `linalg.fill` is still used as the `init_value` since it
573   // was originally a destination operand of the untiled `linalg.matmul`.
574   // When fusing an operand that is a destination operand.
575   //   - Update the iter_arg of the outer most loop to use the destination
576   //     of the untiled producer.
577   //   - Update the destination of the slice of the tiled producer generated
578   //     to use the same basic block argument as the slice that was used to
579   //     generate inplace the tiled implementation of the producer.
580   // With this the IR will be.
581   //
582   // ```
583   // %0 = linalg.init
584   // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
585   //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
586   //     %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
587   //     %4 = linalg.fill .. outs(%3 : )
588   //     .. = linalg.matmul .. outs(%4 : )
589   //   }
590   // }
591   // ```
592   // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
593   // Update to use that when it does become available.
594   scf::ForOp outerMostLoop = loops.front();
595   if (destinationInitArg &&
596       (*destinationInitArg)->getOwner() == outerMostLoop) {
597     unsigned iterArgNumber =
598         outerMostLoop.getResultForOpOperand(**destinationInitArg)
599             .getResultNumber();
600     int64_t resultNumber = fusableProducer.getResultNumber();
601     if (auto dstOp =
602             dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
603       (*destinationInitArg)
604           ->set(dstOp.getTiedOpOperand(fusableProducer)->get());
605     }
606     for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
607       auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
608       if (!dstOp)
609         continue;
610       scf::ForOp innerMostLoop = loops.back();
611       updateDestinationOperandsForTiledOp(
612           rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
613           innerMostLoop.getRegionIterArgs()[iterArgNumber]);
614     }
615   }
616   return scf::SCFFuseProducerOfSliceResult{fusableProducer,
617                                            tileAndFuseResult->tiledValues[0],
618                                            tileAndFuseResult->tiledOps};
619 }
620 
621 /// Reconstruct the fused producer from within the tiled-and-fused code.
622 void mlir::scf::yieldReplacementForFusedProducer(
623     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
624     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
625     MutableArrayRef<scf::ForOp> loops) {
626   auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
627       fusedProducerInfo;
628   SmallVector<Value> initValues;
629   FailureOr<Value> initValue = tensor::getOrCreateDestination(
630       rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
631   if (succeeded(initValue)) {
632     SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets();
633     SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes();
634     SmallVector<Value> yieldedVals =
635         yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
636                          resultOffsets, resultSizes, loops);
637   }
638   for (auto tileAndFusedOp : tileAndFusedOps) {
639     auto dstStyleProducer =
640         dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
641     if (!dstStyleProducer)
642       continue;
643     Value dstValue =
644         dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
645             ->get();
646     updateDestinationOperandsForTiledOp(
647         rewriter, dstValue, loops.back().getRegionIterArgs().back());
648   }
649 }
650 
651 /// Implementation of tile consumer and fuse producer greedily.
652 FailureOr<scf::SCFTileAndFuseResult>
653 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
654     RewriterBase &rewriter, TilingInterface consumer,
655     const scf::SCFTileAndFuseOptions &options) {
656   // This transformation is only valid for ops that return values (i.e. not
657   // valid to use with operations that have memref operands).
658   if (!consumer->getNumResults()) {
659     return rewriter.notifyMatchFailure(
660         consumer, "invalid pattern for op with no results");
661   }
662 
663   // 1. First tile the consumer.
664   SmallVector<scf::ForOp> forLoops;
665   SetVector<Operation *> fusedProducers, tiledAndFusedOps;
666   DenseMap<Value, Value> replacements;
667   llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
668   {
669     FailureOr<scf::SCFTilingResult> tilingResult =
670         tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
671     if (failed(tilingResult))
672       return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
673     for (auto *tiledOp : tilingResult->tiledOps)
674       tiledAndFusedOps.insert(tiledOp);
675     forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
676     for (auto [index, origValue, replacement] :
677          llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
678       replacements[origValue] = replacement;
679       yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
680           index)] = index;
681     }
682   }
683 
684   // If there are no loops generated, fusion is immaterial.
685   if (forLoops.empty()) {
686     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
687                                      getAsOperations(forLoops), replacements};
688   }
689 
690   // 2. Typically, the operands of the tiled operation are slices of the
691   //    operands of the untiled operation. These are expressed in IR using
692   //    `tensor.extract_slice` operations with source being the operands of the
693   //    untiled operation. Create a worklist of these `tensor.extract_slice`
694   //    operations. If the producers of the source of the `tensor.extract_slice`
695   //    can be tiled such that the tiled value is generated in-place, that
696   //    effectively tiles + fuses the operations.
697   auto addCandidateSlices = [](Operation *fusedOp,
698                                std::deque<tensor::ExtractSliceOp> &candidates) {
699     for (Value operand : fusedOp->getOperands())
700       if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
701         candidates.push_back(sliceOp);
702   };
703 
704   std::deque<tensor::ExtractSliceOp> candidates;
705   addCandidateSlices(tiledAndFusedOps.back(), candidates);
706   OpBuilder::InsertionGuard g(rewriter);
707   while (!candidates.empty()) {
708     // Traverse the slices in BFS fashion.
709     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
710     candidates.pop_front();
711 
712     // The operands of the fused producer might themselved be slices of
713     // values produced by operations that implement the `TilingInterface`.
714     // Add these operations to the worklist.
715     std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
716         tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
717     if (!fusedResult)
718       continue;
719 
720     if (Operation *tiledAndFusedOp =
721             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
722       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
723       tiledAndFusedOps.insert(tiledAndFusedOp);
724       addCandidateSlices(tiledAndFusedOp, candidates);
725     }
726   }
727   return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
728                                    getAsOperations(forLoops), replacements};
729 }
730 
731 //===----------------------------------------------------------------------===//
732 // lowerToLoopsUsingSCFForOp implementation.
733 //===----------------------------------------------------------------------===//
734 
735 FailureOr<SmallVector<scf::ForOp>>
736 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
737                                      TilingInterface op) {
738   // TODO: Handle cases where the op has results if needed.
739   if (op->getNumResults() > 0) {
740     return rewriter.notifyMatchFailure(
741         op, "unable to lower to loops operations with return values");
742   }
743 
744   SmallVector<Range> domain = op.getIterationDomain(rewriter);
745   SmallVector<Value> ivs;
746   SmallVector<scf::ForOp> loops;
747   Location loc = op.getLoc();
748   for (auto loopRange : domain) {
749     Value offsetVal =
750         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
751     Value sizeVal =
752         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
753     Value strideVal =
754         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
755     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
756                                             strideVal, ValueRange{});
757     loops.push_back(loop);
758     ivs.push_back(loop.getInductionVar());
759     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
760   }
761   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
762     return failure();
763   }
764   return loops;
765 }
766