xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision 91bbebc7e118cceae1fc0e349de08094a3cd2fe7)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Analysis/TopologicalSortUtils.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Arith/Utils/Utils.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/Dialect/SCF/Utils/Utils.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
28 #include "mlir/Interfaces/TilingInterface.h"
29 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "llvm/ADT/ScopeExit.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
34 #include <optional>
35 
36 #define DEBUG_TYPE "tile-using-interface"
37 
38 using namespace mlir;
39 
40 scf::SCFTilingOptions &
41 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
42   assert(!tileSizeComputationFunction && "tile sizes already set");
43   auto tileSizes = llvm::to_vector(ts);
44   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
45     return tileSizes;
46   };
47   return *this;
48 }
49 
50 scf::SCFTilingOptions &
51 scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
52   assert(!numThreadsComputationFunction && "num tiles already set");
53   auto numThreads = llvm::to_vector(nt);
54   numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
55     return numThreads;
56   };
57   return *this;
58 }
59 
60 /// Helper method to adjust the interchange vector to match the iteration
61 /// domain.
62 static SmallVector<int64_t>
63 fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
64                       size_t iterationDomainSize) {
65   SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
66   if (filledVector.size() < iterationDomainSize) {
67     auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
68     filledVector.append(range.begin(), range.end());
69   }
70   if (filledVector.size() > iterationDomainSize)
71     filledVector.resize(iterationDomainSize);
72   return filledVector;
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // tileUsingSCF implementation.
77 //===----------------------------------------------------------------------===//
78 
79 /// Verify the tile size options are set in a consistent manner.
80 static LogicalResult
81 verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
82                       const scf::SCFTilingOptions &options) {
83   // Specifying number of threads is only supported on `scf.forall` op.
84   if (options.numThreadsComputationFunction &&
85       options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
86     return rewriter.notifyMatchFailure(
87         loc, "number of threads can only by specified when loop type is "
88              "set to use `scf.forall`");
89   }
90 
91   // If specified, check that the interchange vector is a permutation.
92   if (!options.interchangeVector.empty()) {
93     if (!isPermutationVector(options.interchangeVector)) {
94       return rewriter.notifyMatchFailure(
95           loc, "invalid interchange vector, not a permutation of the entire "
96                "iteration space");
97     }
98   }
99   return success();
100 }
101 
102 /// Method to instantiate the tile sizes and/or number of threads specified
103 /// by the user.
104 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
105 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
106                               ArrayRef<Range> iterationDomain,
107                               const scf::SCFTilingOptions &options) {
108   OpFoldResult zero = rewriter.getIndexAttr(0);
109   SmallVector<OpFoldResult> tileSizes, numThreads;
110   size_t numLoops = iterationDomain.size();
111 
112   // Check whether the number of tiles to use is specified.
113   if (options.numThreadsComputationFunction) {
114     numThreads = options.numThreadsComputationFunction(rewriter, op);
115     numThreads.resize(numLoops, zero);
116 
117     // If the number of tiles is also specified, use that.
118     if (options.tileSizeComputationFunction) {
119       tileSizes = options.tileSizeComputationFunction(rewriter, op);
120       tileSizes.resize(numLoops, zero);
121       return {tileSizes, numThreads};
122     }
123 
124     // Compute the tile sizes from the iteration domain and number
125     // of tiles as follows
126     // - niters = ceilDiv(ub - lb, step)
127     // - tileSize = ceilDiv(niters, numThreads)
128     AffineExpr s0, s1, s2;
129     bindSymbols(rewriter.getContext(), s0, s1, s2);
130     // TODO: The step here is assumed to be 1.
131     AffineExpr numItersExpr = (s1 - s0);
132     AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
133     tileSizes.resize(numLoops, zero);
134     for (auto [index, range, nt] :
135          llvm::enumerate(iterationDomain, numThreads)) {
136       if (isConstantIntValue(nt, 0))
137         continue;
138 
139       tileSizes[index] = affine::makeComposedFoldedAffineApply(
140           rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
141     }
142     tileSizes.resize(numLoops, zero);
143     return {tileSizes, numThreads};
144   }
145 
146   // Enforce the convention that "tiling by zero"
147   // skips tiling a particular dimension. This convention is significantly
148   // simpler to handle instead of adjusting affine maps to account for missing
149   // dimensions.
150   assert(options.tileSizeComputationFunction &&
151          "expected tile sizes to be specified");
152   tileSizes = options.tileSizeComputationFunction(rewriter, op);
153   tileSizes.resize(numLoops, zero);
154 
155   return {tileSizes, numThreads};
156 }
157 
158 /// Checks if any of the tiled loops are not parallel.
159 static void checkSafeToTileToForall(TilingInterface op,
160                                     ArrayRef<OpFoldResult> tileSizes,
161                                     ArrayRef<OpFoldResult> numThreads) {
162   auto iterators = op.getLoopIteratorTypes();
163   assert(iterators.size() == tileSizes.size() &&
164          "expected as many tile size values as number of loops");
165   assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
166          "when specified, expected number of threads to use for each loop");
167 
168   for (auto [index, iterator, tileSize] :
169        llvm::enumerate(iterators, tileSizes)) {
170     // If num threads is specified, check that it is greater than one only for
171     // parallel dimensions.
172     if (!numThreads.empty()) {
173       if (std::optional<int64_t> constNumThreads =
174               getConstantIntValue(numThreads[index])) {
175         if (constNumThreads.value() > 1 &&
176             iterator != utils::IteratorType::parallel) {
177           op.emitWarning() << "tiling is not thread safe at axis #" << index;
178         }
179       }
180       continue;
181     }
182 
183     if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
184       if (constTileSize.value() > 0 &&
185           iterator != utils::IteratorType::parallel) {
186         op.emitWarning() << "tiling is not thread safe at axis #" << index;
187       }
188     }
189   }
190 }
191 
192 /// Check if `stride` evenly divides the trip count `size - offset`.
193 static bool tileDividesIterationDomain(Range loopRange) {
194   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
195   if (!offsetAsInt)
196     return false;
197   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
198   if (!sizeAsInt)
199     return false;
200   std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
201   if (!strideAsInt)
202     return false;
203   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
204 }
205 
206 /// Returns the bounded tile size given the current `offset`, `loopRange` and
207 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
208 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
209                                        Range loopRange, OpFoldResult offset,
210                                        OpFoldResult tileSize) {
211   std::optional<int64_t> ts = getConstantIntValue(tileSize);
212   if (ts && ts.value() == 1)
213     return tileSize;
214 
215   if (tileDividesIterationDomain(
216           Range{loopRange.offset, loopRange.size, tileSize}))
217     return tileSize;
218 
219   // The tile size to use (to avoid out of bounds access) is  minimum of
220   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
221   // loop.
222   AffineExpr s0, s1, d0;
223   bindDims(b.getContext(), d0);
224   bindSymbols(b.getContext(), s0, s1);
225   AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
226   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
227   return affine::makeComposedFoldedAffineMin(
228       b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
229 }
230 
231 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
232 /// than `iterationSize`.
233 static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
234                                            OpFoldResult numThreads,
235                                            OpFoldResult iterationSize) {
236   std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
237   std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
238   std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
239   if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
240     return false;
241   return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
242 }
243 
244 /// Compute the `OpFoldResult`s that represents the multi-dimensional
245 /// `offset`s and `size`s of the tile of the iteration space that the
246 /// innermost loop body of the generated tiled loops corresponds to.
247 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
248 getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
249                       ArrayRef<Range> iterationDomain,
250                       ArrayRef<OpFoldResult> tileSizes,
251                       ArrayRef<OpFoldResult> numThreads) {
252   SmallVector<OpFoldResult> offsets, sizes;
253   int materializedLoopNum = 0;
254 
255   if (!numThreads.empty()) {
256     AffineExpr d0, d1, s0, s1;
257     AffineExpr offsetExpr, residualTileSizeExpr;
258     bindDims(rewriter.getContext(), d0, d1);
259     bindSymbols(rewriter.getContext(), s0, s1);
260     offsetExpr = d0 + d1 * s0;
261     residualTileSizeExpr = s1 - (d0 + d1 * s0);
262 
263     for (auto [nt, tileSize, loopRange] :
264          llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
265 
266       // Non-tiled cases, set the offset and size to the
267       // `loopRange.offset/size`.
268       if (isConstantIntValue(nt, 0)) {
269         offsets.push_back(loopRange.offset);
270         sizes.push_back(loopRange.size);
271         continue;
272       }
273 
274       Value iv = ivs[materializedLoopNum++];
275       OpFoldResult offset = affine::makeComposedFoldedAffineApply(
276           rewriter, loc, offsetExpr,
277           ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
278       OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
279           rewriter, loc, residualTileSizeExpr,
280           {loopRange.offset, nt, tileSize, loopRange.size});
281 
282       OpFoldResult size = tileSize;
283       if (!isConstantIntValue(residualTileSize, 0)) {
284         OpFoldResult sizeMinusOffsetPerThread =
285             affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
286                                                   {offset, loopRange.size});
287         size = affine::makeComposedFoldedAffineMin(
288             rewriter, loc,
289             AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
290             {sizeMinusOffsetPerThread, tileSize});
291       }
292 
293       // Consider the case where the original loop was `[0, 100)`.
294       // If number of threads are `7`, the tile size would be computed as
295       // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
296       // - `offset = 0 + 6 * 15 = 105`
297       // - `tileSize = min(15, 100 - 105) = -5`
298       // To avoid negative tile sizes, we need to do a further
299       // `nonNegativeTileSize = affine.max(0, tileSize)`.
300       // This `max` can be avoided if
301       //  `offset + tileSize * (numThreads - 1) < (ub - lb)`
302       if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
303         AffineMap maxMap =
304             AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
305         size = affine::makeComposedFoldedAffineMax(
306             rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
307       }
308 
309       offsets.push_back(offset);
310       sizes.push_back(size);
311     }
312     return {offsets, sizes};
313   } else {
314     for (auto [tileSize, loopRange] :
315          llvm::zip_equal(tileSizes, iterationDomain)) {
316 
317       // Non-tiled cases, set the offset and size to the
318       // `loopRange.offset/size`.
319       if (isConstantIntValue(tileSize, 0)) {
320         offsets.push_back(loopRange.offset);
321         sizes.push_back(loopRange.size);
322         continue;
323       }
324 
325       Value iv = ivs[materializedLoopNum++];
326       OpFoldResult offset = getAsOpFoldResult(iv);
327       offsets.push_back(offset);
328       OpFoldResult size =
329           getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
330       sizes.push_back(size);
331     }
332     return {offsets, sizes};
333   }
334 }
335 
336 /// Function to return the bounds of the loops to be generated.
337 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
338                   SmallVector<OpFoldResult>>
339 getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
340               ArrayRef<OpFoldResult> tileSizes) {
341   SmallVector<OpFoldResult> lbs, ubs, steps;
342   for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
343     // No loop if the tile size is 0.
344     if (isConstantIntValue(tileSize, 0))
345       continue;
346     lbs.push_back(loopRange.offset);
347     ubs.push_back(loopRange.size);
348     steps.push_back(tileSize);
349   }
350   return {lbs, ubs, steps};
351 }
352 
353 /// A function that allows returning additional yielded values during
354 /// `yieldTiledValuesAndReplace`.
355 /// - `ivs` induction variable for the loop.
356 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
357 /// - `tiledValues` the tiled values to return. Must be of same size as
358 ///   `newbbArgs`, each element of this array is inserted into the corresponding
359 ///   element in `newbbArgs`.
360 /// - `resultOffsets` is of the same size as `tiledValues` and represents
361 ///   the offsets to use when inserting corresponding element from `tiledValues`
362 ///   into the element from `newBbArgs`.
363 /// - `resultSizes` is of the same size as `tiledValues` and represents
364 ///   the size of the corresponding element from `tiledValues` inserted into
365 ///   the element from `newBbArgs`.
366 /// In case the method needs to return `failure()` the method is expected
367 /// to clean up any inserted operations.
368 using YieldTiledValuesFn = std::function<LogicalResult(
369     RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
370     SmallVector<Value> &tiledValues,
371     SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
372     SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
373 
374 /// Clones the operation and updates the destination if the operation
375 /// implements the `DestinationStyleOpInterface`.
376 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
377                                                   Operation *op,
378                                                   ValueRange newDestArgs) {
379   Operation *clonedOp = rewriter.clone(*op);
380   if (newDestArgs.empty())
381     return clonedOp;
382   if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
383     destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
384   return clonedOp;
385 }
386 
387 /// Generate the tile-loop nest using `scf.for` operation.
388 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
389 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
390 /// - `destinationTensors` are the init values to use for the outer most loop.
391 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
392 /// most
393 ///    loop.
394 /// - `loops` is an in-out parameter into which the generated loops are
395 ///    populated.
396 static LogicalResult generateLoopNestUsingForOp(
397     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
398     ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
399     YieldTiledValuesFn yieldTiledValuesFn,
400     SmallVector<LoopLikeOpInterface> &loops) {
401   assert(!loopRanges.empty() && "unexpected empty loop ranges");
402   assert(loopRanges.size() == tileSizes.size() &&
403          "expected as many tile sizes as loop ranges");
404   OpBuilder::InsertionGuard guard(rewriter);
405 
406   SmallVector<OpFoldResult> lbs, ubs, steps;
407   std::tie(lbs, ubs, steps) =
408       getLoopBounds(rewriter, loc, loopRanges, tileSizes);
409   SmallVector<Value> lbVals =
410       getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
411   SmallVector<Value> ubVals =
412       getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
413   SmallVector<Value> stepVals =
414       getValueOrCreateConstantIndexOp(rewriter, loc, steps);
415 
416   SmallVector<Value> ivs;
417   for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
418     auto loop =
419         rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
420                                     [](OpBuilder &bodyBuilder, Location bodyLoc,
421                                        Value iv, ValueRange /*iterArgs*/) {});
422     loops.push_back(loop);
423     ivs.push_back(loop.getInductionVar());
424     rewriter.setInsertionPointToEnd(loop.getBody());
425     destinationTensors = loop.getRegionIterArgs();
426   }
427 
428   SmallVector<Value> tiledResults;
429   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
430   if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
431                                 tiledResults, resultOffsets, resultSizes))) {
432     return rewriter.notifyMatchFailure(
433         loc, "failed to generate inner tile loop body");
434   }
435   if (loops.empty())
436     return success();
437 
438   assert(tiledResults.size() == destinationTensors.size() &&
439          "Number of results of body should be equal to number of iter args");
440 
441   // 6. Yield all the results of the tiled operation.
442   SmallVector<Value> yieldedValues;
443   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
444        llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
445                        resultSizes)) {
446     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
447                                            rewriter.getIndexAttr(1));
448     auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
449         loc, tiledValue, destinationTensor, resultOffset, resultSize,
450         resultStride);
451     yieldedValues.push_back(insertSlice);
452   }
453   rewriter.create<scf::YieldOp>(loc, yieldedValues);
454 
455   // Add the scf.yield operations for all the outer loops.
456   for (auto [outerLoop, innerLoop] :
457        llvm::zip_equal(MutableArrayRef(loops).drop_back(),
458                        MutableArrayRef(loops).drop_front())) {
459     rewriter.setInsertionPointToEnd(
460         cast<scf::ForOp>(outerLoop.getOperation()).getBody());
461     rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
462   }
463   return success();
464 }
465 
466 /// Generate the tile-loop nest using `scf.forall` operation.
467 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
468 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
469 /// - `destinationTensors` are the init values to use for the outer most loop.
470 /// - `mappingVector` is the mapping attributes to use for loop construction.
471 ///   Can be empty.
472 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
473 /// most
474 ///    loop.
475 /// - `loops` is an in-out parameter into which the generated loops are
476 ///    populated.
477 static LogicalResult generateLoopNestUsingForallOp(
478     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
479     ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
480     ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
481     YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
482   assert(!loopRanges.empty() && "unexpected empty loop ranges");
483   assert(loopRanges.size() == tileSizes.size() &&
484          "expected as many tile sizes as loop ranges");
485   OpBuilder::InsertionGuard guard(rewriter);
486   SmallVector<OpFoldResult> offsets(loopRanges.size()),
487       sizes(loopRanges.size());
488 
489   std::optional<ArrayAttr> mappingAttr;
490   if (!mappingVector.empty())
491     mappingAttr = rewriter.getArrayAttr(mappingVector);
492 
493   scf::ForallOp forallOp;
494   bool useNumThreads = !numThreads.empty();
495 
496   if (useNumThreads) {
497     // Prune the zero numthreads.
498     SmallVector<OpFoldResult> nonZeroNumThreads;
499     for (auto nt : numThreads) {
500       if (isConstantIntValue(nt, 0))
501         continue;
502       nonZeroNumThreads.push_back(nt);
503     }
504     forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
505                                               destinationTensors, mappingAttr);
506   } else {
507     SmallVector<OpFoldResult> lbs, ubs, steps;
508     std::tie(lbs, ubs, steps) =
509         getLoopBounds(rewriter, loc, loopRanges, tileSizes);
510     forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
511                                               destinationTensors, mappingAttr);
512   }
513   loops.push_back(forallOp);
514 
515   rewriter.setInsertionPoint(forallOp.getTerminator());
516   destinationTensors = forallOp.getRegionOutArgs();
517 
518   SmallVector<Value> tiledResults;
519   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
520   if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
521                          destinationTensors, tiledResults, resultOffsets,
522                          resultSizes)))
523     return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
524 
525   rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
526   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
527        llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
528                        resultSizes)) {
529     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
530                                            rewriter.getIndexAttr(1));
531 
532     rewriter.create<tensor::ParallelInsertSliceOp>(
533         loc, tiledValue, destinationTensor, resultOffset, resultSize,
534         resultStride);
535   }
536   return success();
537 }
538 
539 /// Generate the tile-loop nest using the loop construct specifed in `options`.
540 /// - `options`: Tiling options specified.
541 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
542 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
543 /// - `destinationTensors` are the init values to use for the outer most loop.
544 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
545 /// most
546 ///    loop.
547 /// - `loops` is an in-out parameter into which the generated loops are
548 ///    populated.
549 static LogicalResult generateLoopNest(
550     RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
551     ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
552     ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
553     YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
554   // If the tile sizes are all zero, no loops are generated. Just call the
555   // callback function to handle untiled case.
556   if (llvm::all_of(tileSizes, isZeroIndex)) {
557     SmallVector<Value> tiledResults;
558     SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
559     return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
560                        tiledResults, resultOffsets, resultSizes);
561   }
562   if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
563     return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
564                                       destinationTensors, tiledBodyFn, loops);
565   }
566   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
567     return generateLoopNestUsingForallOp(
568         rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
569         destinationTensors, tiledBodyFn, loops);
570   }
571   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
572 }
573 
574 static FailureOr<SmallVector<Value>>
575 createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
576                               ArrayRef<OpFoldResult> tileSizes,
577                               const scf::SCFTilingOptions &options) {
578   SmallVector<Value> initTensors;
579   Location loc = op->getLoc();
580   switch (options.reductionStrategy) {
581   case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
582     if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
583       return failure();
584     return initTensors;
585   case scf::SCFTilingOptions::ReductionTilingStrategy::
586       PartialReductionOuterReduction: {
587     auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
588     if (!redOp) {
589       return rewriter.notifyMatchFailure(
590           op, "PartialReductionOuterReduction tiling strategy is only supported"
591               "for operations implementing PartialReductionOpInterface");
592     }
593     // Get reduction dimensions.
594     // TODO: PartialReductionOpInterface should really query TilingInterface
595     // itself and find reduction dimensions.
596     SmallVector<int> reductionDims;
597     for (auto [idx, iteratorType] :
598          llvm::enumerate(op.getLoopIteratorTypes())) {
599       if (iteratorType == utils::IteratorType::reduction)
600         reductionDims.push_back(idx);
601     }
602     return redOp.generateInitialTensorForPartialReduction(
603         rewriter, loc, tileSizes, reductionDims);
604   }
605   default:
606     return rewriter.notifyMatchFailure(op,
607                                        "unhandled reduction tiling strategy");
608   }
609 }
610 
611 static FailureOr<TilingResult>
612 getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
613                        ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
614                        ArrayRef<OpFoldResult> sizes,
615                        const scf::SCFTilingOptions &options) {
616   switch (options.reductionStrategy) {
617   case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
618     return op.getTiledImplementation(rewriter, offsets, sizes);
619   case scf::SCFTilingOptions::ReductionTilingStrategy::
620       PartialReductionOuterReduction: {
621     auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
622     if (!redOp) {
623       return rewriter.notifyMatchFailure(
624           op, "PartialReductionOuterReduction tiling strategy is only "
625               "supported for operations "
626               "implementing PartialReductionOpInterface");
627     }
628     // Get reduction dimensions.
629     // TODO: PartialReductionOpInterface should really query TilingInterface
630     // itself and find reduction dimensions.
631     SmallVector<int> reductionDims;
632     for (auto [idx, iteratorType] :
633          llvm::enumerate(op.getLoopIteratorTypes())) {
634       if (iteratorType == utils::IteratorType::reduction)
635         reductionDims.push_back(idx);
636     }
637     return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
638                                         offsets, sizes, reductionDims);
639   }
640   default:
641     return rewriter.notifyMatchFailure(op,
642                                        "unhandled reduction tiling strategy");
643   }
644 }
645 
646 static LogicalResult
647 getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
648                       TilingInterface op, ArrayRef<OpFoldResult> offsets,
649                       ArrayRef<OpFoldResult> sizes,
650                       SmallVector<OpFoldResult> &resultOffset,
651                       SmallVector<OpFoldResult> &resultSize,
652                       const scf::SCFTilingOptions &options) {
653 
654   switch (options.reductionStrategy) {
655   case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
656     return op.getResultTilePosition(rewriter, index, offsets, sizes,
657                                     resultOffset, resultSize);
658   case scf::SCFTilingOptions::ReductionTilingStrategy::
659       PartialReductionOuterReduction: {
660     auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
661     if (!redOp) {
662       return rewriter.notifyMatchFailure(
663           op, "PartialReductionOuterReduction tiling strategy is only supported"
664               "for operations implementing PartialReductionOpInterface");
665     }
666     // Get reduction dimensions.
667     // TODO: PartialReductionOpInterface should really query TilingInterface
668     // itself and find reduction dimensions.
669     SmallVector<int> reductionDims;
670     for (auto [idx, iteratorType] :
671          llvm::enumerate(op.getLoopIteratorTypes())) {
672       if (iteratorType == utils::IteratorType::reduction)
673         reductionDims.push_back(idx);
674     }
675     return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
676                                               resultOffset, resultSize,
677                                               reductionDims);
678   }
679   default:
680     return rewriter.notifyMatchFailure(op,
681                                        "unhandled reduction tiling strategy");
682   }
683 }
684 
685 static FailureOr<MergeResult>
686 mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
687                    ValueRange partialResults,
688                    const scf::SCFTilingOptions &options) {
689   switch (options.reductionStrategy) {
690   case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
691     // No need to merge results for reduction tiling strategy.
692     return MergeResult{{}, partialResults};
693   case scf::SCFTilingOptions::ReductionTilingStrategy::
694       PartialReductionOuterReduction: {
695     auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
696     if (!redOp) {
697       return rewriter.notifyMatchFailure(
698           op, "PartialReductionOuterReduction tiling strategy is only "
699               "supported for operations "
700               "implementing PartialReductionOpInterface");
701     }
702     // Get reduction dimensions.
703     // TODO: PartialReductionOpInterface should really query TilingInterface
704     // itself and find reduction dimensions.
705     SmallVector<int> reductionDims;
706     for (auto [idx, iteratorType] :
707          llvm::enumerate(op.getLoopIteratorTypes())) {
708       if (iteratorType == utils::IteratorType::reduction)
709         reductionDims.push_back(idx);
710     }
711     return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
712                                  reductionDims);
713   }
714   default:
715     return rewriter.notifyMatchFailure(op,
716                                        "unhandled reduction tiling strategy");
717   }
718 }
719 
720 /// Append the specified additional `newInitOperands` operands to the
721 /// loops existing `init` operands (or similar), and replace `loopOp` with
722 /// the new loop that has the additional init operands. The loop body of
723 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
724 /// is called to get the new tiled values returned, and the offset
725 /// and sizes at which the tiled value is inserted into the
726 /// new region iter_args that correspond to the newly added init operands.
727 template <typename LoopType>
728 FailureOr<LoopLikeOpInterface>
729 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
730                                ValueRange newInitOperands,
731                                YieldTiledValuesFn yieldTiledValuesFn) {
732   return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
733 }
734 
735 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
736 template <>
737 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
738     scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
739     YieldTiledValuesFn yieldTiledValuesFn) {
740   OpBuilder::InsertionGuard g(rewriter);
741   Location loc = loopOp.getLoc();
742   rewriter.setInsertionPoint(loopOp);
743 
744   auto inits = llvm::to_vector(loopOp.getInitArgs());
745   inits.append(newInitOperands.begin(), newInitOperands.end());
746   auto newLoop = rewriter.create<scf::ForOp>(
747       loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
748       inits, [](OpBuilder &, Location, Value, ValueRange) {});
749 
750   // Move the loop body to the new op.
751   Block *loopBody = loopOp.getBody();
752   Block *newLoopBody = newLoop.getBody();
753   rewriter.mergeBlocks(
754       loopBody, newLoopBody,
755       newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
756 
757   auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
758   rewriter.setInsertionPoint(yieldOp);
759 
760   SmallVector<Value> tiledValues;
761   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
762   ValueRange newRegionIterArgs =
763       newLoop.getRegionIterArgs().take_back(newInitOperands.size());
764   if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
765                                 newRegionIterArgs, tiledValues, resultOffsets,
766                                 resultSizes))) {
767     rewriter.eraseOp(newLoop);
768     return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
769   }
770 
771   SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
772   for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
773        llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
774                        resultSizes)) {
775     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
776                                            rewriter.getIndexAttr(1));
777     Value insert = rewriter.create<tensor::InsertSliceOp>(
778         yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
779         resultStride);
780     newYieldValues.push_back(insert);
781   }
782 
783   rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
784   rewriter.replaceOp(loopOp,
785                      newLoop->getResults().take_front(loopOp.getNumResults()));
786   return cast<LoopLikeOpInterface>(newLoop.getOperation());
787 }
788 
789 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
790 template <>
791 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
792     scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
793     YieldTiledValuesFn yieldTiledValuesFn) {
794   OpBuilder::InsertionGuard g(rewriter);
795   Location loc = loopOp.getLoc();
796   rewriter.setInsertionPoint(loopOp);
797   auto inits = llvm::to_vector(loopOp.getOutputs());
798   inits.append(newInitOperands.begin(), newInitOperands.end());
799   auto newLoop = rewriter.create<scf::ForallOp>(
800       loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
801       loopOp.getMixedStep(), inits, loopOp.getMapping(),
802       [](OpBuilder &, Location, ValueRange) {});
803 
804   // Move the region of the current block to the newly created op.
805   Block *loopBody = loopOp.getBody();
806   Block *newLoopBody = newLoop.getBody();
807   rewriter.mergeBlocks(
808       loopBody, newLoopBody,
809       newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
810 
811   auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
812   rewriter.setInsertionPoint(terminator);
813   SmallVector<Value> tiledValues;
814   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
815   ValueRange regionIterArgs =
816       newLoop.getRegionIterArgs().take_back(newInitOperands.size());
817   if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
818                                 regionIterArgs, tiledValues, resultOffsets,
819                                 resultSizes))) {
820     rewriter.eraseOp(newLoop);
821     return rewriter.notifyMatchFailure(loopOp,
822                                        "failed to get yielded tiled values");
823   }
824 
825   // Update the terminator.
826   rewriter.setInsertionPointToEnd(terminator.getBody());
827 
828   for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
829            tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
830     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
831                                            rewriter.getIndexAttr(1));
832     rewriter.create<tensor::ParallelInsertSliceOp>(
833         terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
834         resultStride);
835   }
836 
837   rewriter.replaceOp(loopOp,
838                      newLoop->getResults().take_front(loopOp.getNumResults()));
839   return cast<LoopLikeOpInterface>(newLoop.getOperation());
840 }
841 
842 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
843 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
844 /// supported loop type.
845 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
846     LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
847     ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
848   return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>(
849              loopLikeOp.getOperation())
850       .Case<scf::ForOp, scf::ForallOp>(
851           [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
852             return yieldTiledValuesAndReplaceLoop(
853                 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
854           })
855       .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
856         return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
857       });
858 }
859 
860 /// Method to add new init values to a loop nest. Updates `loops` in-place
861 /// with new loops that use the `newInitValues`. The outer-loops are updated
862 /// to yield the new result values of the inner loop. For the innermost loop,
863 /// the call back `getNewYields` is invoked to get the additional values to
864 /// yield form the innermost loop.
865 static LogicalResult addInitOperandsToLoopNest(
866     RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
867     ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
868   SmallVector<scf::ForOp> newLoops;
869   if (loops.empty())
870     return success();
871   OpBuilder::InsertionGuard g(rewriter);
872   rewriter.setInsertionPoint(loops.front());
873 
874   SmallVector<Value> ivs;
875   for (auto &loop : loops.drop_back()) {
876     rewriter.setInsertionPoint(loop);
877 
878     // if loops.size() > 1 we assume that scf.for is used for the loops.
879     auto forLoop = cast<scf::ForOp>(loop.getOperation());
880 
881     // Create a new loop with the new init values for this loop.
882     SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
883     newInits.append(newInitValues.begin(), newInitValues.end());
884     auto newLoop = rewriter.create<scf::ForOp>(
885         forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
886         forLoop.getStep(), newInits,
887         [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
888 
889     // Merge the body of the new loop with the body of the old loops.
890     SmallVector<Value> sourceBlockArgs;
891     sourceBlockArgs.push_back(newLoop.getInductionVar());
892     auto newRegionIterArgs = newLoop.getRegionIterArgs();
893     sourceBlockArgs.append(
894         newRegionIterArgs.begin(),
895         std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
896     rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
897     rewriter.replaceOp(
898         forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
899     loop = newLoop;
900     ivs.push_back(newLoop.getInductionVar());
901     newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
902   }
903 
904   // Update the loop body of the innermost loop to get new yield values.
905   LoopLikeOpInterface innerMostLoop = loops.back();
906   FailureOr<LoopLikeOpInterface> newInnerMostLoop =
907       yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
908                                      getNewTiledYieldsFn);
909 
910   if (failed(newInnerMostLoop))
911     return innerMostLoop.emitOpError("failed to return additional yields");
912   loops.back() = newInnerMostLoop.value();
913 
914   // Make all other loops except the innermost loops yield the values returned
915   // by the inner loop.
916   for (auto [outerLoop, innerLoop] :
917        llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
918     // Again assume that all the outer loops are scf.for operations.
919     auto outerForLoop = cast<scf::ForOp>(outerLoop);
920     auto outerLoopYield =
921         cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
922     SmallVector<Value> newYields =
923         llvm::to_vector(outerLoopYield.getOperands());
924     ValueRange additionalYields =
925         innerLoop->getResults().take_back(newInitValues.size());
926     newYields.append(additionalYields.begin(), additionalYields.end());
927     rewriter.setInsertionPoint(outerLoopYield);
928     rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
929   }
930   return success();
931 }
932 
933 /// Implementation of tiling transformation of `op` that implements the
934 /// `TilingInterface` using `scf.for` to iterate over the tiles.
935 FailureOr<scf::SCFTilingResult>
936 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
937                         const scf::SCFTilingOptions &options) {
938   if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
939     return failure();
940   }
941 
942   OpBuilder::InsertionGuard guard(rewriter);
943   rewriter.setInsertionPointAfter(op);
944 
945   // 1. Get the range of the loops that are represented by the operation.
946   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
947 
948   // 2. Materialize the tile sizes and/or number of threads;
949   SmallVector<OpFoldResult> tileSizes, numThreads;
950   std::tie(tileSizes, numThreads) =
951       getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
952 
953   // Check if it is safe to tile. This is hold over from previous iterations
954   // of tile to for-all. Consider dropping it.
955   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
956     checkSafeToTileToForall(op, tileSizes, numThreads);
957   }
958 
959   // 3. If there is an interchange specified, permute the iteration domain and
960   // the tile sizes.
961   SmallVector<int64_t> interchangeVector;
962   if (!options.interchangeVector.empty()) {
963     interchangeVector = fillInterchangeVector(options.interchangeVector,
964                                               iterationDomain.size());
965     assert(isPermutationVector(interchangeVector) &&
966            "expected interchange vector to be a permutation");
967 
968     applyPermutationToVector(iterationDomain, interchangeVector);
969     applyPermutationToVector(tileSizes, interchangeVector);
970     if (!numThreads.empty())
971       applyPermutationToVector(numThreads, interchangeVector);
972   }
973 
974   FailureOr<TilingResult> tilingResult;
975   // 4. Define the lambda function used later to generate the body of the
976   // innermost tiled loop.
977   YieldTiledValuesFn innerYieldTiledValuesFn =
978       [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
979           ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
980           SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
981           SmallVector<SmallVector<OpFoldResult>> &resultSizes)
982       -> LogicalResult {
983     // 4a. Compute the `offsets` and `sizes` to use for tiling.
984     SmallVector<OpFoldResult> offsets, sizes;
985     std::tie(offsets, sizes) = getTileOffsetAndSizes(
986         rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
987 
988     // 4b. If interchange was provided, apply inverse of the interchange
989     //     to get back the offsets/sizes in the order to be specified.
990     if (!interchangeVector.empty()) {
991       auto inversePermutation = invertPermutationVector(interchangeVector);
992       applyPermutationToVector(offsets, inversePermutation);
993       applyPermutationToVector(sizes, inversePermutation);
994     }
995 
996     // 5. Generate the tiled implementation within the inner most loop.
997 
998     // 5a. Clone the operation within the loop body.
999     auto clonedOp = cast<TilingInterface>(
1000         cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
1001 
1002     // 5b. Early return cloned op if tiling is not happening. We can not
1003     // return the original op because it could lead to `rewriter.replaceOp(op,
1004     // op->getResults())` and users would get crash.
1005     if (llvm::all_of(tileSizes, isZeroIndex)) {
1006       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1007       tilingResult =
1008           TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
1009                        /*generatedSlices=*/{}};
1010       return success();
1011     }
1012 
1013     // 5c. Tile the cloned operation.
1014     tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs,
1015                                           offsets, sizes, options);
1016     if (failed(tilingResult)) {
1017       rewriter.eraseOp(clonedOp);
1018       return op.emitOpError("faild to tile operation");
1019     }
1020 
1021     // 5d. Delete the cloned operation.
1022     rewriter.eraseOp(clonedOp);
1023 
1024     // 5e. Compute the offsets at which the result values are to be inserted
1025     //     back into its destinations.
1026     for (auto [index, tiledValue] :
1027          llvm::enumerate(tilingResult->tiledValues)) {
1028       tiledResults.push_back(tiledValue);
1029       SmallVector<OpFoldResult> resultOffset, resultSize;
1030       if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets,
1031                                        sizes, resultOffset, resultSize,
1032                                        options))) {
1033         for (auto op : tilingResult->tiledOps) {
1034           rewriter.eraseOp(op);
1035         }
1036         return rewriter.notifyMatchFailure(
1037             op, "failed to get slice of result produced");
1038       }
1039       resultOffsets.emplace_back(std::move(resultOffset));
1040       resultSizes.emplace_back(std::move(resultSize));
1041     }
1042 
1043     return success();
1044   };
1045 
1046   // 6. Find the destination tensors to use for the operation.
1047   FailureOr<SmallVector<Value>> maybeInits =
1048       createInitialTensorsForTiling(rewriter, op, tileSizes, options);
1049   if (failed(maybeInits)) {
1050     return rewriter.notifyMatchFailure(
1051         op, "unable to create initial tensors for tiling");
1052   }
1053   SmallVector<Value> &initTensors = maybeInits.value();
1054 
1055   // 7. Generate the tiled loops nest using the callback defined above.
1056   SmallVector<LoopLikeOpInterface> loops;
1057   if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
1058                               tileSizes, numThreads, initTensors,
1059                               innerYieldTiledValuesFn, loops)))
1060     return op.emitOpError("failed to generate tiling loops");
1061   assert(succeeded(tilingResult) &&
1062          "expected tiling result to be computed after loop generation");
1063 
1064   SmallVector<Value> partialResults;
1065   if (loops.empty()) {
1066     // If loops are empty, the tiled op is used as the replacement for the
1067     // untiled op.
1068     partialResults = tilingResult->tiledValues;
1069   } else {
1070     partialResults = llvm::map_to_vector(loops.front()->getResults(),
1071                                          [](OpResult r) -> Value { return r; });
1072   }
1073 
1074   FailureOr<MergeResult> mergeResult =
1075       mergeTilingResults(rewriter, op, partialResults, options);
1076   if (failed(mergeResult)) {
1077     return rewriter.notifyMatchFailure(
1078         op, "Failed to merge partial results from tiling");
1079   }
1080 
1081   return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
1082                               mergeResult.value(),
1083                               tilingResult->generatedSlices};
1084 }
1085 
1086 FailureOr<scf::SCFTilingResult>
1087 mlir::scf::tileReductionUsingScf(RewriterBase &b,
1088                                  PartialReductionOpInterface op,
1089                                  ArrayRef<OpFoldResult> tileSizes) {
1090   SCFTilingOptions options;
1091   options.setLoopType(SCFTilingOptions::LoopType::ForOp);
1092   options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy::
1093                                          PartialReductionOuterReduction);
1094   options.setTileSizes(tileSizes);
1095 
1096   TilingInterface tilingInterfaceOp =
1097       dyn_cast<TilingInterface>(op.getOperation());
1098   if (!tilingInterfaceOp) {
1099     return b.notifyMatchFailure(
1100         op,
1101         "Operation implementing PartialReductionOpInterface should implement "
1102         "TilingInterface");
1103   }
1104 
1105   return tileUsingSCF(b, tilingInterfaceOp, options);
1106 }
1107 
1108 //===----------------------------------------------------------------------===//
1109 // tileConsumerAndFuseProducersUsingSCF implementation.
1110 //===----------------------------------------------------------------------===//
1111 
1112 /// Return the untiled producer whose slice is used in a tiled consumer. The
1113 /// method traverses the tile loop nest (`loops`) if needed, and returns the
1114 /// `iter_args` of the outer most that is encountered. Traversing the
1115 /// iter_args indicates that this is a destination operand of the consumer. If
1116 /// there was no loop traversal needed, the second value of the returned tuple
1117 /// is empty.
1118 static std::tuple<OpResult, std::optional<OpOperand *>>
1119 getUntiledProducerFromSliceSource(OpOperand *source,
1120                                   ArrayRef<LoopLikeOpInterface> loops) {
1121   std::optional<OpOperand *> destinationIterArg;
1122   auto loopIt = loops.rbegin();
1123   while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
1124     auto loop = *loopIt;
1125     if (iterArg.getOwner()->getParentOp() != loop)
1126       break;
1127     source = loop.getTiedLoopInit(iterArg);
1128     loopIt++;
1129   }
1130   if (loopIt == loops.rend())
1131     destinationIterArg = source;
1132   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
1133 }
1134 
1135 /// Implementation of fusing producer of a single slice by computing the
1136 /// slice of the producer in-place.
1137 std::optional<scf::SCFFuseProducerOfSliceResult>
1138 mlir::scf::tileAndFuseProducerOfSlice(
1139     RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1140     MutableArrayRef<LoopLikeOpInterface> loops) {
1141   // 1. Get the producer of the source (potentially walking through
1142   // `iter_args` of nested `scf.for`)
1143   auto [fusableProducer, destinationInitArg] =
1144       getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1145                                         loops);
1146   if (!fusableProducer)
1147     return std::nullopt;
1148   unsigned resultNumber = fusableProducer.getResultNumber();
1149 
1150   OpBuilder::InsertionGuard g(rewriter);
1151   rewriter.setInsertionPoint(candidateSliceOp);
1152 
1153   // 2. Clone the fused producer
1154   // 2a. Compute the destination operands to use for the cloned operation.
1155   SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1156   Operation *fusableProducerOp = fusableProducer.getOwner();
1157   if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1158       failed(tensor::getOrCreateDestinations(
1159           rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1160           origDestinationTensors)))
1161     return std::nullopt;
1162 
1163   clonedOpDestinationTensors = origDestinationTensors;
1164   if (destinationInitArg &&
1165       isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1166     // 2b. If the producer is also destination style, then to maintain the
1167     // destination passing style, update the destination of the producer to be
1168     // the source of the slice.
1169     clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1170   }
1171   // 2c. Clone the fused producer.
1172   Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1173       rewriter, fusableProducerOp, clonedOpDestinationTensors);
1174   // 2d. Update the source of the candidateSlice to be the cloned producer.
1175   //     Easier to just clone the slice with different source since
1176   //     replacements and DCE of cloned ops becomes easier
1177   SmallVector<Value> candidateSliceOpOperands =
1178       llvm::to_vector(candidateSliceOp->getOperands());
1179   candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
1180   tensor::ExtractSliceOp clonedCandidateSliceOp =
1181       mlir::clone(rewriter, candidateSliceOp,
1182                   candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1183 
1184   // 3. Generate the tiled implementation of the producer of the source
1185   FailureOr<TilingResult> tileAndFuseResult =
1186       tensor::replaceExtractSliceWithTiledProducer(
1187           rewriter, clonedCandidateSliceOp,
1188           clonedProducerOp->getResult(resultNumber));
1189   if (failed(tileAndFuseResult))
1190     return std::nullopt;
1191   // Note: Do not delete the candidateSliceOp, since its passed in from the
1192   // caller.
1193   rewriter.replaceAllUsesWith(candidateSliceOp,
1194                               tileAndFuseResult->tiledValues[0]);
1195   rewriter.eraseOp(clonedCandidateSliceOp);
1196   rewriter.eraseOp(clonedProducerOp);
1197 
1198   // 3. If the slice is for a destination operand, for example,
1199   //
1200   // ```mlir
1201   // %0 = linalg.init
1202   // %1 = linalg.fill .. outs(%0 : )
1203   // %2 = scf.for .. iter_args(%arg0 = %1) {
1204   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
1205   //     %4 = tensor.extract_slice %arg1 [..]
1206   //     .. = linalg.matmul .. outs(%4 : )
1207   //   }
1208   // }
1209   // ```
1210   //
1211   // the IR is currently
1212   //
1213   // ```
1214   // %0 = linalg.init
1215   // %1 = linalg.fill
1216   // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1217   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
1218   //     %4 = tensor.extract_slice %arg1[..]
1219   //     %5 = linalg.fill .. outs(%4 : )
1220   //     .. = linalg.matmul .. outs(%5 : )
1221   //   }
1222   // }
1223   // ```
1224   //
1225   // The untiled `linalg.fill` is still used as the `init_value` since it
1226   // was originally a destination operand of the untiled `linalg.matmul`.
1227   // When fusing an operand that is a destination operand, the iter_arg of
1228   // the outer most loop should be changed to use the destination of the
1229   // fused operation. With this the IR will be.
1230   //
1231   // ```
1232   // %0 = linalg.init
1233   // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1234   //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
1235   //     %3 = tensor.extract_slice %arg1[..]
1236   //     %4 = linalg.fill .. outs(%3 : )
1237   //     .. = linalg.matmul .. outs(%4 : )
1238   //   }
1239   // }
1240   // ```
1241   if (destinationInitArg &&
1242       isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1243     loops.front()
1244         ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1245         .set(origDestinationTensors[resultNumber]);
1246   }
1247   return scf::SCFFuseProducerOfSliceResult{
1248       fusableProducer, tileAndFuseResult->tiledValues[0],
1249       tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1250 }
1251 
1252 /// Reconstruct the fused producer from within the tiled-and-fused code.
1253 FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1254     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1255     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1256     MutableArrayRef<LoopLikeOpInterface> loops,
1257     ArrayRef<unsigned> yieldResultNumber) {
1258   if (loops.empty())
1259     return success();
1260 
1261   Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1262             *tiledOwner = fusedProducerInfo.tiledOps[0];
1263 
1264   Location loc = originalOwner->getLoc();
1265   // a. collect all init Value to be appended
1266   SmallVector<unsigned> initNumberList =
1267       yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1268                                       0, originalOwner->getNumResults()))
1269                                 : llvm::to_vector(yieldResultNumber);
1270   SmallVector<Value> initValueList;
1271   for (const auto &resultNumber : initNumberList) {
1272     FailureOr<Value> initValue = tensor::getOrCreateDestination(
1273         rewriter, loc, originalOwner->getResult(resultNumber));
1274     if (succeeded(initValue)) {
1275       initValueList.push_back(initValue.value());
1276     } else {
1277       return failure();
1278     }
1279   }
1280 
1281   SmallVector<Operation *> generatedSlices;
1282   YieldTiledValuesFn newYieldValuesFn =
1283       [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1284           ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1285           SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
1286           SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1287     OpBuilder::InsertionGuard g(innerRewriter);
1288 
1289     // get sliceOp tile information
1290     SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1291                               sliceSizes = sliceOp.getMixedSizes();
1292 
1293     // expect all strides of sliceOp being 1
1294     if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
1295           return !isConstantIntValue(ofr, 1);
1296         }))
1297       return failure();
1298 
1299     unsigned sliceResultNumber =
1300         fusedProducerInfo.origProducer.getResultNumber();
1301 
1302     auto tilableOp = cast<TilingInterface>(originalOwner);
1303     // b. get iterDomain Offset and Sizes based on sliceOp tile
1304     SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1305     // skip tensor.pack/unpack/pad, which expects single opResult
1306     if (tilableOp->getNumResults() > 1 &&
1307         failed(tilableOp.getIterationDomainTileFromResultTile(
1308             rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1309             iterDomainOffset, iterDomainSizes))) {
1310       // In theory, it is unnecessary to raise an error here. Actually
1311       // although it fails to reconstruct the result tensor, it should not
1312       // broke current fusion anyway. The reason why we must return failure
1313       // currently is that the callback function `newYieldValuesFn` will be
1314       // called after new init operand(s) has already been appended. It will
1315       // take more refactoring to make sure the init operands are added
1316       // consistently in the future. For more details, please refer to:
1317       // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1318       return failure();
1319     }
1320 
1321     // c. calculate offsets and sizes info of all OpResults respectively based
1322     // on iteration Domain Tile
1323     SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1324     for (const auto &resultNumber : initNumberList) {
1325       if (resultNumber == sliceResultNumber) {
1326         offsetList.push_back(sliceOffset);
1327         sizesList.push_back(sliceSizes);
1328       } else {
1329         assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1330         // infer result tile according to the iteration domain tile
1331         SmallVector<OpFoldResult> offset, sizes;
1332         if (failed(tilableOp.getResultTilePosition(
1333                 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1334                 offset, sizes))) {
1335           return failure();
1336         }
1337         offsetList.push_back(offset);
1338         sizesList.push_back(sizes);
1339       }
1340     }
1341 
1342     // d. create `extract_slice` for `iter_args` for DPS operation if
1343     // necessary
1344     if (auto tiledDestStyleOp =
1345             dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1346       rewriter.setInsertionPoint(tiledDestStyleOp);
1347       for (const auto &&[index, newRegionArg] :
1348            llvm::enumerate(newRegionIterArgs)) {
1349         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1350             loc, newRegionArg, offsetList[index], sizesList[index],
1351             SmallVector<OpFoldResult>(offsetList[index].size(),
1352                                       rewriter.getIndexAttr(1)));
1353         generatedSlices.push_back(destSlice);
1354         unsigned resultNumber = initNumberList[index];
1355         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1356           tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1357         });
1358       }
1359     }
1360 
1361     // e. prepare tiled offset and sizes for later `insert_slice` creation by
1362     // caller
1363     Block *block = rewriter.getInsertionPoint()->getBlock();
1364     rewriter.setInsertionPoint(block->getTerminator());
1365     for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1366       tiledResult.push_back(tiledOwner->getResult(resultNumber));
1367       tiledOffset.emplace_back(offsetList[index]);
1368       tiledSizes.emplace_back(sizesList[index]);
1369     }
1370     return success();
1371   };
1372 
1373   if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
1374                                        newYieldValuesFn))) {
1375     return failure();
1376   }
1377   return generatedSlices;
1378 }
1379 
1380 namespace {
1381 
1382 //===----------------------------------------------------------------------===//
1383 // SliceTrackingListener
1384 //===----------------------------------------------------------------------===//
1385 
1386 /// This class is a listener for tracking the insertion and removal of
1387 /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1388 /// fusion algorithm to apply cleanup patterns in between fusion steps.
1389 class SliceTrackingListener : public RewriterBase::Listener {
1390 public:
1391   explicit SliceTrackingListener(
1392       std::optional<FrozenRewritePatternSet> patterns);
1393   SliceTrackingListener() = default;
1394 
1395   /// Adds the given list of operations to the worklist, and if present,
1396   /// applies the list of `patterns` to the newly added operations. This only
1397   /// processes the given operations and any newly inserted ones by the
1398   /// pattern set.
1399   LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1400 
1401   /// Add to the new operation worklist if it is an extract_slice.
1402   void notifyOperationInserted(Operation *op,
1403                                OpBuilder::InsertPoint previous) override;
1404 
1405   /// Shared helper for operation removal from the worklist.
1406   void removeOp(Operation *op);
1407 
1408   /// Remove the operation from the worklist.
1409   void notifyOperationErased(Operation *op) override;
1410 
1411   /// Remove the operation from the worklist.
1412   void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1413 
1414   /// The worklist for this transformation keeps track of the slices to visit
1415   /// next for fusion.
1416   std::deque<tensor::ExtractSliceOp> worklist;
1417 
1418 private:
1419   /// Optional pattern set to apply when adding new operations to the
1420   /// worklist.
1421   std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1422 };
1423 
1424 SliceTrackingListener::SliceTrackingListener(
1425     std::optional<FrozenRewritePatternSet> p) {
1426   patterns = std::move(p);
1427 }
1428 
1429 LogicalResult
1430 SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1431   for (Operation *op : ops) {
1432     if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1433       worklist.push_back(slice);
1434   }
1435 
1436   if (!patterns)
1437     return success();
1438 
1439   GreedyRewriteConfig config;
1440   config.listener = this;
1441   config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
1442   return applyOpPatternsGreedily(ops, patterns.value(), config);
1443 }
1444 
1445 void SliceTrackingListener::notifyOperationInserted(
1446     Operation *op, OpBuilder::InsertPoint previous) {
1447   auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1448   if (!slice)
1449     return;
1450   worklist.push_back(slice);
1451 }
1452 
1453 // Scan the worklist for the given op and remove it if present. The
1454 // expectation is for the worklist to be small and for removal to be
1455 // relatively rare.
1456 void SliceTrackingListener::removeOp(Operation *op) {
1457   if (!isa<tensor::ExtractSliceOp>(op))
1458     return;
1459   auto iter = worklist.begin();
1460   while (iter != worklist.end()) {
1461     if (*iter == op)
1462       break;
1463     iter++;
1464   }
1465   if (iter == worklist.end())
1466     return;
1467 
1468   worklist.erase(iter);
1469 }
1470 
1471 void SliceTrackingListener::notifyOperationErased(Operation *op) {
1472   removeOp(op);
1473 }
1474 
1475 void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1476                                                     ValueRange replacement) {
1477   removeOp(op);
1478 }
1479 
1480 //===----------------------------------------------------------------------===//
1481 // ReplacementListener
1482 //===----------------------------------------------------------------------===//
1483 
1484 /// Listener that tracks updates replacements for values which can be mutated.
1485 /// This listener runs on top of the existing listener for the rewriter,
1486 /// to make sure external users can still run listeners.
1487 class ReplacementListener : public RewriterBase::ForwardingListener {
1488 public:
1489   ReplacementListener(DenseMap<Value, Value> &replacements,
1490                       OpBuilder::Listener *listener)
1491       : ForwardingListener(listener), replacements(replacements) {}
1492 
1493   void updateReplacementValues(ValueRange origValues,
1494                                ValueRange replaceValues) {
1495     // This can probably be written better, but just iterates over the map
1496     // and the new replacements for now.
1497     for (auto &[key, val] : replacements) {
1498       for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1499         if (val == orig) {
1500           val = replace;
1501         }
1502       }
1503     }
1504   }
1505 
1506   void notifyOperationReplaced(Operation *op, Operation *newOp) override {
1507     ForwardingListener::notifyOperationReplaced(op, newOp);
1508     updateReplacementValues(op->getResults(), newOp->getResults());
1509   }
1510 
1511   void notifyOperationReplaced(Operation *op, ValueRange values) override {
1512     ForwardingListener::notifyOperationReplaced(op, values);
1513     updateReplacementValues(op->getResults(), values);
1514   }
1515 
1516 private:
1517   DenseMap<Value, Value> &replacements;
1518 };
1519 
1520 } // namespace
1521 
1522 /// Implementation of tile consumer and fuse producer greedily.
1523 FailureOr<scf::SCFTileAndFuseResult>
1524 mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1525     RewriterBase &rewriter, TilingInterface consumer,
1526     const scf::SCFTileAndFuseOptions &options) {
1527   // This transformation is only valid for ops that return values (i.e. not
1528   // valid to use with operations that have memref operands).
1529   if (!consumer->getNumResults()) {
1530     return rewriter.notifyMatchFailure(
1531         consumer, "invalid pattern for op with no results");
1532   }
1533 
1534   // 1. First tile the consumer.
1535   SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1536   llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1537 
1538   FailureOr<scf::SCFTilingResult> tilingResult =
1539       tileUsingSCF(rewriter, consumer, options.tilingOptions);
1540 
1541   if (failed(tilingResult))
1542     return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1543   for (auto *tiledOp : tilingResult->tiledOps)
1544     tiledAndFusedOps.insert(tiledOp);
1545 
1546   DenseMap<Value, Value> replacements;
1547   for (auto [origVal, replacement] : llvm::zip_equal(
1548            consumer->getResults(), tilingResult->mergeResult.replacements)) {
1549     replacements[origVal] = replacement;
1550   }
1551 
1552   // If there are no loops generated, fusion is immaterial.
1553   auto &loops = tilingResult->loops;
1554   if (loops.empty()) {
1555     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1556                                      replacements};
1557   }
1558 
1559   // Since the loop gets potentially replaced during fusion, we need to track
1560   // the mutation of replacement values. To do this, we attach a listener to
1561   // update the replacements as they happen.
1562   OpBuilder::Listener *previousListener = rewriter.getListener();
1563   auto resetListener =
1564       llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
1565   ReplacementListener replaceListener(replacements, previousListener);
1566   rewriter.setListener(&replaceListener);
1567 
1568   // 2. Typically, the operands of the tiled operation are slices of the
1569   //    operands of the untiled operation. These are expressed in IR using
1570   //    `tensor.extract_slice` operations with source being the operands of
1571   //    the untiled operation. Create a worklist of these
1572   //    `tensor.extract_slice` operations. If the producers of the source of
1573   //    the `tensor.extract_slice` can be tiled such that the tiled value is
1574   //    generated in-place, that effectively tiles + fuses the operations.
1575   struct WorklistItem {
1576     tensor::ExtractSliceOp candidateSlice;
1577     SCFTileAndFuseOptions::ControlFnResult controlFnResult;
1578   };
1579 
1580   SliceTrackingListener sliceTracker =
1581       SliceTrackingListener(options.cleanupPatterns);
1582 
1583   if (failed(
1584           sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1585     return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1586   }
1587   OpBuilder::InsertionGuard g(rewriter);
1588   while (!sliceTracker.worklist.empty()) {
1589     auto candidateSlice = sliceTracker.worklist.front();
1590     sliceTracker.worklist.pop_front();
1591 
1592     auto [fusableProducer, destinationInitArg] =
1593         getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1594                                           loops);
1595     if (!fusableProducer)
1596       continue;
1597 
1598     std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1599         options.fusionControlFn(candidateSlice, fusableProducer,
1600                                 destinationInitArg.has_value());
1601     if (!controlFnResult)
1602       continue;
1603 
1604     WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1605 
1606     // The operands of the fused producer might themselved be slices of
1607     // values produced by operations that implement the `TilingInterface`.
1608     // Add these operations to the worklist.
1609     std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1610         tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1611                                    loops);
1612     if (!fusedResult)
1613       continue;
1614 
1615     SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1616 
1617     if (worklistItem.controlFnResult.yieldProducerReplacement) {
1618       // Reconstruct and yield all opResult of fusableProducerOp by default.
1619       // The caller can specific which one to yield by designating optional
1620       // argument named `yieldResultNumber` of
1621       // `yieldReplacementForFusedProducer`.
1622       Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1623       FailureOr<SmallVector<Operation *>> newSlices =
1624           yieldReplacementForFusedProducer(rewriter,
1625                                            worklistItem.candidateSlice,
1626                                            fusedResult.value(), loops);
1627       if (failed(newSlices)) {
1628         return rewriter.notifyMatchFailure(
1629             fusableProducerOp, "failed to replacement value for this "
1630                                "operation from within the tiled loop");
1631       }
1632       worklistCandidates.append(newSlices.value());
1633       for (auto [index, result] :
1634            llvm::enumerate(fusableProducerOp->getResults())) {
1635         replacements[result] = loops.front()->getResult(
1636             loops.front()->getNumResults() -
1637             fusableProducerOp->getNumResults() + index);
1638       }
1639     }
1640     if (Operation *tiledAndFusedOp =
1641             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1642       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1643       tiledAndFusedOps.insert(tiledAndFusedOp);
1644     }
1645 
1646     if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1647       return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1648     }
1649   }
1650 
1651   return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1652                                    replacements};
1653 }
1654 
1655 //===----------------------------------------------------------------------===//
1656 // tileAndFuseConsumerUsingSCF implementation.
1657 //===----------------------------------------------------------------------===//
1658 
1659 /// A utility function that checks whether the only use of the result of a
1660 /// tensor.insert_slice op is in a scf.yield op.
1661 static LogicalResult
1662 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1663   Value result = candidateSliceOp.getResult();
1664   Value::use_range uses = result.getUses();
1665   if (!llvm::hasSingleElement(uses)) {
1666     LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1667     return failure();
1668   }
1669   OpOperand &operandUse = (*uses.begin());
1670   Operation *userOp = operandUse.getOwner();
1671   if (!isa<scf::YieldOp>(userOp)) {
1672     LLVM_DEBUG(llvm::dbgs()
1673                << "Expected scf.yield to be the only user, but got -> "
1674                << (*userOp));
1675     return failure();
1676   }
1677   if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1678     LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1679                                "be in the same block\n");
1680     return failure();
1681   }
1682   return success();
1683 }
1684 
1685 /// An utility to get the first user of the given loopOp. If any of user stay
1686 /// in different block of loopOp, return failure.
1687 static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
1688   if (!isa<LoopLikeOpInterface>(loopOp))
1689     return failure();
1690   Operation *firstUserOfLoop = nullptr;
1691   for (Operation *userOp : loopOp->getUsers()) {
1692     // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1693     // block with any other types of operation. Thus, just redirecting to its
1694     // parent `InParallelOp`. E.g.
1695     //
1696     // ```
1697     // %1 = scf.for {
1698     //   ...
1699     // }
1700     // %2 = consumerOp ins(%1, ...)
1701     // scf.forall.in_parallel {
1702     //    tensor.parallel_insert_slice %1
1703     // }
1704     // ```
1705     // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
1706     // same block with `consumerOp`.
1707     if (isa<tensor::ParallelInsertSliceOp>(userOp))
1708       userOp = userOp->getParentOfType<scf::InParallelOp>();
1709 
1710     if (loopOp->getBlock() != userOp->getBlock())
1711       return failure();
1712 
1713     if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
1714       firstUserOfLoop = userOp;
1715   }
1716   return firstUserOfLoop;
1717 }
1718 
1719 /// This utility currently checks whether the first userOp of loop is NOT
1720 /// before the last defineOp of consumer operand. Because that we need to move
1721 /// the whole loop structure right before the `firstUserOfLoop`. This utility
1722 /// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
1723 /// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
1724 ///
1725 /// ```
1726 /// %0 = scf.for() {
1727 ///   ...
1728 /// }
1729 /// ...
1730 /// %1 = firstUserOfLoop(%0)
1731 /// ...
1732 /// %2 = lastDefOfConsumerOperand
1733 /// ...
1734 /// %3 = consumerOp(%2)
1735 /// ```
1736 ///
1737 /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
1738 /// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
1739 /// a.k.a. use-def chain violation:
1740 ///
1741 /// ```
1742 /// %0:2 = scf.for() {
1743 ///    // use before define error
1744 ///    %3 = tiledConsumerOp(%2)
1745 /// }
1746 /// %1 = firstUserOfLoop(%0)
1747 /// ...
1748 /// %2 = lastDefOfConsumerOperand
1749 /// ```
1750 ///
1751 /// @param loopOp: loop operation
1752 /// @param consumerOp: consumer operation
1753 /// @param reorderOperations: the flag controls whether to reorder the
1754 /// backward slice w.r.t. the defineOp of `consumerOp` operands.
1755 /// @return: computed backward slice of consumerOp, but excluding those
1756 /// already dominates `firstUserOfLoop`.
1757 static FailureOr<llvm::SetVector<Operation *>>
1758 checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
1759                        bool reorderOperations) {
1760   FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1761   if (failed(firstUserOfLoop))
1762     return failure();
1763 
1764   BackwardSliceOptions options;
1765   DominanceInfo dominanceInfo;
1766   options.inclusive = true;
1767   options.omitBlockArguments = true;
1768   bool includeLoopOp = false;
1769   options.filter = [&](Operation *op) {
1770     if (op == loopOp) {
1771       includeLoopOp = true;
1772       return false;
1773     }
1774     // Cut off the slice to not include any operation that already dominates
1775     // firstUserOfLoop.
1776     return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
1777   };
1778   llvm::SetVector<Operation *> slice;
1779   for (auto operand : consumerOp->getOperands()) {
1780     getBackwardSlice(operand, &slice, options);
1781   }
1782 
1783   if (!slice.empty()) {
1784     // If consumerOp has one producer, which is also the user of loopOp.
1785     // E.g.
1786     // ```
1787     //  %0 = %loopOp
1788     //  %1 = consumerOp1 ins(%0)
1789     //  %2 = consumerOp2 ins(%0, %1)
1790     // ```
1791     // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1792     // consumerOp1 has already been fused into loopOp before.
1793     if (includeLoopOp || !reorderOperations)
1794       return failure();
1795   }
1796 
1797   return slice;
1798 }
1799 
1800 /// Fetches the OpOperand of the first valid user (and use) of the value `val`
1801 /// which implements `TilingInterface` and `DestinationStyleOpInterface`.
1802 /// Returns failure otherwise.
1803 static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
1804                                                       Operation *loopOp,
1805                                                       unsigned resultNumber) {
1806   if (!isa<LoopLikeOpInterface>(loopOp))
1807     return failure();
1808   Value val = loopOp->getResult(resultNumber);
1809   Block *loopBlock = loopOp->getBlock();
1810   for (OpOperand &opOperand : val.getUses()) {
1811     Operation *consumerOp = opOperand.getOwner();
1812     // Step 1. Check if the user is tilable.
1813     if (!isa<TilingInterface>(consumerOp) ||
1814         !isa<DestinationStyleOpInterface>(consumerOp)) {
1815       // TODO: We have to init result of consumer before scf.for, use
1816       // DestinationStyleOpInterface to get result shape from init for now.
1817       // Add support for other op such as op has InferTypeOpInterface.
1818       continue;
1819     }
1820     // Step 2. Check if user stay in the same block.
1821     if (loopBlock != consumerOp->getBlock())
1822       continue;
1823     // Step 3. Check if user has succeeding user. Otherwise, it usually
1824     // represents already tiled.
1825     if (consumerOp->use_empty())
1826       continue;
1827     // Step 4. Check assumption for loop with `reorderOperations` enabled.
1828     FailureOr<llvm::SetVector<Operation *>> slice =
1829         checkAssumptionForLoop(loopOp, consumerOp, true);
1830     if (failed(slice))
1831       continue;
1832     // Step 5. If backward sice is not empty, move them before
1833     // firstUserOfLoop.
1834     if (!slice->empty()) {
1835       mlir::topologicalSort(*slice);
1836       FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1837       assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
1838       for (auto op : *slice) {
1839         rewriter.moveOpBefore(op, *firstUserOfLoop);
1840       }
1841     }
1842     return &opOperand;
1843   }
1844   return failure();
1845 }
1846 
1847 /// Find the perfectly nested loops outside of given loop(included) sorted
1848 /// from outer to inner.
1849 ///
1850 /// E.g.
1851 ///
1852 /// ```
1853 ///  %0 = scf.for()
1854 ///    %1 = scf.for()
1855 ///      %2 = scf.for()
1856 ///         %3 = ...
1857 ///         yield %3
1858 ///      yield %2
1859 ///    yield %1
1860 /// ```
1861 ///
1862 /// This function will return three perfectly nested loops: %0 + %1 + %2, when
1863 /// target inner loop is %2.
1864 static SmallVector<scf::ForOp>
1865 getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
1866   SmallVector<scf::ForOp> nestLoops = {loop};
1867   auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1868 
1869   // Check if it is the ForOp that yield the result of inner loop.
1870   auto isForOpYieldResultOfInnerLoop =
1871       [](scf::ForOp outerLoop) -> LogicalResult {
1872     Block *body = outerLoop.getBody();
1873     if (!llvm::hasSingleElement(body->without_terminator()))
1874       return failure();
1875     auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
1876     auto innerForOp = dyn_cast<scf::ForOp>(body->front());
1877     if (!innerForOp)
1878       return failure();
1879     // All of innerForOp results should be yielded.
1880     return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1881   };
1882 
1883   while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1884     nestLoops.push_back(outerLoop);
1885     outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1886   }
1887   // sorted from outer to inner
1888   return {nestLoops.rbegin(), nestLoops.rend()};
1889 }
1890 
1891 /// Fetch the untiled consumer of a scf.for's result which is yielded by a
1892 /// tensor.insert_slice. This function makes the following assumptions :
1893 /// 1.  tensor.insert_slice has scf.yield as its only user.
1894 /// 2.  scf.for's corresponding result has only one use.
1895 static FailureOr<OpOperand *>
1896 getUntiledConsumerFromSlice(RewriterBase &rewriter,
1897                             tensor::InsertSliceOp candidateSliceOp) {
1898   if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
1899     return failure();
1900   Value sliceResult = candidateSliceOp.getResult();
1901   // Step 1. Fetch the corresponding output.
1902   OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1903   unsigned resultNumber = yieldOpOperand.getOperandNumber();
1904   // Step 2. Check containing op is scf.for.
1905   Operation *containingOp = candidateSliceOp->getParentOp();
1906   auto forOp = dyn_cast<scf::ForOp>(containingOp);
1907   if (!forOp)
1908     return failure();
1909   scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
1910 
1911   return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
1912 }
1913 
1914 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
1915 /// by a tensor.parallel_insert_slice.
1916 static FailureOr<OpOperand *>
1917 getUntiledConsumerFromSlice(RewriterBase &rewriter,
1918                             tensor::ParallelInsertSliceOp candidateSliceOp) {
1919   // Step 1. Fetch the corresponding output
1920   Value sliceDest = candidateSliceOp.getDest();
1921   auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1922   if (!iterArg)
1923     return failure();
1924   Operation *containingOp = iterArg.getOwner()->getParentOp();
1925   if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
1926     return failure();
1927   // Step 2. Check that the containing op is scf.forall.
1928   auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1929   if (!forallOp)
1930     return failure();
1931   unsigned resultNumber =
1932       forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1933           .getResultNumber();
1934 
1935   return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
1936 }
1937 
1938 /// A utility to fetch an untiled consumer of
1939 /// tensor.insert_slice/tensor.parallel_insert_slice.
1940 static FailureOr<OpOperand *>
1941 getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
1942   if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1943     return getUntiledConsumerFromSlice(rewriter, insertSlice);
1944   } else if (auto parallelInsertSlice =
1945                  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1946     return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
1947   } else {
1948     return failure();
1949   }
1950 }
1951 
1952 /// Implementation of fusing consumer of a single slice by computing the
1953 /// slice of the consumer in-place for scf loop.
1954 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1955 mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1956                                       Operation *candidateSliceOp) {
1957   if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1958           candidateSliceOp))
1959     return failure();
1960 
1961   bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1962 
1963   // 1. Get the consumer of scf.for for the result yielded by
1964   // tensor.insert_slice/parallel_insert_slice.
1965   FailureOr<OpOperand *> maybeConsumerOpOperand =
1966       getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
1967   if (failed(maybeConsumerOpOperand)) {
1968     return rewriter.notifyMatchFailure(candidateSliceOp,
1969                                        "could not fetch consumer to fuse");
1970   }
1971   OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1972   Operation *consumerOp = consumerOpOperand->getOwner();
1973   unsigned operandNumber = consumerOpOperand->getOperandNumber();
1974   unsigned resultNumber = 0;
1975   if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
1976     resultNumber = producerResult.getResultNumber();
1977   } else {
1978     return rewriter.notifyMatchFailure(
1979         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
1980   }
1981 
1982   // There are two possible cases regarding `oldLoopOp` here:
1983   // 1. single `scf.forall` or `scf.for`.
1984   // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1985   // top-level loop is the outer-most one of these nested loops.
1986   LoopLikeOpInterface innerMostLoop =
1987       candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
1988   SmallVector<LoopLikeOpInterface> nestedLoops;
1989   if (isInsertSliceOp) {
1990     nestedLoops = llvm::map_to_vector(
1991         getPerfectlyNestedLoopsOutsideOf(
1992             cast<scf::ForOp>(innerMostLoop.getOperation())),
1993         [](scf::ForOp forOp) {
1994           return cast<LoopLikeOpInterface>(forOp.getOperation());
1995         });
1996   } else {
1997     nestedLoops = {innerMostLoop};
1998   }
1999 
2000   LoopLikeOpInterface outerMostLoop = nestedLoops.front();
2001 
2002   // Check assumption for loop with `reorderOperations` disabled.
2003   if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
2004     return rewriter.notifyMatchFailure(
2005         outerMostLoop, "the first user of loop should not dominate any define "
2006                        "of consumer operand(s)");
2007   }
2008 
2009   OpBuilder::InsertionGuard g(rewriter);
2010 
2011   // 2. Check consumer is not using scf loop's output as init.
2012   auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2013   if (!dstOp)
2014     return rewriter.notifyMatchFailure(consumerOp,
2015                                        "consumer op is not DPS operation");
2016   SmallVector<Value> dpsInits =
2017       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
2018   if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
2019     return rewriter.notifyMatchFailure(
2020         consumerOp,
2021         "consumer op taking the result of scf.for as init is not supported");
2022   }
2023   SmallVector<Value> newInits = dpsInits;
2024 
2025   Location loc = outerMostLoop->getLoc();
2026 
2027   // 3. Move the whole loop structure right before firstUserOfLoop, the
2028   // dominance should be already ensured by `checkAssumptionForLoop`.
2029   FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
2030   if (failed(firstUserOfLoop)) {
2031     return rewriter.notifyMatchFailure(
2032         outerMostLoop, "could not find the first user of outer most loop");
2033   }
2034   rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
2035 
2036   // 4. Set insertion point before terminator op of the loop and create a new
2037   // tensor.insert_slice. In the scf.for case this is a clone of the
2038   // candidateSliceOp whereas in the scf.forall case this is created from the
2039   // operands of tensor.parallel_insert_slice.
2040   tensor::InsertSliceOp clonedInsertSliceOp;
2041   if (auto sliceOp =
2042           dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
2043     auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2044     rewriter.setInsertionPoint(newForallOp.getTerminator());
2045     clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
2046         loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
2047         sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
2048   } else {
2049     rewriter.setInsertionPoint(candidateSliceOp);
2050     clonedInsertSliceOp =
2051         cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
2052   }
2053 
2054   // 5.a. Clone consumer op.
2055   auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
2056 
2057   // 5.b. Replace all uses of the loop result with the result of the cloned
2058   // tensor.insert_slice.
2059   OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
2060   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
2061     operandToReplace.set(clonedInsertSliceOp.getResult());
2062   });
2063 
2064   // 6. Perform tiling of the cloned consumer and replace the operand at
2065   // `operandNumber` with the source of the cloned tensor.insert_slice op.
2066   auto ossSliceOp =
2067       cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
2068   FailureOr<TilingResult> tileAndFuseResult =
2069       tensor::replaceInsertSliceWithTiledConsumer(
2070           rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
2071   if (failed(tileAndFuseResult)) {
2072     return failure();
2073   }
2074   auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2075   rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
2076                               clonedInsertSliceOp.getSource());
2077 
2078   // 7. Reconstruct [nested] loop with new inits.
2079   YieldTiledValuesFn newYieldValuesFn =
2080       [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
2081           ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
2082           SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
2083           SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
2084     OpBuilder::InsertionGuard g(innerRewriter);
2085     // 8. Set inner insertPoint right before tiled consumer op.
2086     innerRewriter.setInsertionPoint(tiledConsumerOp);
2087 
2088     SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
2089     SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
2090     SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
2091 
2092     // 9. Check all insert stride is 1.
2093     if (llvm::any_of(strides, [](OpFoldResult stride) {
2094           return !isConstantIntValue(stride, 1);
2095         })) {
2096       return rewriter.notifyMatchFailure(
2097           candidateSliceOp, "containingOp's result yield with stride");
2098     }
2099 
2100     // 10. Try to get iter domain position from input position. Use
2101     // clonedConsumerOp instead of tiledConsumerOp, because the iteration
2102     // domain may require index computation based on the result size. The
2103     // sizes and offsets should be the same either way, but using
2104     // tiledConsumerOp could lead to some chained unnecessary extra index
2105     // computation.
2106     SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
2107     if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
2108             rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
2109             iterDomainSizes))) {
2110       return rewriter.notifyMatchFailure(
2111           clonedConsumerOp,
2112           "can't get iter domain position from input position");
2113     }
2114 
2115     // 11. Try to fetch the offset and size for all results of the cloned
2116     // consumer. This would then be used to form the corresponding
2117     // tensor.insert_slice/parallel_insert_slice later.
2118     unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2119     SmallVector<SmallVector<OpFoldResult>> resultOffsets(
2120         totalNumResultsOfConsumer);
2121     SmallVector<SmallVector<OpFoldResult>> resultSizes(
2122         totalNumResultsOfConsumer);
2123     for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2124       if (failed(tiledConsumerOp.getResultTilePosition(
2125               rewriter, idx, iterDomainOffsets, iterDomainSizes,
2126               resultOffsets[idx], resultSizes[idx]))) {
2127         return rewriter.notifyMatchFailure(
2128             tiledConsumerOp,
2129             "can't get result domain position from iter domain position");
2130       }
2131     }
2132 
2133     // 12. Create `extract_slice` for `iter_args` for DPS operation if
2134     // necessary.
2135     if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2136             tiledConsumerOp.getOperation())) {
2137       rewriter.setInsertionPoint(tiledDestStyleOp);
2138       for (const auto &&[index, newRegionArg] :
2139            llvm::enumerate(newRegionIterArgs)) {
2140         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
2141             loc, newRegionArg, resultOffsets[index], resultSizes[index],
2142             SmallVector<OpFoldResult>(resultOffsets[index].size(),
2143                                       rewriter.getIndexAttr(1)));
2144         // Make a copy of index to avoid a capturing structured binding, which
2145         // is a C++20 extension.
2146         auto dstNumber = index;
2147         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
2148           tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2149         });
2150       }
2151     }
2152 
2153     // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
2154     // caller.
2155     Block *block = rewriter.getInsertionPoint()->getBlock();
2156     rewriter.setInsertionPoint(block->getTerminator());
2157     for (const auto &&[index, result] :
2158          llvm::enumerate(tiledConsumerOp->getResults())) {
2159       tiledResult.push_back(result);
2160       tiledOffset.emplace_back(resultOffsets[index]);
2161       tiledSizes.emplace_back(resultSizes[index]);
2162     }
2163     return success();
2164   };
2165   // 14. Add new inits to [nested] loops.
2166   if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
2167                                        newYieldValuesFn))) {
2168     return rewriter.notifyMatchFailure(tiledConsumerOp,
2169                                        "unable to add new inits to nest loop");
2170   }
2171 
2172   // 15. Replace the result of scf loop and consumer op with new loop's
2173   // results.
2174 
2175   for (auto &&[oldResult, newResult] : llvm::zip(
2176            consumerOp->getResults(),
2177            nestedLoops.front()->getResults().take_back(newInits.size()))) {
2178     rewriter.replaceAllUsesWith(oldResult, newResult);
2179   }
2180 
2181   // 16. Need to erase the old scf loop and the cloned consumer op.
2182   rewriter.eraseOp(clonedConsumerOp);
2183 
2184   return scf::SCFFuseConsumerOfSliceResult{
2185       consumerOpOperand,
2186       &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
2187       tileAndFuseResult->tiledOps};
2188 }
2189 
2190 //===----------------------------------------------------------------------===//
2191 // lowerToLoopsUsingSCFForOp implementation.
2192 //===----------------------------------------------------------------------===//
2193 
2194 FailureOr<SmallVector<scf::ForOp>>
2195 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
2196                                      TilingInterface op) {
2197   // TODO: Handle cases where the op has results if needed.
2198   if (op->getNumResults() > 0) {
2199     return rewriter.notifyMatchFailure(
2200         op, "unable to lower to loops operations with return values");
2201   }
2202 
2203   SmallVector<Range> domain = op.getIterationDomain(rewriter);
2204   SmallVector<Value> ivs;
2205   SmallVector<scf::ForOp> loops;
2206   Location loc = op.getLoc();
2207   for (auto loopRange : domain) {
2208     Value offsetVal =
2209         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
2210     Value sizeVal =
2211         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
2212     Value strideVal =
2213         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
2214     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
2215                                             strideVal, ValueRange{});
2216     loops.push_back(loop);
2217     ivs.push_back(loop.getInductionVar());
2218     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
2219   }
2220   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
2221     return failure();
2222   }
2223   return loops;
2224 }
2225