xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision 1c27899e24381a0231003b1aae9c436e78174109)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
25 #include "mlir/Interfaces/TilingInterface.h"
26 #include "llvm/Support/Debug.h"
27 #include <optional>
28 
29 #define DEBUG_TYPE "tile-using-interface"
30 
31 using namespace mlir;
32 
33 scf::SCFTilingOptions &
34 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
35   assert(!tileSizeComputationFunction && "tile sizes already set");
36   auto tileSizes = llvm::to_vector(ts);
37   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
38     return tileSizes;
39   };
40   return *this;
41 }
42 
43 /// Helper method to adjust the interchange vector to match the iteration
44 /// domain.
45 static SmallVector<int64_t>
46 fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
47                       size_t iterationDomainSize) {
48   SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
49   if (filledVector.size() < iterationDomainSize) {
50     auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
51     filledVector.append(range.begin(), range.end());
52   }
53   if (filledVector.size() > iterationDomainSize)
54     filledVector.resize(iterationDomainSize);
55   return filledVector;
56 }
57 
58 /// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
59 template <typename SrcOpTy>
60 static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
61   return llvm::to_vector(
62       llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
63 }
64 template <typename SrcOpTy>
65 static SmallVector<Operation *>
66 getAsOperations(const SmallVector<SrcOpTy> &ops) {
67   return getAsOperations(ArrayRef<SrcOpTy>(ops));
68 }
69 
70 /// Convert a list of `Operation *` to a list of `DstOpTy.
71 template <typename DstOpTy>
72 static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
73   return llvm::to_vector(
74       llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
75 }
76 template <typename DstOpTy>
77 static SmallVector<DstOpTy>
78 castToTypedOperations(const SmallVector<Operation *> &ops) {
79   return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // tileUsingSCFForOp implementation.
84 //===----------------------------------------------------------------------===//
85 
86 // Check if `stride` evenly divides the trip count `size - offset`.
87 static bool tileDividesIterationDomain(Range loopRange) {
88   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
89   if (!offsetAsInt)
90     return false;
91   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
92   if (!sizeAsInt)
93     return false;
94   std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
95   if (!strideAsInt)
96     return false;
97   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
98 }
99 
100 /// Returns the bounded tile size given the current `iv`, `loopRange` and
101 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
102 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
103                                        Range loopRange, Value iv,
104                                        OpFoldResult tileSize) {
105   std::optional<int64_t> ts = getConstantIntValue(tileSize);
106   if (ts && ts.value() == 1)
107     return tileSize;
108 
109   if (tileDividesIterationDomain(
110           Range{loopRange.offset, loopRange.size, tileSize}))
111     return tileSize;
112 
113   // The tile size to use (to avoid out of bounds access) is  minimum of
114   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
115   // loop.
116   AffineExpr s0, s1, d0;
117   bindDims(b.getContext(), d0);
118   bindSymbols(b.getContext(), s0, s1);
119   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
120   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
121   return affine::makeComposedFoldedAffineMin(
122       b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
123 }
124 
125 /// Clones the operation and updates the destination if the operation
126 /// implements the `DestinationStyleOpInterface`.
127 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
128                                                   Operation *op,
129                                                   ValueRange newDestArgs) {
130   Operation *clonedOp = rewriter.clone(*op);
131   if (auto destinationStyleOp =
132           dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
133     destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
134   }
135   return clonedOp;
136 }
137 
138 /// Generate an empty loop nest that represents the tiled loop nest shell.
139 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
140 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
141 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
142 /// the
143 ///   tile processed within the inner most loop.
144 static SmallVector<scf::ForOp> generateTileLoopNest(
145     OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
146     ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
147     SmallVector<OpFoldResult> &sizes) {
148   assert(!loopRanges.empty() && "expected at least one loop range");
149   assert(loopRanges.size() == tileSizes.size() &&
150          "expected as many tile sizes as loop ranges");
151   OpBuilder::InsertionGuard guard(builder);
152   SmallVector<scf::ForOp> loops;
153   offsets.resize(loopRanges.size());
154   sizes.resize(loopRanges.size());
155 
156   for (auto loopRange : llvm::enumerate(loopRanges)) {
157     Value offset =
158         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
159     Value size =
160         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
161     Value tileSize = getValueOrCreateConstantIndexOp(
162         builder, loc, tileSizes[loopRange.index()]);
163     // No loops if tile size is zero. Set offset and size to the loop
164     // offset and size.
165     if (matchPattern(tileSize, m_Zero())) {
166       offsets[loopRange.index()] = offset;
167       sizes[loopRange.index()] = size;
168       continue;
169     }
170 
171     auto loop = builder.create<scf::ForOp>(
172         loc, offset, size, tileSize, ValueRange{},
173         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
174             ValueRange /*iterArgs*/) {
175           sizes[loopRange.index()] =
176               getBoundedTileSize(bodyBuilder, bodyLoc, loopRange.value(), iv,
177                                  getAsOpFoldResult(tileSize));
178           builder.create<scf::YieldOp>(loc);
179         });
180     offsets[loopRange.index()] = loop.getInductionVar();
181     loops.push_back(loop);
182     builder.setInsertionPoint(loop.getBody()->getTerminator());
183   }
184   return loops;
185 }
186 
187 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
188 /// construct the destructive update pattern that inserts the yielded
189 /// value into a destination tensor provided by `initValue` at offset
190 /// `tileOffsets` and size `tileSizes`. For example,
191 ///
192 /// ```mlir
193 /// scf.for %iv0 = ... {
194 ///   %0 = tiled_op
195 /// }
196 /// ```
197 ///
198 /// is transformed to
199 ///
200 /// ```mlir
201 /// scf.for %iv0 = ... iter_args(%arg = %0) {
202 ///   %1 = tensor.extract_slice %arg
203 ///   %2 = tiled_op
204 ///   %3 = tensor.insert_slice %2 into %arg
205 ///   scf.yield %3
206 /// }
207 /// ```
208 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
209 static SmallVector<Value>
210 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
211                  ValueRange yieldedValues,
212                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
213                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
214                  MutableArrayRef<scf::ForOp> loops) {
215   NewYieldValuesFn yieldValueFn =
216       [&](OpBuilder &b, Location loc,
217           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
218     SmallVector<Value> inserts;
219     for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) {
220       ArrayRef<OpFoldResult> tileOffsets =
221           tileOffsetsList[yieldedValue.index()];
222       ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
223       SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
224                                             b.getIndexAttr(1));
225       Value insert = b.create<tensor::InsertSliceOp>(
226           loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
227           tileOffsets, tileSizes, tileStrides);
228       inserts.push_back(insert);
229     }
230     return inserts;
231   };
232 
233   SmallVector<scf::ForOp> newLoops =
234       replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
235                                    /*replaceIterOperandsUsesInLoop =*/false);
236   for (const auto &loop : llvm::enumerate(loops)) {
237     loops[loop.index()] = newLoops[loop.index()];
238   }
239   return llvm::to_vector(llvm::map_range(
240       loops.front().getResults().take_back(yieldedValues.size()),
241       [](OpResult r) -> Value { return r; }));
242 }
243 
244 /// If the tiled operation is destination passing style, update the
245 /// slice of the destination used (which refers to the untiled destination)
246 /// to use the corresponding region argument of the innermost loop.
247 ///
248 /// ```mlir
249 /// %0 =
250 /// scf.for %iv0 = ... iter_args(%arg = %0) {
251 ///   %1 = tensor.extract_slice %0
252 ///   %2 = tiled_op
253 ///   %3 = tensor.insert_slice %2 into %arg
254 ///   scf.yield %3
255 /// }
256 /// ```
257 ///
258 /// is transformed to
259 ///
260 /// ```mlir
261 /// scf.for %iv0 = ... iter_args(%arg = %0) {
262 ///   %1 = tensor.extract_slice %arg
263 ///   %2 = tiled_op
264 ///   %3 = tensor.insert_slice %2 into %arg
265 ///   scf.yield %3
266 /// }
267 /// ```
268 static void
269 updateDestinationOperandsForTiledOp(OpBuilder &builder,
270                                     ValueRange tiledOpDestinationValues,
271                                     ValueRange bbArgsList) {
272   for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) {
273     auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
274     if (!sliceOp)
275       continue;
276     sliceOp.setOperand(0, bbArgsList[destValue.index()]);
277   }
278 }
279 
280 /// Helper method to yield the values of the tiled op, as well as
281 /// update the destination operands of the tiled op, if it is
282 /// a destination passing style op.
283 static SmallVector<Value>
284 yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues,
285                  TilingResult tilingResult,
286                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
287                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
288                  MutableArrayRef<scf::ForOp> loops) {
289   SmallVector<Value> replacements =
290       yieldTiledValues(rewriter, initValues, tilingResult.tiledValues,
291                        tileOffsetsList, tileSizesList, loops);
292   for (auto tiledOp : tilingResult.tiledOps) {
293     if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
294       auto innerMostLoop = loops.back();
295       SmallVector<Value> tiledOpDestinationTensors =
296           llvm::to_vector(dstOp.getDpsInits());
297       updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
298                                           innerMostLoop.getRegionIterArgs());
299     }
300   }
301   return replacements;
302 }
303 
304 /// Implementation of tiling transformation of `op` that implements the
305 /// `TilingInterface` using `scf.for` to iterate over the tiles.
306 FailureOr<scf::SCFTilingResult>
307 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
308                              const scf::SCFTilingOptions &options) {
309   OpBuilder::InsertionGuard guard(rewriter);
310   rewriter.setInsertionPointAfter(op);
311 
312   if (!options.tileSizeComputationFunction) {
313     return rewriter.notifyMatchFailure(
314         op, "missing tile size computation function");
315   }
316 
317   // 1. Get the range of the loops that are represented by the operation.
318   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
319   size_t numLoops = iterationDomain.size();
320   if (numLoops == 0) {
321     return rewriter.notifyMatchFailure(
322         op, "unable to tile op with no iteration domain");
323   }
324 
325   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
326   // skips tiling a particular dimension. This convention is significantly
327   // simpler to handle instead of adjusting affine maps to account for missing
328   // dimensions.
329   SmallVector<OpFoldResult> tileSizeVector =
330       options.tileSizeComputationFunction(rewriter, op);
331   if (tileSizeVector.size() < iterationDomain.size()) {
332     auto zero = rewriter.getIndexAttr(0);
333     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
334   }
335 
336   SmallVector<OpFoldResult> offsets, sizes;
337   SmallVector<scf::ForOp> forLoops;
338   {
339     // If there is an interchange specified, permute the iteration domain and
340     // the tile sizes.
341     SmallVector<int64_t> interchangeVector;
342     if (!options.interchangeVector.empty()) {
343       interchangeVector = fillInterchangeVector(options.interchangeVector,
344                                                 iterationDomain.size());
345     }
346     if (!interchangeVector.empty()) {
347       if (!isPermutationVector(interchangeVector)) {
348         return rewriter.notifyMatchFailure(
349             op, "invalid intechange vector, not a permutation of the entire "
350                 "iteration space");
351       }
352 
353       applyPermutationToVector(iterationDomain, interchangeVector);
354       applyPermutationToVector(tileSizeVector, interchangeVector);
355     }
356 
357     // 3. Materialize an empty loop nest that iterates over the tiles. These
358     // loops for now do not return any values even if the original operation has
359     // results.
360     forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
361                                     tileSizeVector, offsets, sizes);
362 
363     if (!interchangeVector.empty()) {
364       auto inversePermutation = invertPermutationVector(interchangeVector);
365       applyPermutationToVector(offsets, inversePermutation);
366       applyPermutationToVector(sizes, inversePermutation);
367     }
368   }
369 
370   LLVM_DEBUG({
371     if (!forLoops.empty()) {
372       llvm::dbgs() << "LoopNest shell :\n";
373       forLoops.front().dump();
374       llvm::dbgs() << "\n";
375     }
376   });
377 
378   // 4. Generate the tiled implementation within the inner most loop.
379   if (!forLoops.empty())
380     rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator());
381   FailureOr<TilingResult> tiledImplementation =
382       op.getTiledImplementation(rewriter, offsets, sizes);
383 
384   if (op->getNumResults() == 0) {
385     return scf::SCFTilingResult{
386         tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
387   }
388 
389   // If loops are empty, the tiled op is used as the replacement for the untiled
390   // op.
391   if (forLoops.empty()) {
392     return scf::SCFTilingResult{tiledImplementation->tiledOps,
393                                 getAsOperations(forLoops),
394                                 tiledImplementation->tiledValues};
395   }
396 
397   // 5. Yield all the results of the tiled operation. The surrounding loop
398   //    nest is modified to insert a destructive update pattern to yield
399   //    from the loop nest values to replace the untiled op with.
400   int64_t numResults = op->getNumResults();
401   SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
402       resultSizesList(numResults);
403   for (const auto &result : llvm::enumerate(op->getResults())) {
404     if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
405                                         sizes,
406                                         resultOffsetsList[result.index()],
407                                         resultSizesList[result.index()]))) {
408       return rewriter.notifyMatchFailure(
409           op, "failed to get slice of result produced");
410     }
411   }
412 
413   SmallVector<Value> destinationTensors;
414   if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
415                                              destinationTensors)))
416     return rewriter.notifyMatchFailure(op, "failed to get destinations");
417 
418   SmallVector<Value> replacements = yieldTiledValues(
419       rewriter, destinationTensors, tiledImplementation.value(),
420       resultOffsetsList, resultSizesList, forLoops);
421   LLVM_DEBUG({
422     if (!forLoops.empty()) {
423       llvm::dbgs() << "After tiled implementation :\n";
424       forLoops.front().dump();
425       llvm::dbgs() << "\n";
426     }
427   });
428   return scf::SCFTilingResult{tiledImplementation->tiledOps,
429                               getAsOperations(forLoops), replacements};
430 }
431 
432 FailureOr<scf::SCFReductionTilingResult>
433 mlir::scf::tileReductionUsingScf(RewriterBase &b,
434                                  PartialReductionOpInterface op,
435                                  ArrayRef<OpFoldResult> tileSizes) {
436   Location loc = op.getLoc();
437   // Ops implementing PartialReductionOpInterface are expected to implement
438   // TilingInterface.
439   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
440   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
441   auto tileSizesVector = llvm::to_vector(tileSizes);
442   if (tileSizesVector.size() < iterationDomain.size()) {
443     auto zero = b.getIndexAttr(0);
444     tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
445                            zero);
446   }
447   if (op->getNumResults() != 1)
448     return b.notifyMatchFailure(
449         op, "don't support ops with multiple results for now");
450   SmallVector<utils::IteratorType> iterators =
451       tilingInterfaceOp.getLoopIteratorTypes();
452 
453   SmallVector<int> reductionDims;
454   for (auto [idx, iteratorType] :
455        llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
456     if (iteratorType == utils::IteratorType::reduction)
457       reductionDims.push_back(idx);
458   }
459 
460   // 1. create the inital tensor value.
461   FailureOr<Operation *> identityTensor =
462       op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
463                                                   reductionDims);
464   if (failed(identityTensor))
465     return b.notifyMatchFailure(op,
466                                 "cannot create a tensor of identity value.");
467   // 2. Create the nested loops.
468   SmallVector<OpFoldResult> offsets, sizes;
469   SmallVector<scf::ForOp> loops = generateTileLoopNest(
470       b, loc, iterationDomain, tileSizesVector, offsets, sizes);
471 
472   // 3. Generate the tiled implementation within the inner most loop.
473   b.setInsertionPoint(loops.back().getBody()->getTerminator());
474   Operation *parallelOp = op.tileToPartialReduction(
475       b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDims);
476 
477   SmallVector<OpFoldResult> resultSizesList;
478   for (size_t i = 0; i < offsets.size(); i++)
479     resultSizesList.push_back(
480         tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
481   SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
482   SmallVector<Value> replacements = yieldTiledValues(
483       b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
484       resultSizesList, loops);
485 
486   auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
487   auto innerMostLoop = loops.back();
488   SmallVector<Value> destinationTensors = llvm::to_vector(dstOp.getDpsInits());
489   assert(destinationTensors.size() ==
490              innerMostLoop.getRegionIterArgs().size() &&
491          "unexpected number of outputs");
492   updateDestinationOperandsForTiledOp(b, destinationTensors,
493                                       innerMostLoop.getRegionIterArgs());
494 
495   // 4. Apply the merge reduction to combine all the partial values.
496   b.setInsertionPointAfter(*loops.begin());
497   Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
498   b.replaceOp(op, mergeOp->getResults());
499 
500   SCFReductionTilingResult results;
501   results.initialOp = *identityTensor;
502   results.loops = std::move(loops);
503   results.parallelTiledOp = parallelOp;
504   results.mergeOp = mergeOp;
505   return results;
506 }
507 
508 //===----------------------------------------------------------------------===//
509 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
510 //===----------------------------------------------------------------------===//
511 
512 /// Return the untiled producer whose slice is used in a tiled consumer. The
513 /// method traverses the tile loop nest (`loops`) if needed, and returns the
514 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
515 /// indicates that this is a destination operand of the consumer. If there was
516 /// no loop traversal needed, the second value of the returned tuple is empty.
517 static std::tuple<OpResult, std::optional<OpOperand *>>
518 getUntiledProducerFromSliceSource(OpOperand *source,
519                                   ArrayRef<scf::ForOp> loops) {
520   std::optional<OpOperand *> destinationIterArg;
521   auto loopIt = loops.rbegin();
522   while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
523     scf::ForOp loop = *loopIt;
524     if (iterArg.getOwner()->getParentOp() != loop)
525       break;
526     source = &loop.getOpOperandForRegionIterArg(iterArg);
527     loopIt++;
528   }
529   if (loopIt == loops.rend())
530     destinationIterArg = source;
531   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
532 }
533 
534 /// Implementation of fusing producer of a single slice by computing the
535 /// slice of the producer in-place.
536 std::optional<scf::SCFFuseProducerOfSliceResult>
537 mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
538                                       tensor::ExtractSliceOp candidateSliceOp,
539                                       MutableArrayRef<scf::ForOp> loops) {
540   // 1. Get the producer of the source (potentially walking through
541   // `iter_args` of nested `scf.for`)
542   auto [fusableProducer, destinationInitArg] =
543       getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
544                                         loops);
545   if (!fusableProducer)
546     return std::nullopt;
547 
548   // 2. Generate the tiled implementation of the producer of the source
549   OpBuilder::InsertionGuard g(rewriter);
550   rewriter.setInsertionPoint(candidateSliceOp);
551   FailureOr<TilingResult> tileAndFuseResult =
552       tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
553                                                    fusableProducer);
554   if (failed(tileAndFuseResult))
555     return std::nullopt;
556   rewriter.replaceAllUsesWith(candidateSliceOp,
557                               tileAndFuseResult->tiledValues[0]);
558 
559   // 3. If the slice is for a destination operand, for example,
560   //
561   // ```mlir
562   // %0 = linalg.init
563   // %1 = linalg.fill .. outs(%0 : )
564   // %2 = scf.for .. iter_args(%arg0 = %1) {
565   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
566   //     %4 = tensor.extract_slice %arg1 [..]
567   //     .. = linalg.matmul .. outs(%4 : )
568   //   }
569   // }
570   // ```
571   //
572   // the IR is currently
573   //
574   // ```
575   // %0 = linalg.init
576   // %1 = linalg.fill
577   // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
578   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
579   //     %4 = tensor.extract_slice %0 /*incorrect value */ [..]
580   //     %5 = linalg.fill .. outs(%4 : )
581   //     .. = linalg.matmul .. outs(%5 : )
582   //   }
583   // }
584   // ```
585   //
586   // The untiled `linalg.fill` is still used as the `init_value` since it
587   // was originally a destination operand of the untiled `linalg.matmul`.
588   // When fusing an operand that is a destination operand.
589   //   - Update the iter_arg of the outer most loop to use the destination
590   //     of the untiled producer.
591   //   - Update the destination of the slice of the tiled producer generated
592   //     to use the same basic block argument as the slice that was used to
593   //     generate inplace the tiled implementation of the producer.
594   // With this the IR will be.
595   //
596   // ```
597   // %0 = linalg.init
598   // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
599   //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
600   //     %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
601   //     %4 = linalg.fill .. outs(%3 : )
602   //     .. = linalg.matmul .. outs(%4 : )
603   //   }
604   // }
605   // ```
606   // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
607   // Update to use that when it does become available.
608   scf::ForOp outerMostLoop = loops.front();
609   if (destinationInitArg &&
610       (*destinationInitArg)->getOwner() == outerMostLoop) {
611     unsigned iterArgNumber =
612         outerMostLoop.getResultForOpOperand(**destinationInitArg)
613             .getResultNumber();
614     int64_t resultNumber = fusableProducer.getResultNumber();
615     if (auto dstOp =
616             dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
617       (*destinationInitArg)
618           ->set(dstOp.getTiedOpOperand(fusableProducer)->get());
619     }
620     for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
621       auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
622       if (!dstOp)
623         continue;
624       scf::ForOp innerMostLoop = loops.back();
625       updateDestinationOperandsForTiledOp(
626           rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
627           innerMostLoop.getRegionIterArgs()[iterArgNumber]);
628     }
629   }
630   return scf::SCFFuseProducerOfSliceResult{fusableProducer,
631                                            tileAndFuseResult->tiledValues[0],
632                                            tileAndFuseResult->tiledOps};
633 }
634 
635 /// Reconstruct the fused producer from within the tiled-and-fused code.
636 void mlir::scf::yieldReplacementForFusedProducer(
637     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
638     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
639     MutableArrayRef<scf::ForOp> loops) {
640   auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
641       fusedProducerInfo;
642   SmallVector<Value> initValues;
643   FailureOr<Value> initValue = tensor::getOrCreateDestination(
644       rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
645   if (succeeded(initValue)) {
646     SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets();
647     SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes();
648     SmallVector<Value> yieldedVals =
649         yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
650                          resultOffsets, resultSizes, loops);
651   }
652   for (auto tileAndFusedOp : tileAndFusedOps) {
653     auto dstStyleProducer =
654         dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
655     if (!dstStyleProducer)
656       continue;
657     Value dstValue =
658         dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
659             ->get();
660     updateDestinationOperandsForTiledOp(
661         rewriter, dstValue, loops.back().getRegionIterArgs().back());
662   }
663 }
664 
665 /// Implementation of tile consumer and fuse producer greedily.
666 FailureOr<scf::SCFTileAndFuseResult>
667 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
668     RewriterBase &rewriter, TilingInterface consumer,
669     const scf::SCFTileAndFuseOptions &options) {
670   // This transformation is only valid for ops that return values (i.e. not
671   // valid to use with operations that have memref operands).
672   if (!consumer->getNumResults()) {
673     return rewriter.notifyMatchFailure(
674         consumer, "invalid pattern for op with no results");
675   }
676 
677   // 1. First tile the consumer.
678   SmallVector<scf::ForOp> forLoops;
679   SetVector<Operation *> fusedProducers, tiledAndFusedOps;
680   DenseMap<Value, Value> replacements;
681   llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
682   {
683     FailureOr<scf::SCFTilingResult> tilingResult =
684         tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
685     if (failed(tilingResult))
686       return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
687     for (auto *tiledOp : tilingResult->tiledOps)
688       tiledAndFusedOps.insert(tiledOp);
689     forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
690     for (auto [index, origValue, replacement] :
691          llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
692       replacements[origValue] = replacement;
693       yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
694           index)] = index;
695     }
696   }
697 
698   // If there are no loops generated, fusion is immaterial.
699   if (forLoops.empty()) {
700     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
701                                      getAsOperations(forLoops), replacements};
702   }
703 
704   // 2. Typically, the operands of the tiled operation are slices of the
705   //    operands of the untiled operation. These are expressed in IR using
706   //    `tensor.extract_slice` operations with source being the operands of the
707   //    untiled operation. Create a worklist of these `tensor.extract_slice`
708   //    operations. If the producers of the source of the `tensor.extract_slice`
709   //    can be tiled such that the tiled value is generated in-place, that
710   //    effectively tiles + fuses the operations.
711   auto addCandidateSlices = [](Operation *fusedOp,
712                                std::deque<tensor::ExtractSliceOp> &candidates) {
713     for (Value operand : fusedOp->getOperands())
714       if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
715         candidates.push_back(sliceOp);
716   };
717 
718   std::deque<tensor::ExtractSliceOp> candidates;
719   addCandidateSlices(tiledAndFusedOps.back(), candidates);
720   OpBuilder::InsertionGuard g(rewriter);
721   while (!candidates.empty()) {
722     // Traverse the slices in BFS fashion.
723     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
724     candidates.pop_front();
725 
726     // The operands of the fused producer might themselved be slices of
727     // values produced by operations that implement the `TilingInterface`.
728     // Add these operations to the worklist.
729     std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
730         tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
731     if (!fusedResult)
732       continue;
733 
734     if (Operation *tiledAndFusedOp =
735             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
736       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
737       tiledAndFusedOps.insert(tiledAndFusedOp);
738       addCandidateSlices(tiledAndFusedOp, candidates);
739     }
740   }
741   return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
742                                    getAsOperations(forLoops), replacements};
743 }
744 
745 //===----------------------------------------------------------------------===//
746 // tileUsingSCFForAllOp implementation.
747 //===----------------------------------------------------------------------===//
748 
749 FailureOr<scf::SCFTilingResult>
750 mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
751                                 const scf::SCFTilingOptions &options) {
752   Location loc = op->getLoc();
753   OpBuilder::InsertionGuard g(rewriter);
754 
755   // 1. Get the range of loops that are represented by the operation.
756   SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
757   if (loopRanges.empty())
758     return op->emitOpError("expected non-empty loop ranges");
759   auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
760   if (llvm::any_of(loopRanges, hasStrideOne))
761     return op->emitOpError("only stride-1 supported atm");
762 
763   // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
764   // To make it easier, pad the tile sizes to loopRanges.size with value 0.
765   SmallVector<OpFoldResult> tileSizeVector =
766       options.tileSizeComputationFunction(rewriter, op);
767   tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
768 
769   // 3. Build the offsets, sizes and steps for the tile and distributed loops.
770   SmallVector<OpFoldResult> lbs, ubs, steps;
771   for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
772     if (isConstantIntValue(tileSize, 0))
773       continue;
774     lbs.push_back(loopRange.offset);
775     ubs.push_back(loopRange.size);
776     steps.push_back(tileSize);
777   }
778 
779   // 4. Gather destination tensors.
780   SmallVector<Value> dest;
781   if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
782     return op->emitOpError("failed to get destination tensors");
783 
784   // 5. Build the device mapping attribute.
785   std::optional<ArrayAttr> mappingAttr;
786   if (!options.mappingVector.empty()) {
787     mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
788   }
789 
790   // 6. Create the ForallOp. We don't use the lambda body-builder
791   // version because we require the use of RewriterBase in the body, so we
792   // manually move the insertion point to the body below.
793   auto forallOp =
794       rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);
795 
796   // 7. Get the tile offset and sizes.
797   rewriter.setInsertionPoint(forallOp.getTerminator());
798   SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
799   ValueRange ivs = forallOp.getInductionVars();
800   {
801     int materializedLoopNum = 0;
802     for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
803       if (isConstantIntValue(tileSize, 0)) {
804         tiledOffsets.push_back(loopRange.offset);
805         tiledSizes.push_back(loopRange.size);
806         continue;
807       }
808       Value iv = ivs[materializedLoopNum++];
809       tiledOffsets.push_back(iv);
810       tiledSizes.push_back(
811           getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
812     }
813   }
814 
815   // 8. Tile the operation. Clone the operation to allow fix up of destination
816   // operands.
817   ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
818   Operation *clonedOp =
819       cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
820   FailureOr<TilingResult> tilingResult =
821       cast<TilingInterface>(clonedOp).getTiledImplementation(
822           rewriter, tiledOffsets, tiledSizes);
823   if (failed(tilingResult))
824     return clonedOp->emitError("failed to tile op: ");
825   rewriter.eraseOp(clonedOp);
826 
827   // 9. Parallel insert back into the result tensor.
828   for (auto [index, tiledValue, destBBArg] :
829        llvm::enumerate(tilingResult->tiledValues, destBbArgs)) {
830     // 9.a. Partial subset information is inserted just before the terminator.
831     rewriter.setInsertionPoint(forallOp.getTerminator());
832 
833     SmallVector<OpFoldResult> resultOffsets, resultSizes;
834     if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
835                                         tiledSizes, resultOffsets,
836                                         resultSizes))) {
837       return op->emitOpError("output offsets couldn't be calculated");
838     }
839 
840     SmallVector<OpFoldResult> strides(resultSizes.size(),
841                                       rewriter.getIndexAttr(1));
842     // 9.b. Parallel insertions are inserted at the end of the combining
843     // terminator.
844     rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
845     rewriter.create<tensor::ParallelInsertSliceOp>(
846         loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
847   }
848 
849   // 10. Return the tiling result.
850   return scf::SCFTilingResult{
851       tilingResult->tiledOps,
852       {forallOp.getOperation()},
853       llvm::map_to_vector(forallOp.getResults(),
854                           [](auto val) -> Value { return val; })};
855 }
856 
857 //===----------------------------------------------------------------------===//
858 // lowerToLoopsUsingSCFForOp implementation.
859 //===----------------------------------------------------------------------===//
860 
861 FailureOr<SmallVector<scf::ForOp>>
862 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
863                                      TilingInterface op) {
864   // TODO: Handle cases where the op has results if needed.
865   if (op->getNumResults() > 0) {
866     return rewriter.notifyMatchFailure(
867         op, "unable to lower to loops operations with return values");
868   }
869 
870   SmallVector<Range> domain = op.getIterationDomain(rewriter);
871   SmallVector<Value> ivs;
872   SmallVector<scf::ForOp> loops;
873   Location loc = op.getLoc();
874   for (auto loopRange : domain) {
875     Value offsetVal =
876         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
877     Value sizeVal =
878         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
879     Value strideVal =
880         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
881     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
882                                             strideVal, ValueRange{});
883     loops.push_back(loop);
884     ivs.push_back(loop.getInductionVar());
885     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
886   }
887   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
888     return failure();
889   }
890   return loops;
891 }
892