xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision 8cc616bc71dfe0648de3843a006ac8827c5fe59d)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Analysis/TopologicalSortUtils.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Arith/Utils/Utils.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/Dialect/SCF/Utils/Utils.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
28 #include "mlir/Interfaces/TilingInterface.h"
29 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
33 #include <optional>
34 
35 #define DEBUG_TYPE "tile-using-interface"
36 
37 using namespace mlir;
38 
39 scf::SCFTilingOptions &
40 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
41   assert(!tileSizeComputationFunction && "tile sizes already set");
42   auto tileSizes = llvm::to_vector(ts);
43   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
44     return tileSizes;
45   };
46   return *this;
47 }
48 
49 scf::SCFTilingOptions &
50 scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
51   assert(!numThreadsComputationFunction && "num tiles already set");
52   auto numThreads = llvm::to_vector(nt);
53   numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
54     return numThreads;
55   };
56   return *this;
57 }
58 
59 /// Helper method to adjust the interchange vector to match the iteration
60 /// domain.
61 static SmallVector<int64_t>
62 fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
63                       size_t iterationDomainSize) {
64   SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
65   if (filledVector.size() < iterationDomainSize) {
66     auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
67     filledVector.append(range.begin(), range.end());
68   }
69   if (filledVector.size() > iterationDomainSize)
70     filledVector.resize(iterationDomainSize);
71   return filledVector;
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // tileUsingSCF implementation.
76 //===----------------------------------------------------------------------===//
77 
78 /// Verify the tile size options are set in a consistent manner.
79 static LogicalResult
80 verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
81                       const scf::SCFTilingOptions &options) {
82   // Specifying number of threads is only supported on `scf.forall` op.
83   if (options.numThreadsComputationFunction &&
84       options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
85     return rewriter.notifyMatchFailure(
86         loc, "number of threads can only by specified when loop type is "
87              "set to use `scf.forall`");
88   }
89 
90   // If specified, check that the interchange vector is a permutation.
91   if (!options.interchangeVector.empty()) {
92     if (!isPermutationVector(options.interchangeVector)) {
93       return rewriter.notifyMatchFailure(
94           loc, "invalid interchange vector, not a permutation of the entire "
95                "iteration space");
96     }
97   }
98   return success();
99 }
100 
101 /// Method to instantiate the tile sizes and/or number of threads specified
102 /// by the user.
103 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
104 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
105                               ArrayRef<Range> iterationDomain,
106                               const scf::SCFTilingOptions &options) {
107   OpFoldResult zero = rewriter.getIndexAttr(0);
108   SmallVector<OpFoldResult> tileSizes, numThreads;
109   size_t numLoops = iterationDomain.size();
110 
111   // Check whether the number of tiles to use is specified.
112   if (options.numThreadsComputationFunction) {
113     numThreads = options.numThreadsComputationFunction(rewriter, op);
114     numThreads.resize(numLoops, zero);
115 
116     // If the number of tiles is also specified, use that.
117     if (options.tileSizeComputationFunction) {
118       tileSizes = options.tileSizeComputationFunction(rewriter, op);
119       tileSizes.resize(numLoops, zero);
120       return {tileSizes, numThreads};
121     }
122 
123     // Compute the tile sizes from the iteration domain and number
124     // of tiles as follows
125     // - niters = ceilDiv(ub - lb, step)
126     // - tileSize = ceilDiv(niters, numThreads)
127     AffineExpr s0, s1, s2;
128     bindSymbols(rewriter.getContext(), s0, s1, s2);
129     // TODO: The step here is assumed to be 1.
130     AffineExpr numItersExpr = (s1 - s0);
131     AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
132     tileSizes.resize(numLoops, zero);
133     for (auto [index, range, nt] :
134          llvm::enumerate(iterationDomain, numThreads)) {
135       if (isConstantIntValue(nt, 0))
136         continue;
137 
138       tileSizes[index] = affine::makeComposedFoldedAffineApply(
139           rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
140     }
141     tileSizes.resize(numLoops, zero);
142     return {tileSizes, numThreads};
143   }
144 
145   // Enforce the convention that "tiling by zero"
146   // skips tiling a particular dimension. This convention is significantly
147   // simpler to handle instead of adjusting affine maps to account for missing
148   // dimensions.
149   assert(options.tileSizeComputationFunction &&
150          "expected tile sizes to be specified");
151   tileSizes = options.tileSizeComputationFunction(rewriter, op);
152   tileSizes.resize(numLoops, zero);
153 
154   return {tileSizes, numThreads};
155 }
156 
157 /// Checks if any of the tiled loops are not parallel.
158 static void checkSafeToTileToForall(TilingInterface op,
159                                     ArrayRef<OpFoldResult> tileSizes,
160                                     ArrayRef<OpFoldResult> numThreads) {
161   auto iterators = op.getLoopIteratorTypes();
162   assert(iterators.size() == tileSizes.size() &&
163          "expected as many tile size values as number of loops");
164   assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
165          "when specified, expected number of threads to use for each loop");
166 
167   for (auto [index, iterator, tileSize] :
168        llvm::enumerate(iterators, tileSizes)) {
169     // If num threads is specified, check that it is greater than one only for
170     // parallel dimensions.
171     if (!numThreads.empty()) {
172       if (std::optional<int64_t> constNumThreads =
173               getConstantIntValue(numThreads[index])) {
174         if (constNumThreads.value() > 1 &&
175             iterator != utils::IteratorType::parallel) {
176           op.emitWarning() << "tiling is not thread safe at axis #" << index;
177         }
178       }
179       continue;
180     }
181 
182     if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
183       if (constTileSize.value() > 0 &&
184           iterator != utils::IteratorType::parallel) {
185         op.emitWarning() << "tiling is not thread safe at axis #" << index;
186       }
187     }
188   }
189 }
190 
191 /// Check if `stride` evenly divides the trip count `size - offset`.
192 static bool tileDividesIterationDomain(Range loopRange) {
193   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
194   if (!offsetAsInt)
195     return false;
196   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
197   if (!sizeAsInt)
198     return false;
199   std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
200   if (!strideAsInt)
201     return false;
202   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
203 }
204 
205 /// Returns the bounded tile size given the current `offset`, `loopRange` and
206 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
207 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
208                                        Range loopRange, OpFoldResult offset,
209                                        OpFoldResult tileSize) {
210   std::optional<int64_t> ts = getConstantIntValue(tileSize);
211   if (ts && ts.value() == 1)
212     return tileSize;
213 
214   if (tileDividesIterationDomain(
215           Range{loopRange.offset, loopRange.size, tileSize}))
216     return tileSize;
217 
218   // The tile size to use (to avoid out of bounds access) is  minimum of
219   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
220   // loop.
221   AffineExpr s0, s1, d0;
222   bindDims(b.getContext(), d0);
223   bindSymbols(b.getContext(), s0, s1);
224   AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
225   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
226   return affine::makeComposedFoldedAffineMin(
227       b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
228 }
229 
230 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
231 /// than `iterationSize`.
232 static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
233                                            OpFoldResult numThreads,
234                                            OpFoldResult iterationSize) {
235   std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
236   std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
237   std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
238   if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
239     return false;
240   return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
241 }
242 
243 /// Compute the `OpFoldResult`s that represents the multi-dimensional
244 /// `offset`s and `size`s of the tile of the iteration space that the
245 /// innermost loop body of the generated tiled loops corresponds to.
246 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
247 getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
248                       ArrayRef<Range> iterationDomain,
249                       ArrayRef<OpFoldResult> tileSizes,
250                       ArrayRef<OpFoldResult> numThreads) {
251   SmallVector<OpFoldResult> offsets, sizes;
252   int materializedLoopNum = 0;
253 
254   if (!numThreads.empty()) {
255     AffineExpr d0, d1, s0, s1;
256     AffineExpr offsetExpr, residualTileSizeExpr;
257     bindDims(rewriter.getContext(), d0, d1);
258     bindSymbols(rewriter.getContext(), s0, s1);
259     offsetExpr = d0 + d1 * s0;
260     residualTileSizeExpr = s1 - (d0 + d1 * s0);
261 
262     for (auto [nt, tileSize, loopRange] :
263          llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
264 
265       // Non-tiled cases, set the offset and size to the
266       // `loopRange.offset/size`.
267       if (isConstantIntValue(nt, 0)) {
268         offsets.push_back(loopRange.offset);
269         sizes.push_back(loopRange.size);
270         continue;
271       }
272 
273       Value iv = ivs[materializedLoopNum++];
274       OpFoldResult offset = affine::makeComposedFoldedAffineApply(
275           rewriter, loc, offsetExpr,
276           ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
277       OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
278           rewriter, loc, residualTileSizeExpr,
279           {loopRange.offset, nt, tileSize, loopRange.size});
280 
281       OpFoldResult size = tileSize;
282       if (!isConstantIntValue(residualTileSize, 0)) {
283         OpFoldResult sizeMinusOffsetPerThread =
284             affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
285                                                   {offset, loopRange.size});
286         size = affine::makeComposedFoldedAffineMin(
287             rewriter, loc,
288             AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
289             {sizeMinusOffsetPerThread, tileSize});
290       }
291 
292       // Consider the case where the original loop was `[0, 100)`.
293       // If number of threads are `7`, the tile size would be computed as
294       // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
295       // - `offset = 0 + 6 * 15 = 105`
296       // - `tileSize = min(15, 100 - 105) = -5`
297       // To avoid negative tile sizes, we need to do a further
298       // `nonNegativeTileSize = affine.max(0, tileSize)`.
299       // This `max` can be avoided if
300       //  `offset + tileSize * (numThreads - 1) < (ub - lb)`
301       if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
302         AffineMap maxMap =
303             AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
304         size = affine::makeComposedFoldedAffineMax(
305             rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
306       }
307 
308       offsets.push_back(offset);
309       sizes.push_back(size);
310     }
311     return {offsets, sizes};
312   } else {
313     for (auto [tileSize, loopRange] :
314          llvm::zip_equal(tileSizes, iterationDomain)) {
315 
316       // Non-tiled cases, set the offset and size to the
317       // `loopRange.offset/size`.
318       if (isConstantIntValue(tileSize, 0)) {
319         offsets.push_back(loopRange.offset);
320         sizes.push_back(loopRange.size);
321         continue;
322       }
323 
324       Value iv = ivs[materializedLoopNum++];
325       OpFoldResult offset = getAsOpFoldResult(iv);
326       offsets.push_back(offset);
327       OpFoldResult size =
328           getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
329       sizes.push_back(size);
330     }
331     return {offsets, sizes};
332   }
333 }
334 
335 /// Function to return the bounds of the loops to be generated.
336 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
337                   SmallVector<OpFoldResult>>
338 getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
339               ArrayRef<OpFoldResult> tileSizes) {
340   SmallVector<OpFoldResult> lbs, ubs, steps;
341   for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
342     // No loop if the tile size is 0.
343     if (isConstantIntValue(tileSize, 0))
344       continue;
345     lbs.push_back(loopRange.offset);
346     ubs.push_back(loopRange.size);
347     steps.push_back(tileSize);
348   }
349   return {lbs, ubs, steps};
350 }
351 
352 /// A function that allows returning additional yielded values during
353 /// `yieldTiledValuesAndReplace`.
354 /// - `ivs` induction variable for the loop.
355 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
356 /// - `tiledValues` the tiled values to return. Must be of same size as
357 ///   `newbbArgs`, each element of this array is inserted into the corresponding
358 ///   element in `newbbArgs`.
359 /// - `resultOffsets` is of the same size as `tiledValues` and represents
360 ///   the offsets to use when inserting corresponding element from `tiledValues`
361 ///   into the element from `newBbArgs`.
362 /// - `resultSizes` is of the same size as `tiledValues` and represents
363 ///   the size of the corresponding element from `tiledValues` inserted into
364 ///   the element from `newBbArgs`.
365 /// In case the method needs to return `failure()` the method is expected
366 /// to clean up any inserted operations.
367 using YieldTiledValuesFn = std::function<LogicalResult(
368     RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
369     SmallVector<Value> &tiledValues,
370     SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
371     SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
372 
373 /// Clones the operation and updates the destination if the operation
374 /// implements the `DestinationStyleOpInterface`.
375 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
376                                                   Operation *op,
377                                                   ValueRange newDestArgs) {
378   Operation *clonedOp = rewriter.clone(*op);
379   if (newDestArgs.empty())
380     return clonedOp;
381   if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
382     destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
383   return clonedOp;
384 }
385 
386 /// Generate the tile-loop nest using `scf.for` operation.
387 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
388 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
389 /// - `destinationTensors` are the init values to use for the outer most loop.
390 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
391 /// most
392 ///    loop.
393 /// - `loops` is an in-out parameter into which the generated loops are
394 ///    populated.
395 static LogicalResult generateLoopNestUsingForOp(
396     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
397     ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
398     YieldTiledValuesFn yieldTiledValuesFn,
399     SmallVector<LoopLikeOpInterface> &loops) {
400   assert(!loopRanges.empty() && "unexpected empty loop ranges");
401   assert(loopRanges.size() == tileSizes.size() &&
402          "expected as many tile sizes as loop ranges");
403   OpBuilder::InsertionGuard guard(rewriter);
404 
405   SmallVector<OpFoldResult> lbs, ubs, steps;
406   std::tie(lbs, ubs, steps) =
407       getLoopBounds(rewriter, loc, loopRanges, tileSizes);
408   SmallVector<Value> lbVals =
409       getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
410   SmallVector<Value> ubVals =
411       getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
412   SmallVector<Value> stepVals =
413       getValueOrCreateConstantIndexOp(rewriter, loc, steps);
414 
415   SmallVector<Value> ivs;
416   for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
417     auto loop =
418         rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
419                                     [](OpBuilder &bodyBuilder, Location bodyLoc,
420                                        Value iv, ValueRange /*iterArgs*/) {});
421     loops.push_back(loop);
422     ivs.push_back(loop.getInductionVar());
423     rewriter.setInsertionPointToEnd(loop.getBody());
424     destinationTensors = loop.getRegionIterArgs();
425   }
426 
427   SmallVector<Value> tiledResults;
428   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
429   if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
430                                 tiledResults, resultOffsets, resultSizes))) {
431     return rewriter.notifyMatchFailure(
432         loc, "failed to generate inner tile loop body");
433   }
434   if (loops.empty())
435     return success();
436 
437   assert(tiledResults.size() == destinationTensors.size() &&
438          "Number of results of body should be equal to number of iter args");
439 
440   // 6. Yield all the results of the tiled operation.
441   SmallVector<Value> yieldedValues;
442   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
443        llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
444                        resultSizes)) {
445     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
446                                            rewriter.getIndexAttr(1));
447     auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
448         loc, tiledValue, destinationTensor, resultOffset, resultSize,
449         resultStride);
450     yieldedValues.push_back(insertSlice);
451   }
452   rewriter.create<scf::YieldOp>(loc, yieldedValues);
453 
454   // Add the scf.yield operations for all the outer loops.
455   for (auto [outerLoop, innerLoop] :
456        llvm::zip_equal(MutableArrayRef(loops).drop_back(),
457                        MutableArrayRef(loops).drop_front())) {
458     rewriter.setInsertionPointToEnd(
459         cast<scf::ForOp>(outerLoop.getOperation()).getBody());
460     rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
461   }
462   return success();
463 }
464 
465 /// Generate the tile-loop nest using `scf.forall` operation.
466 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
467 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
468 /// - `destinationTensors` are the init values to use for the outer most loop.
469 /// - `mappingVector` is the mapping attributes to use for loop construction.
470 ///   Can be empty.
471 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
472 /// most
473 ///    loop.
474 /// - `loops` is an in-out parameter into which the generated loops are
475 ///    populated.
476 static LogicalResult generateLoopNestUsingForallOp(
477     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
478     ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
479     ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
480     YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
481   assert(!loopRanges.empty() && "unexpected empty loop ranges");
482   assert(loopRanges.size() == tileSizes.size() &&
483          "expected as many tile sizes as loop ranges");
484   OpBuilder::InsertionGuard guard(rewriter);
485   SmallVector<OpFoldResult> offsets(loopRanges.size()),
486       sizes(loopRanges.size());
487 
488   std::optional<ArrayAttr> mappingAttr;
489   if (!mappingVector.empty())
490     mappingAttr = rewriter.getArrayAttr(mappingVector);
491 
492   scf::ForallOp forallOp;
493   bool useNumThreads = !numThreads.empty();
494 
495   if (useNumThreads) {
496     // Prune the zero numthreads.
497     SmallVector<OpFoldResult> nonZeroNumThreads;
498     for (auto nt : numThreads) {
499       if (isConstantIntValue(nt, 0))
500         continue;
501       nonZeroNumThreads.push_back(nt);
502     }
503     forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
504                                               destinationTensors, mappingAttr);
505   } else {
506     SmallVector<OpFoldResult> lbs, ubs, steps;
507     std::tie(lbs, ubs, steps) =
508         getLoopBounds(rewriter, loc, loopRanges, tileSizes);
509     forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
510                                               destinationTensors, mappingAttr);
511   }
512   loops.push_back(forallOp);
513 
514   rewriter.setInsertionPoint(forallOp.getTerminator());
515   destinationTensors = forallOp.getRegionOutArgs();
516 
517   SmallVector<Value> tiledResults;
518   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
519   if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
520                          destinationTensors, tiledResults, resultOffsets,
521                          resultSizes)))
522     return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
523 
524   rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
525   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
526        llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
527                        resultSizes)) {
528     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
529                                            rewriter.getIndexAttr(1));
530 
531     rewriter.create<tensor::ParallelInsertSliceOp>(
532         loc, tiledValue, destinationTensor, resultOffset, resultSize,
533         resultStride);
534   }
535   return success();
536 }
537 
538 /// Generate the tile-loop nest using the loop construct specifed in `options`.
539 /// - `options`: Tiling options specified.
540 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
541 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
542 /// - `destinationTensors` are the init values to use for the outer most loop.
543 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
544 /// most
545 ///    loop.
546 /// - `loops` is an in-out parameter into which the generated loops are
547 ///    populated.
548 static LogicalResult generateLoopNest(
549     RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
550     ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
551     ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
552     YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
553   // If the tile sizes are all zero, no loops are generated. Just call the
554   // callback function to handle untiled case.
555   if (llvm::all_of(tileSizes, isZeroIndex)) {
556     SmallVector<Value> tiledResults;
557     SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
558     return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
559                        tiledResults, resultOffsets, resultSizes);
560   }
561   if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
562     return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
563                                       destinationTensors, tiledBodyFn, loops);
564   }
565   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
566     return generateLoopNestUsingForallOp(
567         rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
568         destinationTensors, tiledBodyFn, loops);
569   }
570   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
571 }
572 
573 /// Append the specified additional `newInitOperands` operands to the
574 /// loops existing `init` operands (or similar), and replace `loopOp` with
575 /// the new loop that has the additional init operands. The loop body of
576 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
577 /// is called to get the new tiled values returned, and the offset
578 /// and sizes at which the tiled value is inserted into the
579 /// new region iter_args that correspond to the newly added init operands.
580 template <typename LoopType>
581 FailureOr<LoopLikeOpInterface>
582 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
583                                ValueRange newInitOperands,
584                                YieldTiledValuesFn yieldTiledValuesFn) {
585   return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
586 }
587 
588 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
589 template <>
590 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
591     scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
592     YieldTiledValuesFn yieldTiledValuesFn) {
593   OpBuilder::InsertionGuard g(rewriter);
594   Location loc = loopOp.getLoc();
595   rewriter.setInsertionPoint(loopOp);
596 
597   auto inits = llvm::to_vector(loopOp.getInitArgs());
598   inits.append(newInitOperands.begin(), newInitOperands.end());
599   auto newLoop = rewriter.create<scf::ForOp>(
600       loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
601       inits, [](OpBuilder &, Location, Value, ValueRange) {});
602 
603   // Move the loop body to the new op.
604   Block *loopBody = loopOp.getBody();
605   Block *newLoopBody = newLoop.getBody();
606   rewriter.mergeBlocks(
607       loopBody, newLoopBody,
608       newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
609 
610   auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
611   rewriter.setInsertionPoint(yieldOp);
612 
613   SmallVector<Value> tiledValues;
614   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
615   ValueRange newRegionIterArgs =
616       newLoop.getRegionIterArgs().take_back(newInitOperands.size());
617   if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
618                                 newRegionIterArgs, tiledValues, resultOffsets,
619                                 resultSizes))) {
620     rewriter.eraseOp(newLoop);
621     return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
622   }
623 
624   SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
625   for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
626        llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
627                        resultSizes)) {
628     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
629                                            rewriter.getIndexAttr(1));
630     Value insert = rewriter.create<tensor::InsertSliceOp>(
631         yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
632         resultStride);
633     newYieldValues.push_back(insert);
634   }
635 
636   rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
637   rewriter.replaceOp(loopOp,
638                      newLoop->getResults().take_front(loopOp.getNumResults()));
639   return cast<LoopLikeOpInterface>(newLoop.getOperation());
640 }
641 
642 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
643 template <>
644 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
645     scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
646     YieldTiledValuesFn yieldTiledValuesFn) {
647   OpBuilder::InsertionGuard g(rewriter);
648   Location loc = loopOp.getLoc();
649   rewriter.setInsertionPoint(loopOp);
650   auto inits = llvm::to_vector(loopOp.getOutputs());
651   inits.append(newInitOperands.begin(), newInitOperands.end());
652   auto newLoop = rewriter.create<scf::ForallOp>(
653       loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
654       loopOp.getMixedStep(), inits, loopOp.getMapping(),
655       [](OpBuilder &, Location, ValueRange) {});
656 
657   // Move the region of the current block to the newly created op.
658   Block *loopBody = loopOp.getBody();
659   Block *newLoopBody = newLoop.getBody();
660   rewriter.mergeBlocks(
661       loopBody, newLoopBody,
662       newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
663 
664   auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
665   rewriter.setInsertionPoint(terminator);
666   SmallVector<Value> tiledValues;
667   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
668   ValueRange regionIterArgs =
669       newLoop.getRegionIterArgs().take_back(newInitOperands.size());
670   if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
671                                 regionIterArgs, tiledValues, resultOffsets,
672                                 resultSizes))) {
673     rewriter.eraseOp(newLoop);
674     return rewriter.notifyMatchFailure(loopOp,
675                                        "failed to get yielded tiled values");
676   }
677 
678   // Update the terminator.
679   rewriter.setInsertionPointToEnd(terminator.getBody());
680 
681   for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
682            tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
683     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
684                                            rewriter.getIndexAttr(1));
685     rewriter.create<tensor::ParallelInsertSliceOp>(
686         terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
687         resultStride);
688   }
689 
690   rewriter.replaceOp(loopOp,
691                      newLoop->getResults().take_front(loopOp.getNumResults()));
692   return cast<LoopLikeOpInterface>(newLoop.getOperation());
693 }
694 
695 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
696 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
697 /// supported loop type.
698 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
699     LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
700     ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
701   return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>(
702              loopLikeOp.getOperation())
703       .Case<scf::ForOp, scf::ForallOp>(
704           [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
705             return yieldTiledValuesAndReplaceLoop(
706                 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
707           })
708       .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
709         return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
710       });
711 }
712 
713 /// Method to add new init values to a loop nest. Updates `loops` in-place with
714 /// new loops that use the `newInitValues`.
715 /// The outer-loops are updated to yield the new result values of the inner
716 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get
717 /// the additional values to yield form the innermost loop.
718 static LogicalResult addInitOperandsToLoopNest(
719     RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
720     ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
721   SmallVector<scf::ForOp> newLoops;
722   if (loops.empty())
723     return success();
724   OpBuilder::InsertionGuard g(rewriter);
725   rewriter.setInsertionPoint(loops.front());
726 
727   SmallVector<Value> ivs;
728   for (auto &loop : loops.drop_back()) {
729     rewriter.setInsertionPoint(loop);
730 
731     // if loops.size() > 1 we assume that scf.for is used for the loops.
732     auto forLoop = cast<scf::ForOp>(loop.getOperation());
733 
734     // Create a new loop with the new init values for this loop.
735     SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
736     newInits.append(newInitValues.begin(), newInitValues.end());
737     auto newLoop = rewriter.create<scf::ForOp>(
738         forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
739         forLoop.getStep(), newInits,
740         [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
741 
742     // Merge the body of the new loop with the body of the old loops.
743     SmallVector<Value> sourceBlockArgs;
744     sourceBlockArgs.push_back(newLoop.getInductionVar());
745     auto newRegionIterArgs = newLoop.getRegionIterArgs();
746     sourceBlockArgs.append(
747         newRegionIterArgs.begin(),
748         std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
749     rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
750     rewriter.replaceOp(
751         forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
752     loop = newLoop;
753     ivs.push_back(newLoop.getInductionVar());
754     newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
755   }
756 
757   // Update the loop body of the innermost loop to get new yield values.
758   LoopLikeOpInterface innerMostLoop = loops.back();
759   FailureOr<LoopLikeOpInterface> newInnerMostLoop =
760       yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
761                                      getNewTiledYieldsFn);
762 
763   if (failed(newInnerMostLoop))
764     return innerMostLoop.emitOpError("failed to return additional yields");
765   loops.back() = newInnerMostLoop.value();
766 
767   // Make all other loops except the innermost loops yield the values returned
768   // by the inner loop.
769   for (auto [outerLoop, innerLoop] :
770        llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
771     // Again assume that all the outer loops are scf.for operations.
772     auto outerForLoop = cast<scf::ForOp>(outerLoop);
773     auto outerLoopYield =
774         cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
775     SmallVector<Value> newYields =
776         llvm::to_vector(outerLoopYield.getOperands());
777     ValueRange additionalYields =
778         innerLoop->getResults().take_back(newInitValues.size());
779     newYields.append(additionalYields.begin(), additionalYields.end());
780     rewriter.setInsertionPoint(outerLoopYield);
781     rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
782   }
783   return success();
784 }
785 
786 /// Implementation of tiling transformation of `op` that implements the
787 /// `TilingInterface` using `scf.for` to iterate over the tiles.
788 FailureOr<scf::SCFTilingResult>
789 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
790                         const scf::SCFTilingOptions &options) {
791   if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
792     return failure();
793   }
794 
795   OpBuilder::InsertionGuard guard(rewriter);
796   rewriter.setInsertionPointAfter(op);
797 
798   // 1. Get the range of the loops that are represented by the operation.
799   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
800 
801   // 2. Materialize the tile sizes and/or number of threads;
802   SmallVector<OpFoldResult> tileSizes, numThreads;
803   std::tie(tileSizes, numThreads) =
804       getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
805 
806   // Check if it is safe to tile. This is hold over from previous iterations
807   // of tile to for-all. Consider dropping it.
808   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
809     checkSafeToTileToForall(op, tileSizes, numThreads);
810   }
811 
812   // 3. If there is an interchange specified, permute the iteration domain and
813   // the tile sizes.
814   SmallVector<int64_t> interchangeVector;
815   if (!options.interchangeVector.empty()) {
816     interchangeVector = fillInterchangeVector(options.interchangeVector,
817                                               iterationDomain.size());
818     assert(isPermutationVector(interchangeVector) &&
819            "expected interchange vector to be a permutation");
820 
821     applyPermutationToVector(iterationDomain, interchangeVector);
822     applyPermutationToVector(tileSizes, interchangeVector);
823     if (!numThreads.empty())
824       applyPermutationToVector(numThreads, interchangeVector);
825   }
826 
827   FailureOr<TilingResult> tilingResult;
828   // 4. Define the lambda function used later to generate the body of the
829   // innermost tiled loop.
830   YieldTiledValuesFn innerYieldTiledValuesFn =
831       [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
832           ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
833           SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
834           SmallVector<SmallVector<OpFoldResult>> &resultSizes)
835       -> LogicalResult {
836     // 4a. Compute the `offsets` and `sizes` to use for tiling.
837     SmallVector<OpFoldResult> offsets, sizes;
838     std::tie(offsets, sizes) = getTileOffsetAndSizes(
839         rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
840 
841     // 4b. If interchange was provided, apply inverse of the interchange
842     //     to get back the offsets/sizes in the order to be specified.
843     if (!interchangeVector.empty()) {
844       auto inversePermutation = invertPermutationVector(interchangeVector);
845       applyPermutationToVector(offsets, inversePermutation);
846       applyPermutationToVector(sizes, inversePermutation);
847     }
848 
849     // 5. Generate the tiled implementation within the inner most loop.
850 
851     // 5a. Clone the operation within the loop body.
852     auto clonedOp = cast<TilingInterface>(
853         cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
854 
855     // 5b. Early return cloned op if tiling is not happening. We can not return
856     // the original op because it could lead to
857     // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
858     if (llvm::all_of(tileSizes, isZeroIndex)) {
859       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
860       tilingResult =
861           TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
862                        /*generatedSlices=*/{}};
863       return success();
864     }
865 
866     // 5c. Tile the cloned operation.
867     tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
868     if (failed(tilingResult)) {
869       rewriter.eraseOp(clonedOp);
870       return op.emitOpError("faild to tile operation");
871     }
872 
873     // 5d. Delete the cloned operation.
874     rewriter.eraseOp(clonedOp);
875 
876     // 5e. Compute the offsets at which the result values are to be inserted
877     //     back into its destinations.
878     for (auto [index, tiledValue] :
879          llvm::enumerate(tilingResult->tiledValues)) {
880       tiledResults.push_back(tiledValue);
881       SmallVector<OpFoldResult> resultOffset, resultSize;
882       if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
883                                           resultOffset, resultSize))) {
884         for (auto op : tilingResult->tiledOps) {
885           rewriter.eraseOp(op);
886         }
887         return rewriter.notifyMatchFailure(
888             op, "failed to get slice of result produced");
889       }
890       resultOffsets.emplace_back(std::move(resultOffset));
891       resultSizes.emplace_back(std::move(resultSize));
892     }
893 
894     return success();
895   };
896 
897   // 6. Find the destination tensors to use for the operation.
898   SmallVector<Value> destinationTensors;
899   if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
900                                              destinationTensors))) {
901     return rewriter.notifyMatchFailure(op,
902                                        "unable to create destination tensors");
903   }
904 
905   // 7. Generate the tiled loops nest using the callback defined above.
906   SmallVector<LoopLikeOpInterface> loops;
907   if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
908                               tileSizes, numThreads, destinationTensors,
909                               innerYieldTiledValuesFn, loops)))
910     return op.emitOpError("failed to generate tiling loops");
911   assert(succeeded(tilingResult) &&
912          "expected tiling result to be computed after loop generation");
913 
914   // If loops are empty, the tiled op is used as the replacement for the untiled
915   // op.
916   if (loops.empty()) {
917     return scf::SCFTilingResult{tilingResult->tiledOps, loops,
918                                 tilingResult->tiledValues,
919                                 tilingResult->generatedSlices};
920   }
921 
922   SmallVector<Value> replacements = llvm::map_to_vector(
923       loops.front()->getResults(), [](OpResult r) -> Value { return r; });
924   return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
925                               tilingResult->generatedSlices};
926 }
927 
928 FailureOr<scf::SCFReductionTilingResult>
929 mlir::scf::tileReductionUsingScf(RewriterBase &b,
930                                  PartialReductionOpInterface op,
931                                  ArrayRef<OpFoldResult> tileSizes) {
932   Location loc = op.getLoc();
933   // Ops implementing PartialReductionOpInterface are expected to implement
934   // TilingInterface.
935   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
936   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
937   auto tileSizesVector = llvm::to_vector(tileSizes);
938   if (tileSizesVector.size() < iterationDomain.size()) {
939     auto zero = b.getIndexAttr(0);
940     tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
941                            zero);
942   }
943   SmallVector<utils::IteratorType> iterators =
944       tilingInterfaceOp.getLoopIteratorTypes();
945 
946   SmallVector<int> reductionDims;
947   for (auto [idx, iteratorType] :
948        llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
949     if (iteratorType == utils::IteratorType::reduction)
950       reductionDims.push_back(idx);
951   }
952 
953   // 2. create the inital tensor value.
954   FailureOr<SmallVector<Value>> maybeInitTensors =
955       op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
956                                                   reductionDims);
957   if (failed(maybeInitTensors)) {
958     return b.notifyMatchFailure(op, "Failed to create initial tensors.");
959   }
960   SmallVector<Value> &initTensors = maybeInitTensors.value();
961 
962   // 3. Define the callback to use for generating the inner most tile loop body.
963   SmallVector<Operation *> parallelTiledOps;
964   auto innerYieldTiledValuesFn =
965       [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
966           ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
967           SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
968           SmallVector<SmallVector<OpFoldResult>> &resultSizes)
969       -> LogicalResult {
970     SmallVector<OpFoldResult> offsets, sizes;
971     {
972       int materializedLoopNum = 0;
973       for (auto [tileSize, loopRange] :
974            llvm::zip_equal(tileSizesVector, iterationDomain)) {
975         if (isConstantIntValue(tileSize, 0)) {
976           offsets.push_back(loopRange.offset);
977           sizes.push_back(loopRange.size);
978           continue;
979         }
980         Value iv = ivs[materializedLoopNum++];
981         offsets.push_back(iv);
982         sizes.push_back(
983             getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
984       }
985     }
986 
987     // 4a. Clone the operation.
988     {
989       auto clonedOp = cast<PartialReductionOpInterface>(
990           cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
991 
992       // 4b. Tile the cloned operation.
993       FailureOr<TilingResult> partialTilingResult =
994           clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
995                                           sizes, reductionDims);
996       if (failed(partialTilingResult)) {
997         return failure();
998       }
999       std::swap(parallelTiledOps, partialTilingResult->tiledOps);
1000       std::swap(tiledResult, partialTilingResult->tiledValues);
1001 
1002       // 4c. Delete the cloned operation.
1003       b.eraseOp(clonedOp);
1004     }
1005 
1006     // 4d. Compute the offsets and sizes needed to insert the result of the
1007     // tiled value back into destination before yielding the destination.
1008     for (auto result : tiledResult) {
1009       SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
1010       resultOffsets.emplace_back(std::move(outOffsets));
1011 
1012       SmallVector<OpFoldResult> outSizes;
1013       for (size_t i = 0; i < offsets.size(); i++) {
1014         outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
1015       }
1016       resultSizes.emplace_back(std::move(outSizes));
1017     }
1018     return success();
1019   };
1020 
1021   // 5. Generate the tiled implementation using the destination tensors.
1022   SmallVector<LoopLikeOpInterface> loops;
1023   scf::SCFTilingOptions options;
1024   options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
1025   if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
1026                               /*numThreads=*/ArrayRef<OpFoldResult>{},
1027                               initTensors, innerYieldTiledValuesFn, loops)))
1028     return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
1029 
1030   SmallVector<Value> replacements = llvm::map_to_vector(
1031       loops.front()->getResults(), [](OpResult r) -> Value { return r; });
1032 
1033   // 5. Apply the merge reduction to combine all the partial values.
1034   b.setInsertionPointAfter(*loops.begin());
1035   FailureOr<MergeResult> mergeResult =
1036       op.mergeReductions(b, loc, replacements, reductionDims);
1037   if (failed(mergeResult)) {
1038     return failure();
1039   }
1040   b.replaceOp(op, mergeResult->replacements);
1041 
1042   SCFReductionTilingResult reductionTilingResult;
1043   std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
1044   std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
1045   std::swap(reductionTilingResult.initialValues, initTensors);
1046   std::swap(reductionTilingResult.loops, loops);
1047   std::swap(reductionTilingResult.replacements, mergeResult->replacements);
1048 
1049   return reductionTilingResult;
1050 }
1051 
1052 //===----------------------------------------------------------------------===//
1053 // tileConsumerAndFuseProducersUsingSCF implementation.
1054 //===----------------------------------------------------------------------===//
1055 
1056 /// Return the untiled producer whose slice is used in a tiled consumer. The
1057 /// method traverses the tile loop nest (`loops`) if needed, and returns the
1058 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
1059 /// indicates that this is a destination operand of the consumer. If there was
1060 /// no loop traversal needed, the second value of the returned tuple is empty.
1061 static std::tuple<OpResult, std::optional<OpOperand *>>
1062 getUntiledProducerFromSliceSource(OpOperand *source,
1063                                   ArrayRef<LoopLikeOpInterface> loops) {
1064   std::optional<OpOperand *> destinationIterArg;
1065   auto loopIt = loops.rbegin();
1066   while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
1067     auto loop = *loopIt;
1068     if (iterArg.getOwner()->getParentOp() != loop)
1069       break;
1070     source = loop.getTiedLoopInit(iterArg);
1071     loopIt++;
1072   }
1073   if (loopIt == loops.rend())
1074     destinationIterArg = source;
1075   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
1076 }
1077 
1078 /// Implementation of fusing producer of a single slice by computing the
1079 /// slice of the producer in-place.
1080 std::optional<scf::SCFFuseProducerOfSliceResult>
1081 mlir::scf::tileAndFuseProducerOfSlice(
1082     RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1083     MutableArrayRef<LoopLikeOpInterface> loops) {
1084   // 1. Get the producer of the source (potentially walking through
1085   // `iter_args` of nested `scf.for`)
1086   auto [fusableProducer, destinationInitArg] =
1087       getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1088                                         loops);
1089   if (!fusableProducer)
1090     return std::nullopt;
1091   unsigned resultNumber = fusableProducer.getResultNumber();
1092 
1093   OpBuilder::InsertionGuard g(rewriter);
1094   rewriter.setInsertionPoint(candidateSliceOp);
1095 
1096   // 2. Clone the fused producer
1097   // 2a. Compute the destination operands to use for the cloned operation.
1098   SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1099   Operation *fusableProducerOp = fusableProducer.getOwner();
1100   if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1101       failed(tensor::getOrCreateDestinations(
1102           rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1103           origDestinationTensors)))
1104     return std::nullopt;
1105 
1106   clonedOpDestinationTensors = origDestinationTensors;
1107   if (destinationInitArg &&
1108       isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1109     // 2b. If the producer is also destination style, then to maintain the
1110     // destination passing style, update the destination of the producer to be
1111     // the source of the slice.
1112     clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1113   }
1114   // 2c. Clone the fused producer.
1115   Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1116       rewriter, fusableProducerOp, clonedOpDestinationTensors);
1117   // 2d. Update the source of the candidateSlice to be the cloned producer.
1118   //     Easier to just clone the slice with different source since replacements
1119   //     and DCE of cloned ops becomes easier
1120   SmallVector<Value> candidateSliceOpOperands =
1121       llvm::to_vector(candidateSliceOp->getOperands());
1122   candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
1123   tensor::ExtractSliceOp clonedCandidateSliceOp =
1124       mlir::clone(rewriter, candidateSliceOp,
1125                   candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1126 
1127   // 3. Generate the tiled implementation of the producer of the source
1128   FailureOr<TilingResult> tileAndFuseResult =
1129       tensor::replaceExtractSliceWithTiledProducer(
1130           rewriter, clonedCandidateSliceOp,
1131           clonedProducerOp->getResult(resultNumber));
1132   if (failed(tileAndFuseResult))
1133     return std::nullopt;
1134   // Note: Do not delete the candidateSliceOp, since its passed in from the
1135   // caller.
1136   rewriter.replaceAllUsesWith(candidateSliceOp,
1137                               tileAndFuseResult->tiledValues[0]);
1138   rewriter.eraseOp(clonedCandidateSliceOp);
1139   rewriter.eraseOp(clonedProducerOp);
1140 
1141   // 3. If the slice is for a destination operand, for example,
1142   //
1143   // ```mlir
1144   // %0 = linalg.init
1145   // %1 = linalg.fill .. outs(%0 : )
1146   // %2 = scf.for .. iter_args(%arg0 = %1) {
1147   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
1148   //     %4 = tensor.extract_slice %arg1 [..]
1149   //     .. = linalg.matmul .. outs(%4 : )
1150   //   }
1151   // }
1152   // ```
1153   //
1154   // the IR is currently
1155   //
1156   // ```
1157   // %0 = linalg.init
1158   // %1 = linalg.fill
1159   // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1160   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
1161   //     %4 = tensor.extract_slice %arg1[..]
1162   //     %5 = linalg.fill .. outs(%4 : )
1163   //     .. = linalg.matmul .. outs(%5 : )
1164   //   }
1165   // }
1166   // ```
1167   //
1168   // The untiled `linalg.fill` is still used as the `init_value` since it
1169   // was originally a destination operand of the untiled `linalg.matmul`.
1170   // When fusing an operand that is a destination operand, the iter_arg of
1171   // the outer most loop should be changed to use the destination of the
1172   // fused operation. With this the IR will be.
1173   //
1174   // ```
1175   // %0 = linalg.init
1176   // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1177   //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
1178   //     %3 = tensor.extract_slice %arg1[..]
1179   //     %4 = linalg.fill .. outs(%3 : )
1180   //     .. = linalg.matmul .. outs(%4 : )
1181   //   }
1182   // }
1183   // ```
1184   if (destinationInitArg &&
1185       isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1186     loops.front()
1187         ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1188         .set(origDestinationTensors[resultNumber]);
1189   }
1190   return scf::SCFFuseProducerOfSliceResult{
1191       fusableProducer, tileAndFuseResult->tiledValues[0],
1192       tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1193 }
1194 
1195 /// Reconstruct the fused producer from within the tiled-and-fused code.
1196 FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1197     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1198     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1199     MutableArrayRef<LoopLikeOpInterface> loops,
1200     ArrayRef<unsigned> yieldResultNumber) {
1201   if (loops.empty())
1202     return success();
1203 
1204   Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1205             *tiledOwner = fusedProducerInfo.tiledOps[0];
1206 
1207   Location loc = originalOwner->getLoc();
1208   // a. collect all init Value to be appended
1209   SmallVector<unsigned> initNumberList =
1210       yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1211                                       0, originalOwner->getNumResults()))
1212                                 : llvm::to_vector(yieldResultNumber);
1213   SmallVector<Value> initValueList;
1214   for (const auto &resultNumber : initNumberList) {
1215     FailureOr<Value> initValue = tensor::getOrCreateDestination(
1216         rewriter, loc, originalOwner->getResult(resultNumber));
1217     if (succeeded(initValue)) {
1218       initValueList.push_back(initValue.value());
1219     } else {
1220       return failure();
1221     }
1222   }
1223 
1224   SmallVector<Operation *> generatedSlices;
1225   YieldTiledValuesFn newYieldValuesFn =
1226       [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1227           ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1228           SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
1229           SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1230     OpBuilder::InsertionGuard g(innerRewriter);
1231 
1232     // get sliceOp tile information
1233     SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1234                               sliceSizes = sliceOp.getMixedSizes();
1235 
1236     // expect all strides of sliceOp being 1
1237     if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
1238           return !isConstantIntValue(ofr, 1);
1239         }))
1240       return failure();
1241 
1242     unsigned sliceResultNumber =
1243         fusedProducerInfo.origProducer.getResultNumber();
1244 
1245     auto tilableOp = cast<TilingInterface>(originalOwner);
1246     // b. get iterDomain Offset and Sizes based on sliceOp tile
1247     SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1248     // skip tensor.pack/unpack/pad, which expects single opResult
1249     if (tilableOp->getNumResults() > 1 &&
1250         failed(tilableOp.getIterationDomainTileFromResultTile(
1251             rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1252             iterDomainOffset, iterDomainSizes))) {
1253       // In theory, it is unnecessary to raise an error here. Actually although
1254       // it fails to reconstruct the result tensor, it should not broke current
1255       // fusion anyway. The reason why we must return failure currently is that
1256       // the callback function `newYieldValuesFn` will be called after new init
1257       // operand(s) has already been appended. It will take more refactoring to
1258       // make sure the init operands are added consistently in the future. For
1259       // more details, please refer to:
1260       // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1261       return failure();
1262     }
1263 
1264     // c. calculate offsets and sizes info of all OpResults respectively based
1265     // on iteration Domain Tile
1266     SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1267     for (const auto &resultNumber : initNumberList) {
1268       if (resultNumber == sliceResultNumber) {
1269         offsetList.push_back(sliceOffset);
1270         sizesList.push_back(sliceSizes);
1271       } else {
1272         assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1273         // infer result tile according to the iteration domain tile
1274         SmallVector<OpFoldResult> offset, sizes;
1275         if (failed(tilableOp.getResultTilePosition(
1276                 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1277                 offset, sizes))) {
1278           return failure();
1279         }
1280         offsetList.push_back(offset);
1281         sizesList.push_back(sizes);
1282       }
1283     }
1284 
1285     // d. create `extract_slice` for `iter_args` for DPS operation if necessary
1286     if (auto tiledDestStyleOp =
1287             dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1288       rewriter.setInsertionPoint(tiledDestStyleOp);
1289       for (const auto &&[index, newRegionArg] :
1290            llvm::enumerate(newRegionIterArgs)) {
1291         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1292             loc, newRegionArg, offsetList[index], sizesList[index],
1293             SmallVector<OpFoldResult>(offsetList[index].size(),
1294                                       rewriter.getIndexAttr(1)));
1295         generatedSlices.push_back(destSlice);
1296         unsigned resultNumber = initNumberList[index];
1297         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1298           tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1299         });
1300       }
1301     }
1302 
1303     // e. prepare tiled offset and sizes for later `insert_slice` creation by
1304     // caller
1305     Block *block = rewriter.getInsertionPoint()->getBlock();
1306     rewriter.setInsertionPoint(block->getTerminator());
1307     for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1308       tiledResult.push_back(tiledOwner->getResult(resultNumber));
1309       tiledOffset.emplace_back(offsetList[index]);
1310       tiledSizes.emplace_back(sizesList[index]);
1311     }
1312     return success();
1313   };
1314 
1315   if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
1316                                        newYieldValuesFn))) {
1317     return failure();
1318   }
1319   return generatedSlices;
1320 }
1321 
1322 namespace {
1323 
1324 //===----------------------------------------------------------------------===//
1325 // SliceTrackingListener
1326 //===----------------------------------------------------------------------===//
1327 
1328 /// This class is a listener for tracking the insertion and removal of
1329 /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1330 /// fusion algorithm to apply cleanup patterns in between fusion steps.
1331 class SliceTrackingListener : public RewriterBase::Listener {
1332 public:
1333   explicit SliceTrackingListener(
1334       std::optional<FrozenRewritePatternSet> patterns);
1335   SliceTrackingListener() = default;
1336 
1337   /// Adds the given list of operations to the worklist, and if present, applies
1338   /// the list of `patterns` to the newly added operations. This only processes
1339   /// the given operations and any newly inserted ones by the pattern set.
1340   LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1341 
1342   /// Add to the new operation worklist if it is an extract_slice.
1343   void notifyOperationInserted(Operation *op,
1344                                OpBuilder::InsertPoint previous) override;
1345 
1346   /// Shared helper for operation removal from the worklist.
1347   void removeOp(Operation *op);
1348 
1349   /// Remove the operation from the worklist.
1350   void notifyOperationErased(Operation *op) override;
1351 
1352   /// Remove the operation from the worklist.
1353   void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1354 
1355   /// The worklist for this transformation keeps track of the slices to visit
1356   /// next for fusion.
1357   std::deque<tensor::ExtractSliceOp> worklist;
1358 
1359 private:
1360   /// Optional pattern set to apply when adding new operations to the worklist.
1361   std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1362 };
1363 
1364 SliceTrackingListener::SliceTrackingListener(
1365     std::optional<FrozenRewritePatternSet> p) {
1366   patterns = std::move(p);
1367 }
1368 
1369 LogicalResult
1370 SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1371   for (Operation *op : ops) {
1372     if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1373       worklist.push_back(slice);
1374   }
1375 
1376   if (!patterns)
1377     return success();
1378 
1379   GreedyRewriteConfig config;
1380   config.listener = this;
1381   config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
1382   return applyOpPatternsAndFold(ops, patterns.value(), config);
1383 }
1384 
1385 void SliceTrackingListener::notifyOperationInserted(
1386     Operation *op, OpBuilder::InsertPoint previous) {
1387   auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1388   if (!slice)
1389     return;
1390   worklist.push_back(slice);
1391 }
1392 
1393 // Scan the worklist for the given op and remove it if present. The expectation
1394 // is for the worklist to be small and for removal to be relatively rare.
1395 void SliceTrackingListener::removeOp(Operation *op) {
1396   if (!isa<tensor::ExtractSliceOp>(op))
1397     return;
1398   auto iter = worklist.begin();
1399   while (iter != worklist.end()) {
1400     if (*iter == op)
1401       break;
1402     iter++;
1403   }
1404   if (iter == worklist.end())
1405     return;
1406 
1407   worklist.erase(iter);
1408 }
1409 
1410 void SliceTrackingListener::notifyOperationErased(Operation *op) {
1411   removeOp(op);
1412 }
1413 
1414 void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1415                                                     ValueRange replacement) {
1416   removeOp(op);
1417 }
1418 } // namespace
1419 
1420 /// Implementation of tile consumer and fuse producer greedily.
1421 FailureOr<scf::SCFTileAndFuseResult>
1422 mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1423     RewriterBase &rewriter, TilingInterface consumer,
1424     const scf::SCFTileAndFuseOptions &options) {
1425   // This transformation is only valid for ops that return values (i.e. not
1426   // valid to use with operations that have memref operands).
1427   if (!consumer->getNumResults()) {
1428     return rewriter.notifyMatchFailure(
1429         consumer, "invalid pattern for op with no results");
1430   }
1431 
1432   // 1. First tile the consumer.
1433   SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1434   llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1435 
1436   FailureOr<scf::SCFTilingResult> tilingResult =
1437       tileUsingSCF(rewriter, consumer, options.tilingOptions);
1438 
1439   if (failed(tilingResult))
1440     return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1441   for (auto *tiledOp : tilingResult->tiledOps)
1442     tiledAndFusedOps.insert(tiledOp);
1443 
1444   // If there are no loops generated, fusion is immaterial.
1445   auto &loops = tilingResult->loops;
1446   if (loops.empty()) {
1447     DenseMap<Value, Value> replacements;
1448     for (auto [origVal, replacement] :
1449          llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1450       replacements[origVal] = replacement;
1451     }
1452     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1453                                      replacements};
1454   }
1455 
1456   // To keep track of replacements for now just record the map from the original
1457   // untiled value to the result number of the for loop. Since the loop gets
1458   // potentially replaced during fusion, keeping the value directly wont work.
1459   DenseMap<Value, size_t> origValToResultNumber;
1460   for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
1461     origValToResultNumber[result] = index;
1462   }
1463 
1464   // 2. Typically, the operands of the tiled operation are slices of the
1465   //    operands of the untiled operation. These are expressed in IR using
1466   //    `tensor.extract_slice` operations with source being the operands of the
1467   //    untiled operation. Create a worklist of these `tensor.extract_slice`
1468   //    operations. If the producers of the source of the `tensor.extract_slice`
1469   //    can be tiled such that the tiled value is generated in-place, that
1470   //    effectively tiles + fuses the operations.
1471   struct WorklistItem {
1472     tensor::ExtractSliceOp candidateSlice;
1473     SCFTileAndFuseOptions::ControlFnResult controlFnResult;
1474   };
1475 
1476   SliceTrackingListener sliceTracker =
1477       SliceTrackingListener(options.cleanupPatterns);
1478 
1479   if (failed(
1480           sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1481     return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1482   }
1483   OpBuilder::InsertionGuard g(rewriter);
1484   while (!sliceTracker.worklist.empty()) {
1485     auto candidateSlice = sliceTracker.worklist.front();
1486     sliceTracker.worklist.pop_front();
1487 
1488     auto [fusableProducer, destinationInitArg] =
1489         getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1490                                           loops);
1491     if (!fusableProducer)
1492       continue;
1493 
1494     std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1495         options.fusionControlFn(candidateSlice, fusableProducer,
1496                                 destinationInitArg.has_value());
1497     if (!controlFnResult)
1498       continue;
1499 
1500     WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1501 
1502     // The operands of the fused producer might themselved be slices of
1503     // values produced by operations that implement the `TilingInterface`.
1504     // Add these operations to the worklist.
1505     std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1506         tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1507                                    loops);
1508     if (!fusedResult)
1509       continue;
1510 
1511     SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1512 
1513     if (worklistItem.controlFnResult.yieldProducerReplacement) {
1514       // Reconstruct and yield all opResult of fusableProducerOp by default. The
1515       // caller can specific which one to yield by designating optional argument
1516       // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
1517       Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1518       FailureOr<SmallVector<Operation *>> newSlices =
1519           yieldReplacementForFusedProducer(rewriter,
1520                                            worklistItem.candidateSlice,
1521                                            fusedResult.value(), loops);
1522       if (failed(newSlices)) {
1523         return rewriter.notifyMatchFailure(
1524             fusableProducerOp, "failed to replacement value for this "
1525                                "operation from within the tiled loop");
1526       }
1527       worklistCandidates.append(newSlices.value());
1528       for (auto [index, result] :
1529            llvm::enumerate(fusableProducerOp->getResults())) {
1530         origValToResultNumber[result] = loops.front()->getNumResults() -
1531                                         fusableProducerOp->getNumResults() +
1532                                         index;
1533       }
1534     }
1535     if (Operation *tiledAndFusedOp =
1536             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1537       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1538       tiledAndFusedOps.insert(tiledAndFusedOp);
1539     }
1540 
1541     if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1542       return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1543     }
1544   }
1545 
1546   DenseMap<Value, Value> replacements;
1547   for (auto [origVal, resultNumber] : origValToResultNumber) {
1548     replacements[origVal] = loops.front()->getResult(resultNumber);
1549   }
1550 
1551   return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1552                                    replacements};
1553 }
1554 
1555 //===----------------------------------------------------------------------===//
1556 // tileAndFuseConsumerUsingSCF implementation.
1557 //===----------------------------------------------------------------------===//
1558 
1559 /// A utility function that checks whether the only use of the result of a
1560 /// tensor.insert_slice op is in a scf.yield op.
1561 static LogicalResult
1562 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1563   Value result = candidateSliceOp.getResult();
1564   Value::use_range uses = result.getUses();
1565   if (!llvm::hasSingleElement(uses)) {
1566     LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1567     return failure();
1568   }
1569   OpOperand &operandUse = (*uses.begin());
1570   Operation *userOp = operandUse.getOwner();
1571   if (!isa<scf::YieldOp>(userOp)) {
1572     LLVM_DEBUG(llvm::dbgs()
1573                << "Expected scf.yield to be the only user, but got -> "
1574                << (*userOp));
1575     return failure();
1576   }
1577   if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1578     LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1579                                "be in the same block\n");
1580     return failure();
1581   }
1582   return success();
1583 }
1584 
1585 /// An utility to get the first user of the given loopOp. If any of user stay in
1586 /// different block of loopOp, return failure.
1587 static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
1588   if (!isa<LoopLikeOpInterface>(loopOp))
1589     return failure();
1590   Operation *firstUserOfLoop = nullptr;
1591   for (Operation *userOp : loopOp->getUsers()) {
1592     // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1593     // block with any other types of operation. Thus, just redirecting to its
1594     // parent `InParallelOp`. E.g.
1595     //
1596     // ```
1597     // %1 = scf.for {
1598     //   ...
1599     // }
1600     // %2 = consumerOp ins(%1, ...)
1601     // scf.forall.in_parallel {
1602     //    tensor.parallel_insert_slice %1
1603     // }
1604     // ```
1605     // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
1606     // same block with `consumerOp`.
1607     if (isa<tensor::ParallelInsertSliceOp>(userOp))
1608       userOp = userOp->getParentOfType<scf::InParallelOp>();
1609 
1610     if (loopOp->getBlock() != userOp->getBlock())
1611       return failure();
1612 
1613     if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
1614       firstUserOfLoop = userOp;
1615   }
1616   return firstUserOfLoop;
1617 }
1618 
1619 /// This utility currently checks whether the first userOp of loop is NOT before
1620 /// the last defineOp of consumer operand. Because that we need to move the
1621 /// whole loop structure right before the `firstUserOfLoop`. This utility thus
1622 /// helps ensuring that no invalid IR is formed, i.e. no backward slice of
1623 /// consumerOp is dominated by the `firstUserOfLoop`. Saying that:
1624 ///
1625 /// ```
1626 /// %0 = scf.for() {
1627 ///   ...
1628 /// }
1629 /// ...
1630 /// %1 = firstUserOfLoop(%0)
1631 /// ...
1632 /// %2 = lastDefOfConsumerOperand
1633 /// ...
1634 /// %3 = consumerOp(%2)
1635 /// ```
1636 ///
1637 /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it would
1638 /// be invalid to move the `loopOp` right before the `firstUserOfLoop`, a.k.a.
1639 /// use-def chain violation:
1640 ///
1641 /// ```
1642 /// %0:2 = scf.for() {
1643 ///    // use before define error
1644 ///    %3 = tiledConsumerOp(%2)
1645 /// }
1646 /// %1 = firstUserOfLoop(%0)
1647 /// ...
1648 /// %2 = lastDefOfConsumerOperand
1649 /// ```
1650 ///
1651 /// @param loopOp: loop operation
1652 /// @param consumerOp: consumer operation
1653 /// @param reorderOperations: the flag controls whether to reorder the backward
1654 /// slice w.r.t. the defineOp of `consumerOp` operands.
1655 /// @return: computed backward slice of consumerOp, but excluding those already
1656 /// dominates `firstUserOfLoop`.
1657 static FailureOr<llvm::SetVector<Operation *>>
1658 checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
1659                        bool reorderOperations) {
1660   FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1661   if (failed(firstUserOfLoop))
1662     return failure();
1663 
1664   BackwardSliceOptions options;
1665   DominanceInfo dominanceInfo;
1666   options.inclusive = true;
1667   options.omitBlockArguments = true;
1668   bool includeLoopOp = false;
1669   options.filter = [&](Operation *op) {
1670     if (op == loopOp) {
1671       includeLoopOp = true;
1672       return false;
1673     }
1674     // Cut off the slice to not include any operation that already dominates
1675     // firstUserOfLoop.
1676     return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
1677   };
1678   llvm::SetVector<Operation *> slice;
1679   for (auto operand : consumerOp->getOperands()) {
1680     getBackwardSlice(operand, &slice, options);
1681   }
1682 
1683   if (!slice.empty()) {
1684     // If consumerOp has one producer, which is also the user of loopOp.
1685     // E.g.
1686     // ```
1687     //  %0 = %loopOp
1688     //  %1 = consumerOp1 ins(%0)
1689     //  %2 = consumerOp2 ins(%0, %1)
1690     // ```
1691     // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1692     // consumerOp1 has already been fused into loopOp before.
1693     if (includeLoopOp || !reorderOperations)
1694       return failure();
1695   }
1696 
1697   return slice;
1698 }
1699 
1700 /// Fetches the OpOperand of the first valid user (and use) of the value `val`
1701 /// which implements `TilingInterface` and `DestinationStyleOpInterface`.
1702 /// Returns failure otherwise.
1703 static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
1704                                                       Operation *loopOp,
1705                                                       unsigned resultNumber) {
1706   if (!isa<LoopLikeOpInterface>(loopOp))
1707     return failure();
1708   Value val = loopOp->getResult(resultNumber);
1709   Block *loopBlock = loopOp->getBlock();
1710   for (OpOperand &opOperand : val.getUses()) {
1711     Operation *consumerOp = opOperand.getOwner();
1712     // Step 1. Check if the user is tilable.
1713     if (!isa<TilingInterface>(consumerOp) ||
1714         !isa<DestinationStyleOpInterface>(consumerOp)) {
1715       // TODO: We have to init result of consumer before scf.for, use
1716       // DestinationStyleOpInterface to get result shape from init for now. Add
1717       // support for other op such as op has InferTypeOpInterface.
1718       continue;
1719     }
1720     // Step 2. Check if user stay in the same block.
1721     if (loopBlock != consumerOp->getBlock())
1722       continue;
1723     // Step 3. Check if user has succeeding user. Otherwise, it usually
1724     // represents already tiled.
1725     if (consumerOp->use_empty())
1726       continue;
1727     // Step 4. Check assumption for loop with `reorderOperations` enabled.
1728     FailureOr<llvm::SetVector<Operation *>> slice =
1729         checkAssumptionForLoop(loopOp, consumerOp, true);
1730     if (failed(slice))
1731       continue;
1732     // Step 5. If backward sice is not empty, move them before firstUserOfLoop.
1733     if (!slice->empty()) {
1734       mlir::topologicalSort(*slice);
1735       FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1736       assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
1737       for (auto op : *slice) {
1738         rewriter.moveOpBefore(op, *firstUserOfLoop);
1739       }
1740     }
1741     return &opOperand;
1742   }
1743   return failure();
1744 }
1745 
1746 /// Find the perfectly nested loops outside of given loop(included) sorted from
1747 /// outer to inner.
1748 ///
1749 /// E.g.
1750 ///
1751 /// ```
1752 ///  %0 = scf.for()
1753 ///    %1 = scf.for()
1754 ///      %2 = scf.for()
1755 ///         %3 = ...
1756 ///         yield %3
1757 ///      yield %2
1758 ///    yield %1
1759 /// ```
1760 ///
1761 /// This function will return three perfectly nested loops: %0 + %1 + %2, when
1762 /// target inner loop is %2.
1763 static SmallVector<scf::ForOp>
1764 getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
1765   SmallVector<scf::ForOp> nestLoops = {loop};
1766   auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1767 
1768   // Check if it is the ForOp that yield the result of inner loop.
1769   auto isForOpYieldResultOfInnerLoop =
1770       [](scf::ForOp outerLoop) -> LogicalResult {
1771     Block *body = outerLoop.getBody();
1772     if (!llvm::hasSingleElement(body->without_terminator()))
1773       return failure();
1774     auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
1775     auto innerForOp = dyn_cast<scf::ForOp>(body->front());
1776     if (!innerForOp)
1777       return failure();
1778     // All of innerForOp results should be yielded.
1779     return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1780   };
1781 
1782   while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1783     nestLoops.push_back(outerLoop);
1784     outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1785   }
1786   // sorted from outer to inner
1787   return {nestLoops.rbegin(), nestLoops.rend()};
1788 }
1789 
1790 /// Fetch the untiled consumer of a scf.for's result which is yielded by a
1791 /// tensor.insert_slice. This function makes the following assumptions :
1792 /// 1.  tensor.insert_slice has scf.yield as its only user.
1793 /// 2.  scf.for's corresponding result has only one use.
1794 static FailureOr<OpOperand *>
1795 getUntiledConsumerFromSlice(RewriterBase &rewriter,
1796                             tensor::InsertSliceOp candidateSliceOp) {
1797   if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
1798     return failure();
1799   Value sliceResult = candidateSliceOp.getResult();
1800   // Step 1. Fetch the corresponding output.
1801   OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1802   unsigned resultNumber = yieldOpOperand.getOperandNumber();
1803   // Step 2. Check containing op is scf.for.
1804   Operation *containingOp = candidateSliceOp->getParentOp();
1805   auto forOp = dyn_cast<scf::ForOp>(containingOp);
1806   if (!forOp)
1807     return failure();
1808   scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
1809 
1810   return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
1811 }
1812 
1813 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
1814 /// by a tensor.parallel_insert_slice.
1815 static FailureOr<OpOperand *>
1816 getUntiledConsumerFromSlice(RewriterBase &rewriter,
1817                             tensor::ParallelInsertSliceOp candidateSliceOp) {
1818   // Step 1. Fetch the corresponding output
1819   Value sliceDest = candidateSliceOp.getDest();
1820   auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1821   if (!iterArg)
1822     return failure();
1823   Operation *containingOp = iterArg.getOwner()->getParentOp();
1824   if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
1825     return failure();
1826   // Step 2. Check that the containing op is scf.forall.
1827   auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1828   if (!forallOp)
1829     return failure();
1830   unsigned resultNumber =
1831       forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1832           .getResultNumber();
1833 
1834   return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
1835 }
1836 
1837 /// A utility to fetch an untiled consumer of
1838 /// tensor.insert_slice/tensor.parallel_insert_slice.
1839 static FailureOr<OpOperand *>
1840 getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
1841   if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1842     return getUntiledConsumerFromSlice(rewriter, insertSlice);
1843   } else if (auto parallelInsertSlice =
1844                  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1845     return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
1846   } else {
1847     return failure();
1848   }
1849 }
1850 
1851 /// Implementation of fusing consumer of a single slice by computing the
1852 /// slice of the consumer in-place for scf loop.
1853 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1854 mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1855                                       Operation *candidateSliceOp) {
1856   if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1857           candidateSliceOp))
1858     return failure();
1859 
1860   bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1861 
1862   // 1. Get the consumer of scf.for for the result yielded by
1863   // tensor.insert_slice/parallel_insert_slice.
1864   FailureOr<OpOperand *> maybeConsumerOpOperand =
1865       getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
1866   if (failed(maybeConsumerOpOperand)) {
1867     return rewriter.notifyMatchFailure(candidateSliceOp,
1868                                        "could not fetch consumer to fuse");
1869   }
1870   OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1871   Operation *consumerOp = consumerOpOperand->getOwner();
1872   unsigned operandNumber = consumerOpOperand->getOperandNumber();
1873   unsigned resultNumber = 0;
1874   if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
1875     resultNumber = producerResult.getResultNumber();
1876   } else {
1877     return rewriter.notifyMatchFailure(
1878         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
1879   }
1880 
1881   // There are two possible cases regarding `oldLoopOp` here:
1882   // 1. single `scf.forall` or `scf.for`.
1883   // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1884   // top-level loop is the outer-most one of these nested loops.
1885   LoopLikeOpInterface innerMostLoop =
1886       candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
1887   SmallVector<LoopLikeOpInterface> nestedLoops;
1888   if (isInsertSliceOp) {
1889     nestedLoops = llvm::map_to_vector(
1890         getPerfectlyNestedLoopsOutsideOf(
1891             cast<scf::ForOp>(innerMostLoop.getOperation())),
1892         [](scf::ForOp forOp) {
1893           return cast<LoopLikeOpInterface>(forOp.getOperation());
1894         });
1895   } else {
1896     nestedLoops = {innerMostLoop};
1897   }
1898 
1899   LoopLikeOpInterface outerMostLoop = nestedLoops.front();
1900 
1901   // Check assumption for loop with `reorderOperations` disabled.
1902   if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
1903     return rewriter.notifyMatchFailure(
1904         outerMostLoop, "the first user of loop should not dominate any define "
1905                        "of consumer operand(s)");
1906   }
1907 
1908   OpBuilder::InsertionGuard g(rewriter);
1909 
1910   // 2. Check consumer is not using scf loop's output as init.
1911   auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
1912   if (!dstOp)
1913     return rewriter.notifyMatchFailure(consumerOp,
1914                                        "consumer op is not DPS operation");
1915   SmallVector<Value> dpsInits =
1916       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
1917   if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
1918     return rewriter.notifyMatchFailure(
1919         consumerOp,
1920         "consumer op taking the result of scf.for as init is not supported");
1921   }
1922   SmallVector<Value> newInits = dpsInits;
1923 
1924   Location loc = outerMostLoop->getLoc();
1925 
1926   // 3. Move the whole loop structure right before firstUserOfLoop, the
1927   // dominance should be already ensured by `checkAssumptionForLoop`.
1928   FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
1929   if (failed(firstUserOfLoop)) {
1930     return rewriter.notifyMatchFailure(
1931         outerMostLoop, "could not find the first user of outer most loop");
1932   }
1933   rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
1934 
1935   // 4. Set insertion point before terminator op of the loop and create a new
1936   // tensor.insert_slice. In the scf.for case this is a clone of the
1937   // candidateSliceOp whereas in the scf.forall case this is created from the
1938   // operands of tensor.parallel_insert_slice.
1939   tensor::InsertSliceOp clonedInsertSliceOp;
1940   if (auto sliceOp =
1941           dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1942     auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
1943     rewriter.setInsertionPoint(newForallOp.getTerminator());
1944     clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
1945         loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
1946         sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
1947   } else {
1948     rewriter.setInsertionPoint(candidateSliceOp);
1949     clonedInsertSliceOp =
1950         cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
1951   }
1952 
1953   // 5.a. Clone consumer op.
1954   auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
1955 
1956   // 5.b. Replace all uses of the loop result with the result of the cloned
1957   // tensor.insert_slice.
1958   OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
1959   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
1960     operandToReplace.set(clonedInsertSliceOp.getResult());
1961   });
1962 
1963   // 6. Perform tiling of the cloned consumer and replace the operand at
1964   // `operandNumber` with the source of the cloned tensor.insert_slice op.
1965   auto ossSliceOp =
1966       cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
1967   FailureOr<TilingResult> tileAndFuseResult =
1968       tensor::replaceInsertSliceWithTiledConsumer(
1969           rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
1970   if (failed(tileAndFuseResult)) {
1971     return failure();
1972   }
1973   auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
1974   rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
1975                               clonedInsertSliceOp.getSource());
1976 
1977   // 7. Reconstruct [nested] loop with new inits.
1978   YieldTiledValuesFn newYieldValuesFn =
1979       [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1980           ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1981           SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
1982           SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1983     OpBuilder::InsertionGuard g(innerRewriter);
1984     // 8. Set inner insertPoint right before tiled consumer op.
1985     innerRewriter.setInsertionPoint(tiledConsumerOp);
1986 
1987     SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
1988     SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
1989     SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
1990 
1991     // 9. Check all insert stride is 1.
1992     if (llvm::any_of(strides, [](OpFoldResult stride) {
1993           return !isConstantIntValue(stride, 1);
1994         })) {
1995       return rewriter.notifyMatchFailure(
1996           candidateSliceOp, "containingOp's result yield with stride");
1997     }
1998 
1999     // 10. Try to get iter domain position from input position. Use
2000     // clonedConsumerOp instead of tiledConsumerOp, because the iteration domain
2001     // may require index computation based on the result size. The sizes and
2002     // offsets should be the same either way, but using tiledConsumerOp could
2003     // lead to some chained unnecessary extra index computation.
2004     SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
2005     if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
2006             rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
2007             iterDomainSizes))) {
2008       return rewriter.notifyMatchFailure(
2009           clonedConsumerOp,
2010           "can't get iter domain position from input position");
2011     }
2012 
2013     // 11. Try to fetch the offset and size for all results of the cloned
2014     // consumer. This would then be used to form the corresponding
2015     // tensor.insert_slice/parallel_insert_slice later.
2016     unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2017     SmallVector<SmallVector<OpFoldResult>> resultOffsets(
2018         totalNumResultsOfConsumer);
2019     SmallVector<SmallVector<OpFoldResult>> resultSizes(
2020         totalNumResultsOfConsumer);
2021     for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2022       if (failed(tiledConsumerOp.getResultTilePosition(
2023               rewriter, idx, iterDomainOffsets, iterDomainSizes,
2024               resultOffsets[idx], resultSizes[idx]))) {
2025         return rewriter.notifyMatchFailure(
2026             tiledConsumerOp,
2027             "can't get result domain position from iter domain position");
2028       }
2029     }
2030 
2031     // 12. Create `extract_slice` for `iter_args` for DPS operation if
2032     // necessary.
2033     if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2034             tiledConsumerOp.getOperation())) {
2035       rewriter.setInsertionPoint(tiledDestStyleOp);
2036       for (const auto &&[index, newRegionArg] :
2037            llvm::enumerate(newRegionIterArgs)) {
2038         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
2039             loc, newRegionArg, resultOffsets[index], resultSizes[index],
2040             SmallVector<OpFoldResult>(resultOffsets[index].size(),
2041                                       rewriter.getIndexAttr(1)));
2042         // Make a copy of index to avoid a capturing structured binding, which
2043         // is a C++20 extension.
2044         auto dstNumber = index;
2045         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
2046           tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2047         });
2048       }
2049     }
2050 
2051     // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
2052     // caller.
2053     Block *block = rewriter.getInsertionPoint()->getBlock();
2054     rewriter.setInsertionPoint(block->getTerminator());
2055     for (const auto &&[index, result] :
2056          llvm::enumerate(tiledConsumerOp->getResults())) {
2057       tiledResult.push_back(result);
2058       tiledOffset.emplace_back(resultOffsets[index]);
2059       tiledSizes.emplace_back(resultSizes[index]);
2060     }
2061     return success();
2062   };
2063   // 14. Add new inits to [nested] loops.
2064   if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
2065                                        newYieldValuesFn))) {
2066     return rewriter.notifyMatchFailure(tiledConsumerOp,
2067                                        "unable to add new inits to nest loop");
2068   }
2069 
2070   // 15. Replace the result of scf loop and consumer op with new loop's results.
2071 
2072   for (auto &&[oldResult, newResult] : llvm::zip(
2073            consumerOp->getResults(),
2074            nestedLoops.front()->getResults().take_back(newInits.size()))) {
2075     rewriter.replaceAllUsesWith(oldResult, newResult);
2076   }
2077 
2078   // 16. Need to erase the old scf loop and the cloned consumer op.
2079   rewriter.eraseOp(clonedConsumerOp);
2080 
2081   return scf::SCFFuseConsumerOfSliceResult{
2082       consumerOpOperand,
2083       &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
2084       tileAndFuseResult->tiledOps};
2085 }
2086 
2087 //===----------------------------------------------------------------------===//
2088 // lowerToLoopsUsingSCFForOp implementation.
2089 //===----------------------------------------------------------------------===//
2090 
2091 FailureOr<SmallVector<scf::ForOp>>
2092 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
2093                                      TilingInterface op) {
2094   // TODO: Handle cases where the op has results if needed.
2095   if (op->getNumResults() > 0) {
2096     return rewriter.notifyMatchFailure(
2097         op, "unable to lower to loops operations with return values");
2098   }
2099 
2100   SmallVector<Range> domain = op.getIterationDomain(rewriter);
2101   SmallVector<Value> ivs;
2102   SmallVector<scf::ForOp> loops;
2103   Location loc = op.getLoc();
2104   for (auto loopRange : domain) {
2105     Value offsetVal =
2106         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
2107     Value sizeVal =
2108         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
2109     Value strideVal =
2110         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
2111     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
2112                                             strideVal, ValueRange{});
2113     loops.push_back(loop);
2114     ivs.push_back(loop.getInductionVar());
2115     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
2116   }
2117   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
2118     return failure();
2119   }
2120   return loops;
2121 }
2122