xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision f080f1122f42ec71571081343f7ef2dcc237a02c)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
25 #include "mlir/Interfaces/TilingInterface.h"
26 #include "llvm/Support/Debug.h"
27 #include <optional>
28 
29 #define DEBUG_TYPE "tile-using-interface"
30 
31 using namespace mlir;
32 
33 scf::SCFTilingOptions &
34 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
35   assert(!tileSizeComputationFunction && "tile sizes already set");
36   SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
37   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
38     OpBuilder::InsertionGuard guard(b);
39     b.setInsertionPointToStart(
40         &op->getParentWithTrait<OpTrait::IsIsolatedFromAbove>()
41              ->getRegion(0)
42              .front());
43     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
44       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
45       return v;
46     }));
47   };
48   return *this;
49 }
50 
51 /// Helper method to adjust the interchange vector to match the iteration
52 /// domain.
53 static SmallVector<int64_t>
54 fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
55                       size_t iterationDomainSize) {
56   SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
57   if (filledVector.size() < iterationDomainSize) {
58     auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
59     filledVector.append(range.begin(), range.end());
60   }
61   if (filledVector.size() > iterationDomainSize)
62     filledVector.resize(iterationDomainSize);
63   return filledVector;
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // tileUsingSCFForOp implementation.
68 //===----------------------------------------------------------------------===//
69 
70 // Check if `stride` evenly divides the trip count `size - offset`.
71 static bool tileDividesIterationDomain(Range loopRange) {
72   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
73   if (!offsetAsInt)
74     return false;
75   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
76   if (!sizeAsInt)
77     return false;
78   std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
79   if (!strideAsInt)
80     return false;
81   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
82 }
83 
84 /// Returns the bounded tile size given the current `iv`, `loopRange` and
85 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
86 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
87                                        Range loopRange, Value iv,
88                                        Value tileSize) {
89   std::optional<int64_t> ts = getConstantIntValue(tileSize);
90   if (ts && ts.value() == 1)
91     return getAsOpFoldResult(tileSize);
92 
93   if (tileDividesIterationDomain(
94           Range{loopRange.offset, loopRange.size, tileSize}))
95     return tileSize;
96 
97   // The tile size to use (to avoid out of bounds access) is  minimum of
98   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
99   // loop.
100   AffineExpr s0, s1, d0;
101   bindDims(b.getContext(), d0);
102   bindSymbols(b.getContext(), s0, s1);
103   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
104   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
105   return makeComposedFoldedAffineMin(
106       b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
107 }
108 
109 /// Generate an empty loop nest that represents the tiled loop nest shell.
110 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
111 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
112 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
113 /// the
114 ///   tile processed within the inner most loop.
115 static SmallVector<scf::ForOp>
116 generateTileLoopNest(OpBuilder &builder, Location loc,
117                      ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
118                      SmallVector<OpFoldResult> &offsets,
119                      SmallVector<OpFoldResult> &sizes) {
120   assert(!loopRanges.empty() && "expected at least one loop range");
121   assert(loopRanges.size() == tileSizeVals.size() &&
122          "expected as many tile sizes as loop ranges");
123   OpBuilder::InsertionGuard guard(builder);
124   SmallVector<scf::ForOp> loops;
125   offsets.resize(loopRanges.size());
126   sizes.resize(loopRanges.size());
127 
128   for (auto loopRange : llvm::enumerate(loopRanges)) {
129     Value offset =
130         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
131     Value size =
132         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
133     Value tileSize = tileSizeVals[loopRange.index()];
134     // No loops if tile size is zero. Set offset and size to the loop
135     // offset and size.
136     if (matchPattern(tileSize, m_Zero())) {
137       offsets[loopRange.index()] = offset;
138       sizes[loopRange.index()] = size;
139       continue;
140     }
141 
142     auto loop = builder.create<scf::ForOp>(
143         loc, offset, size, tileSize, ValueRange{},
144         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
145             ValueRange /*iterArgs*/) {
146           sizes[loopRange.index()] = getBoundedTileSize(
147               bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize);
148           builder.create<scf::YieldOp>(loc);
149         });
150     offsets[loopRange.index()] = loop.getInductionVar();
151     loops.push_back(loop);
152     builder.setInsertionPoint(loop.getBody()->getTerminator());
153   }
154   return loops;
155 }
156 
157 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
158 /// construct the destructive update pattern that inserts the yielded
159 /// value into a destination tensor provided by `initValue` at offset
160 /// `tileOffsets` and size `tileSizes`. For example,
161 ///
162 /// ```mlir
163 /// scf.for %iv0 = ... {
164 ///   %0 = tiled_op
165 /// }
166 /// ```
167 ///
168 /// is transformed to
169 ///
170 /// ```mlir
171 /// scf.for %iv0 = ... iter_args(%arg = %0) {
172 ///   %1 = tensor.extract_slice %arg
173 ///   %2 = tiled_op
174 ///   %3 = tensor.insert_slice %2 into %arg
175 ///   scf.yield %3
176 /// }
177 /// ```
178 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
179 static SmallVector<Value>
180 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
181                  ValueRange yieldedValues,
182                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
183                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
184                  MutableArrayRef<scf::ForOp> loops) {
185   NewYieldValueFn yieldValueFn =
186       [&](OpBuilder &b, Location loc,
187           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
188     SmallVector<Value> inserts;
189     for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) {
190       ArrayRef<OpFoldResult> tileOffsets =
191           tileOffsetsList[yieldedValue.index()];
192       ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
193       SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
194                                             b.getIndexAttr(1));
195       Value insert = b.create<tensor::InsertSliceOp>(
196           loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
197           tileOffsets, tileSizes, tileStrides);
198       inserts.push_back(insert);
199     }
200     return inserts;
201   };
202 
203   SmallVector<scf::ForOp> newLoops =
204       replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
205                                    /*replaceIterOperandsUsesInLoop =*/false);
206   for (const auto &loop : llvm::enumerate(loops)) {
207     rewriter.eraseOp(loop.value());
208     loops[loop.index()] = newLoops[loop.index()];
209   }
210   return llvm::to_vector(llvm::map_range(
211       loops.front().getResults().take_back(yieldedValues.size()),
212       [](OpResult r) -> Value { return r; }));
213 }
214 
215 /// If the tiled operation is destination passing style, update the
216 /// slice of the destination used (which refers to the untiled destination)
217 /// to use the corresponding region argument of the innermost loop.
218 ///
219 /// ```mlir
220 /// %0 =
221 /// scf.for %iv0 = ... iter_args(%arg = %0) {
222 ///   %1 = tensor.extract_slice %0
223 ///   %2 = tiled_op
224 ///   %3 = tensor.insert_slice %2 into %arg
225 ///   scf.yield %3
226 /// }
227 /// ```
228 ///
229 /// is transformed to
230 ///
231 /// ```mlir
232 /// scf.for %iv0 = ... iter_args(%arg = %0) {
233 ///   %1 = tensor.extract_slice %arg
234 ///   %2 = tiled_op
235 ///   %3 = tensor.insert_slice %2 into %arg
236 ///   scf.yield %3
237 /// }
238 /// ```
239 static void
240 updateDestinationOperandsForTiledOp(OpBuilder &builder,
241                                     ValueRange tiledOpDestinationValues,
242                                     ValueRange bbArgsList) {
243   for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) {
244     auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
245     if (!sliceOp)
246       continue;
247     sliceOp.setOperand(0, bbArgsList[destValue.index()]);
248   }
249 }
250 
251 /// Helper method to yield the values of the tiled op, as well as
252 /// update the destination operands of the tiled op, if it is
253 /// a destination passing style op.
254 static SmallVector<Value>
255 yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues,
256                  TilingResult tilingResult,
257                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
258                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
259                  MutableArrayRef<scf::ForOp> loops) {
260   SmallVector<Value> replacements =
261       yieldTiledValues(rewriter, initValues, tilingResult.tiledValues,
262                        tileOffsetsList, tileSizesList, loops);
263   for (auto tiledOp : tilingResult.tiledOps) {
264     if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
265       auto innerMostLoop = loops.back();
266       SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands();
267       updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
268                                           innerMostLoop.getRegionIterArgs());
269     }
270   }
271   return replacements;
272 }
273 
274 /// Implementation of tiling transformation of `op` that implements the
275 /// `TilingInterface` using `scf.for` to iterate over the tiles.
276 FailureOr<scf::SCFTilingResult>
277 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
278                              const scf::SCFTilingOptions &options) {
279   OpBuilder::InsertionGuard guard(rewriter);
280   rewriter.setInsertionPointAfter(op);
281 
282   if (!options.tileSizeComputationFunction) {
283     return rewriter.notifyMatchFailure(
284         op, "missing tile size computation function");
285   }
286 
287   // 1. Get the range of the loops that are represented by the operation.
288   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
289   size_t numLoops = iterationDomain.size();
290   if (numLoops == 0) {
291     return rewriter.notifyMatchFailure(
292         op, "unable to tile op with no iteration domain");
293   }
294 
295   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
296   // skips tiling a particular dimension. This convention is significantly
297   // simpler to handle instead of adjusting affine maps to account for missing
298   // dimensions.
299   SmallVector<Value> tileSizeVector =
300       options.tileSizeComputationFunction(rewriter, op);
301   if (tileSizeVector.size() < iterationDomain.size()) {
302     auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
303     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
304   }
305 
306   scf::SCFTilingResult tilingResult;
307   SmallVector<OpFoldResult> offsets, sizes;
308   {
309     // If there is an interchange specified, permute the iteration domain and
310     // the tile sizes.
311     SmallVector<int64_t> interchangeVector;
312     if (!options.interchangeVector.empty()) {
313       interchangeVector = fillInterchangeVector(options.interchangeVector,
314                                                 iterationDomain.size());
315     }
316     if (!interchangeVector.empty()) {
317       if (!isPermutationVector(interchangeVector)) {
318         return rewriter.notifyMatchFailure(
319             op, "invalid intechange vector, not a permutation of the entire "
320                 "iteration space");
321       }
322 
323       applyPermutationToVector(iterationDomain, interchangeVector);
324       applyPermutationToVector(tileSizeVector, interchangeVector);
325     }
326 
327     // 3. Materialize an empty loop nest that iterates over the tiles. These
328     // loops for now do not return any values even if the original operation has
329     // results.
330     tilingResult.loops = generateTileLoopNest(
331         rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
332 
333     if (!interchangeVector.empty()) {
334       auto inversePermutation = invertPermutationVector(interchangeVector);
335       applyPermutationToVector(offsets, inversePermutation);
336       applyPermutationToVector(sizes, inversePermutation);
337     }
338   }
339 
340   LLVM_DEBUG({
341     if (!tilingResult.loops.empty()) {
342       llvm::dbgs() << "LoopNest shell :\n";
343       tilingResult.loops.front().dump();
344       llvm::dbgs() << "\n";
345     }
346   });
347 
348   // 4. Generate the tiled implementation within the inner most loop.
349   if (!tilingResult.loops.empty())
350     rewriter.setInsertionPoint(
351         tilingResult.loops.back().getBody()->getTerminator());
352   FailureOr<TilingResult> tiledImplementation =
353       op.getTiledImplementation(rewriter, offsets, sizes);
354   tilingResult.tiledOps.append(tiledImplementation->tiledOps);
355   if (op->getNumResults() == 0) {
356     // nothing more to do.
357     return tilingResult;
358   }
359 
360   // If loops are empty, the tiled op is used as the replacement for the untiled
361   // op.
362   if (tilingResult.loops.empty()) {
363     tilingResult.replacements = tiledImplementation->tiledValues;
364     return tilingResult;
365   }
366 
367   // 5. Yield all the results of the tiled operation. The surrounding loop
368   //    nest is modified to insert a destructive update pattern to yield
369   //    from the loop nest values to replace the untiled op with.
370   int64_t numResults = op->getNumResults();
371   SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
372       resultSizesList(numResults);
373   for (const auto &result : llvm::enumerate(op->getResults())) {
374     if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
375                                         sizes,
376                                         resultOffsetsList[result.index()],
377                                         resultSizesList[result.index()]))) {
378       return rewriter.notifyMatchFailure(
379           op, "failed to get slice of result produced");
380     }
381   }
382 
383   SmallVector<Value> destinationTensors;
384   if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
385                                              destinationTensors)))
386     return rewriter.notifyMatchFailure(op, "failed to get destinations");
387 
388   tilingResult.replacements = yieldTiledValues(
389       rewriter, destinationTensors, tiledImplementation.value(),
390       resultOffsetsList, resultSizesList, tilingResult.loops);
391 
392   LLVM_DEBUG({
393     if (!tilingResult.loops.empty()) {
394       llvm::dbgs() << "After tiled implementation :\n";
395       tilingResult.loops.front().dump();
396       llvm::dbgs() << "\n";
397     }
398   });
399   return tilingResult;
400 }
401 
402 FailureOr<scf::SCFReductionTilingResult>
403 mlir::scf::tileReductionUsingScf(RewriterBase &b,
404                                  PartialReductionOpInterface op,
405                                  ArrayRef<OpFoldResult> tileSize) {
406   Location loc = op.getLoc();
407   // Ops implementing PartialReductionOpInterface are expected to implement
408   // TilingInterface.
409   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
410   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
411   SmallVector<Value> tileSizeVector =
412       getValueOrCreateConstantIndexOp(b, loc, tileSize);
413   if (tileSizeVector.size() < iterationDomain.size()) {
414     auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
415     tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
416   }
417   if (op->getNumResults() != 1)
418     return b.notifyMatchFailure(
419         op, "don't support ops with multiple results for now");
420   SmallVector<utils::IteratorType> iterators =
421       tilingInterfaceOp.getLoopIteratorTypes();
422   int64_t numReductionDims = llvm::count(
423       tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction);
424   if (numReductionDims != 1)
425     return b.notifyMatchFailure(
426         op, "only support ops with one reduction dimension.");
427   int reductionDim;
428   for (auto [idx, iteratorType] :
429        llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
430     if (iteratorType == utils::IteratorType::reduction) {
431       reductionDim = idx;
432       break;
433     }
434   }
435   if (static_cast<size_t>(reductionDim) >= tileSize.size())
436     return b.notifyMatchFailure(op, "reduction dimension must be tiled");
437 
438   // 1. create the inital tensor value.
439   FailureOr<Operation *> identityTensor =
440       op.generateInitialTensorForPartialReduction(b, loc, tileSize,
441                                                   reductionDim);
442   if (failed(identityTensor))
443     return b.notifyMatchFailure(op,
444                                 "cannot create a tensor of identity value.");
445   // 2. Create the nested loops.
446   SmallVector<OpFoldResult> offsets, sizes;
447   SmallVector<scf::ForOp> loops = generateTileLoopNest(
448       b, loc, iterationDomain, tileSizeVector, offsets, sizes);
449 
450   // 3. Generate the tiled implementation within the inner most loop.
451   b.setInsertionPoint(loops.back().getBody()->getTerminator());
452   Operation *parallelOp = op.tileToPartialReduction(
453       b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDim);
454 
455   SmallVector<OpFoldResult> resultSizesList;
456   for (size_t i = 0; i < offsets.size(); i++)
457     resultSizesList.push_back(
458         b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i));
459   SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
460   SmallVector<Value> replacements = yieldTiledValues(
461       b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
462       resultSizesList, loops);
463 
464   auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
465   auto innerMostLoop = loops.back();
466   SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
467   assert(destinationTensors.size() ==
468              innerMostLoop.getRegionIterArgs().size() &&
469          "unexpected number of outputs");
470   updateDestinationOperandsForTiledOp(b, destinationTensors,
471                                       innerMostLoop.getRegionIterArgs());
472 
473   // 4. Apply the merge reduction to combine all the partial values.
474   b.setInsertionPointAfter(*loops.begin());
475   Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim);
476   b.replaceOp(op, mergeOp->getResults());
477 
478   SCFReductionTilingResult results;
479   results.initialOp = *identityTensor;
480   results.loops = std::move(loops);
481   results.parallelTiledOp = parallelOp;
482   results.mergeOp = mergeOp;
483   return results;
484 }
485 //===----------------------------------------------------------------------===//
486 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
487 //===----------------------------------------------------------------------===//
488 
489 /// Return the untiled producer whose slice is used in a tiled consumer. The
490 /// method traverses the tile loop nest (`loops`) if needed, and returns the
491 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
492 /// indicates that this is a destination operand of the consumer. If there was
493 /// no loop traversal needed, the second value of the returned tuple is empty.
494 static std::tuple<OpResult, std::optional<OpOperand *>>
495 getUntiledProducerFromSliceSource(OpOperand *source,
496                                   ArrayRef<scf::ForOp> loops) {
497   std::optional<OpOperand *> destinationIterArg;
498   auto loopIt = loops.rbegin();
499   while (auto iterArg = source->get().dyn_cast<BlockArgument>()) {
500     scf::ForOp loop = *loopIt;
501     if (iterArg.getOwner()->getParentOp() != loop)
502       break;
503     source = &loop.getOpOperandForRegionIterArg(iterArg);
504     loopIt++;
505   }
506   if (loopIt == loops.rend())
507     destinationIterArg = source;
508   return {source->get().dyn_cast<OpResult>(), destinationIterArg};
509 }
510 
511 /// Implementation of fusing producer of a single slice by computing the
512 /// slice of the producer in-place.
513 std::optional<scf::SCFFuseProducerOfSliceResult>
514 mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
515                                       tensor::ExtractSliceOp candidateSliceOp,
516                                       MutableArrayRef<scf::ForOp> loops) {
517   // 1. Get the producer of the source (potentially walking through
518   // `iter_args` of nested `scf.for`)
519   auto [fusableProducer, destinationIterArg] =
520       getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
521                                         loops);
522   if (!fusableProducer)
523     return std::nullopt;
524 
525   // 2. Generate the tiled implementation of the producer of the source
526   OpBuilder::InsertionGuard g(rewriter);
527   rewriter.setInsertionPoint(candidateSliceOp);
528   FailureOr<TilingResult> tileAndFuseResult =
529       tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
530                                                    fusableProducer);
531   if (failed(tileAndFuseResult))
532     return std::nullopt;
533   rewriter.replaceAllUsesWith(candidateSliceOp,
534                               tileAndFuseResult->tiledValues[0]);
535 
536   // 3. If the slice is for a destination operand, for example,
537   //
538   // ```mlir
539   // %0 = linalg.init
540   // %1 = linalg.fill .. outs(%0 : )
541   // %2 = scf.for .. iter_args(%arg0 = %1) {
542   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
543   //     %4 = tensor.extract_slice %arg1 [..]
544   //     .. = linalg.matmul .. outs(%4 : )
545   //   }
546   // }
547   // ```
548   //
549   // the IR is currently
550   //
551   // ```
552   // %0 = linalg.init
553   // %1 = linalg.fill
554   // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
555   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
556   //     %4 = tensor.extract_slice %0 /*incorrect value */ [..]
557   //     %5 = linalg.fill .. outs(%4 : )
558   //     .. = linalg.matmul .. outs(%5 : )
559   //   }
560   // }
561   // ```
562   //
563   // The untiled `linalg.fill` is still used as the `init_value` since it
564   // was originally a destination operand of the untiled `linalg.matmul`.
565   // When fusing an operand that is a destination operand.
566   //   - Update the iter_arg of the outer most loop to use the destination
567   //     of the untiled producer.
568   //   - Update the destination of the slice of the tiled producer generated
569   //     to use the same basic block argument as the slice that was used to
570   //     generate inplace the tiled implementation of the producer.
571   // With this the IR will be.
572   //
573   // ```
574   // %0 = linalg.init
575   // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
576   //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
577   //     %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
578   //     %4 = linalg.fill .. outs(%3 : )
579   //     .. = linalg.matmul .. outs(%4 : )
580   //   }
581   // }
582   // ```
583   // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
584   // Update to use that when it does become available.
585   scf::ForOp outerMostLoop = loops.front();
586   std::optional<unsigned> iterArgNumber;
587   if (destinationIterArg) {
588     iterArgNumber =
589         outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
590   }
591   if (iterArgNumber) {
592     int64_t resultNumber = fusableProducer.getResultNumber();
593     if (auto dstOp =
594             dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
595       outerMostLoop.setIterArg(iterArgNumber.value(),
596                                dstOp.getTiedOpOperand(fusableProducer)->get());
597     }
598     for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
599       auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
600       if (!dstOp)
601         continue;
602       scf::ForOp innerMostLoop = loops.back();
603       updateDestinationOperandsForTiledOp(
604           rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
605           innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
606     }
607   }
608   return scf::SCFFuseProducerOfSliceResult{fusableProducer,
609                                            tileAndFuseResult->tiledValues[0],
610                                            tileAndFuseResult->tiledOps};
611 }
612 
613 /// Reconstruct the fused producer from within the tiled-and-fused code.
614 void mlir::scf::yieldReplacementForFusedProducer(
615     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
616     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
617     MutableArrayRef<scf::ForOp> loops) {
618   auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
619       fusedProducerInfo;
620   SmallVector<Value> initValues;
621   FailureOr<Value> initValue = tensor::getOrCreateDestination(
622       rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
623   if (succeeded(initValue)) {
624     SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets();
625     SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes();
626     SmallVector<Value> yieldedVals =
627         yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
628                          resultOffsets, resultSizes, loops);
629   }
630   for (auto tileAndFusedOp : tileAndFusedOps) {
631     auto dstStyleProducer =
632         dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
633     if (!dstStyleProducer)
634       continue;
635     Value dstValue =
636         dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
637             ->get();
638     updateDestinationOperandsForTiledOp(
639         rewriter, dstValue, loops.back().getRegionIterArgs().back());
640   }
641 }
642 
643 /// Implementation of tile consumer and fuse producer greedily.
644 FailureOr<scf::SCFTileAndFuseResult>
645 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
646     RewriterBase &rewriter, TilingInterface consumer,
647     const scf::SCFTileAndFuseOptions &options) {
648   // This transformation is only valid for ops that return values (i.e. not
649   // valid to use with operations that have memref operands).
650   if (!consumer->getNumResults()) {
651     return rewriter.notifyMatchFailure(
652         consumer, "invalid pattern for op with no results");
653   }
654 
655   // 1. First tile the consumer.
656   scf::SCFTileAndFuseResult tileAndFuseResult;
657   llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
658   {
659     FailureOr<scf::SCFTilingResult> tilingResult =
660         tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
661     if (failed(tilingResult))
662       return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
663     for (auto *tiledOp : tilingResult->tiledOps)
664       tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
665     tileAndFuseResult.loops = std::move(tilingResult->loops);
666     for (const auto &result : llvm::enumerate(
667              llvm::zip(consumer->getResults(), tilingResult->replacements))) {
668       tileAndFuseResult.replacements[std::get<0>(result.value())] =
669           std::get<1>(result.value());
670       yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
671           result.index())] = result.index();
672     }
673   }
674 
675   // If there are no loops generated, fusion is immaterial.
676   if (tileAndFuseResult.loops.empty())
677     return tileAndFuseResult;
678 
679   // 2. Typically, the operands of the tiled operation are slices of the
680   //    operands of the untiled operation. These are expressed in IR using
681   //    `tensor.extract_slice` operations with source being the operands of the
682   //    untiled operation. Create a worklist of these `tensor.extract_slice`
683   //    operations. If the producers of the source of the `tensor.extract_slice`
684   //    can be tiled such that the tiled value is generated in-place, that
685   //    effectively tiles + fuses the operations.
686   auto addCandidateSlices = [](Operation *fusedOp,
687                                std::deque<tensor::ExtractSliceOp> &candidates) {
688     for (Value operand : fusedOp->getOperands())
689       if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
690         candidates.push_back(sliceOp);
691   };
692 
693   std::deque<tensor::ExtractSliceOp> candidates;
694   addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
695   OpBuilder::InsertionGuard g(rewriter);
696   while (!candidates.empty()) {
697     // Traverse the slices in BFS fashion.
698     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
699     candidates.pop_front();
700 
701     // The operands of the fused producer might themselved be slices of
702     // values produced by operations that implement the `TilingInterface`.
703     // Add these operations to the worklist.
704     std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
705         tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
706                                    tileAndFuseResult.loops);
707     if (!fusedProducer)
708       continue;
709 
710     if (Operation *tiledAndFusedOp =
711             fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
712       tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
713       addCandidateSlices(tiledAndFusedOp, candidates);
714     }
715   }
716   return tileAndFuseResult;
717 }
718 
719 //===----------------------------------------------------------------------===//
720 // lowerToLoopsUsingSCFForOp implementation.
721 //===----------------------------------------------------------------------===//
722 
723 FailureOr<SmallVector<scf::ForOp>>
724 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
725                                      TilingInterface op) {
726   // TODO: Handle cases where the op has results if needed.
727   if (op->getNumResults() > 0) {
728     return rewriter.notifyMatchFailure(
729         op, "unable to lower to loops operations with return values");
730   }
731 
732   SmallVector<Range> domain = op.getIterationDomain(rewriter);
733   SmallVector<Value> ivs;
734   SmallVector<scf::ForOp> loops;
735   Location loc = op.getLoc();
736   for (auto loopRange : domain) {
737     Value offsetVal =
738         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
739     Value sizeVal =
740         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
741     Value strideVal =
742         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
743     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
744                                             strideVal, ValueRange{});
745     loops.push_back(loop);
746     ivs.push_back(loop.getInductionVar());
747     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
748   }
749   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
750     return failure();
751   }
752   return loops;
753 }
754