xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision a9ba1b6dd5133aa4432759c203e807d8039b4cbd)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/IR/Dominance.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
26 #include "mlir/Interfaces/TilingInterface.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 #include <optional>
30 
31 #define DEBUG_TYPE "tile-using-interface"
32 
33 using namespace mlir;
34 
35 scf::SCFTilingOptions &
36 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
37   assert(!tileSizeComputationFunction && "tile sizes already set");
38   auto tileSizes = llvm::to_vector(ts);
39   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
40     return tileSizes;
41   };
42   return *this;
43 }
44 
45 scf::SCFTilingOptions &
46 scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
47   assert(!numThreadsComputationFunction && "num tiles already set");
48   auto numThreads = llvm::to_vector(nt);
49   numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
50     return numThreads;
51   };
52   return *this;
53 }
54 
55 /// Helper method to adjust the interchange vector to match the iteration
56 /// domain.
57 static SmallVector<int64_t>
58 fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
59                       size_t iterationDomainSize) {
60   SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
61   if (filledVector.size() < iterationDomainSize) {
62     auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
63     filledVector.append(range.begin(), range.end());
64   }
65   if (filledVector.size() > iterationDomainSize)
66     filledVector.resize(iterationDomainSize);
67   return filledVector;
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // tileUsingSCF implementation.
72 //===----------------------------------------------------------------------===//
73 
74 /// Verify the tile size options are set in a consistent manner.
75 static LogicalResult
76 verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
77                       const scf::SCFTilingOptions &options) {
78   // Specifying number of threads is only supported on `scf.forall` op.
79   if (options.numThreadsComputationFunction &&
80       options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
81     return rewriter.notifyMatchFailure(
82         loc, "number of threads can only by specified when loop type is "
83              "set to use `scf.forall`");
84   }
85 
86   // If specified, check that the interchange vector is a permutation.
87   if (!options.interchangeVector.empty()) {
88     if (!isPermutationVector(options.interchangeVector)) {
89       return rewriter.notifyMatchFailure(
90           loc, "invalid interchange vector, not a permutation of the entire "
91                "iteration space");
92     }
93   }
94   return success();
95 }
96 
97 /// Method to instantiate the tile sizes and/or number of threads specified
98 /// by the user.
99 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
100 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
101                               ArrayRef<Range> iterationDomain,
102                               const scf::SCFTilingOptions &options) {
103   OpFoldResult zero = rewriter.getIndexAttr(0);
104   SmallVector<OpFoldResult> tileSizes, numThreads;
105   size_t numLoops = iterationDomain.size();
106 
107   // Check whether the number of tiles to use is specified.
108   if (options.numThreadsComputationFunction) {
109     numThreads = options.numThreadsComputationFunction(rewriter, op);
110     numThreads.resize(numLoops, zero);
111 
112     // If the number of tiles is also specified, use that.
113     if (options.tileSizeComputationFunction) {
114       tileSizes = options.tileSizeComputationFunction(rewriter, op);
115       tileSizes.resize(numLoops, zero);
116       return {tileSizes, numThreads};
117     }
118 
119     // Compute the tile sizes from the iteration domain and number
120     // of tiles as follows
121     // - niters = ceilDiv(ub - lb, step)
122     // - tileSize = ceilDiv(niters, numThreads)
123     AffineExpr s0, s1, s2;
124     bindSymbols(rewriter.getContext(), s0, s1, s2);
125     // TODO: The step here is assumed to be 1.
126     AffineExpr numItersExpr = (s1 - s0);
127     AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
128     tileSizes.resize(numLoops, zero);
129     for (auto [index, range, nt] :
130          llvm::enumerate(iterationDomain, numThreads)) {
131       if (isConstantIntValue(nt, 0))
132         continue;
133 
134       tileSizes[index] = affine::makeComposedFoldedAffineApply(
135           rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
136     }
137     tileSizes.resize(numLoops, zero);
138     return {tileSizes, numThreads};
139   }
140 
141   // Enforce the convention that "tiling by zero"
142   // skips tiling a particular dimension. This convention is significantly
143   // simpler to handle instead of adjusting affine maps to account for missing
144   // dimensions.
145   assert(options.tileSizeComputationFunction &&
146          "expected tile sizes to be specified");
147   tileSizes = options.tileSizeComputationFunction(rewriter, op);
148   tileSizes.resize(numLoops, zero);
149 
150   return {tileSizes, numThreads};
151 }
152 
153 /// Checks if any of the tiled loops are not parallel.
154 static void checkSafeToTileToForall(TilingInterface op,
155                                     ArrayRef<OpFoldResult> tileSizes,
156                                     ArrayRef<OpFoldResult> numThreads) {
157   auto iterators = op.getLoopIteratorTypes();
158   assert(iterators.size() == tileSizes.size() &&
159          "expected as many tile size values as number of loops");
160   assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
161          "when specified, expected number of threads to use for each loop");
162 
163   for (auto [index, iterator, tileSize] :
164        llvm::enumerate(iterators, tileSizes)) {
165     // If num threads is specified, check that it is greater than one only for
166     // parallel dimensions.
167     if (!numThreads.empty()) {
168       if (std::optional<int64_t> constNumThreads =
169               getConstantIntValue(numThreads[index])) {
170         if (constNumThreads.value() > 1 &&
171             iterator != utils::IteratorType::parallel) {
172           op.emitWarning() << "tiling is not thread safe at axis #" << index;
173         }
174       }
175       continue;
176     }
177 
178     if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
179       if (constTileSize.value() > 0 &&
180           iterator != utils::IteratorType::parallel) {
181         op.emitWarning() << "tiling is not thread safe at axis #" << index;
182       }
183     }
184   }
185 }
186 
187 /// Check if `stride` evenly divides the trip count `size - offset`.
188 static bool tileDividesIterationDomain(Range loopRange) {
189   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
190   if (!offsetAsInt)
191     return false;
192   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
193   if (!sizeAsInt)
194     return false;
195   std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
196   if (!strideAsInt)
197     return false;
198   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
199 }
200 
201 /// Returns the bounded tile size given the current `offset`, `loopRange` and
202 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
203 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
204                                        Range loopRange, OpFoldResult offset,
205                                        OpFoldResult tileSize) {
206   std::optional<int64_t> ts = getConstantIntValue(tileSize);
207   if (ts && ts.value() == 1)
208     return tileSize;
209 
210   if (tileDividesIterationDomain(
211           Range{loopRange.offset, loopRange.size, tileSize}))
212     return tileSize;
213 
214   // The tile size to use (to avoid out of bounds access) is  minimum of
215   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
216   // loop.
217   AffineExpr s0, s1, d0;
218   bindDims(b.getContext(), d0);
219   bindSymbols(b.getContext(), s0, s1);
220   AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
221   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
222   return affine::makeComposedFoldedAffineMin(
223       b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
224 }
225 
226 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
227 /// than `iterationSize`.
228 static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
229                                            OpFoldResult numThreads,
230                                            OpFoldResult iterationSize) {
231   std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
232   std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
233   std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
234   if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
235     return false;
236   return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
237 }
238 
239 /// Compute the `OpFoldResult`s that represents the multi-dimensional
240 /// `offset`s and `size`s of the tile of the iteration space that the
241 /// innermost loop body of the generated tiled loops corresponds to.
242 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
243 getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
244                       ArrayRef<Range> iterationDomain,
245                       ArrayRef<OpFoldResult> tileSizes,
246                       ArrayRef<OpFoldResult> numThreads) {
247   SmallVector<OpFoldResult> offsets, sizes;
248   int materializedLoopNum = 0;
249 
250   if (!numThreads.empty()) {
251     AffineExpr d0, d1, s0, s1;
252     AffineExpr offsetExpr, residualTileSizeExpr;
253     bindDims(rewriter.getContext(), d0, d1);
254     bindSymbols(rewriter.getContext(), s0, s1);
255     offsetExpr = d0 + d1 * s0;
256     residualTileSizeExpr = s1 - (d0 + d1 * s0);
257 
258     for (auto [nt, tileSize, loopRange] :
259          llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
260 
261       // Non-tiled cases, set the offset and size to the
262       // `loopRange.offset/size`.
263       if (isConstantIntValue(nt, 0)) {
264         offsets.push_back(loopRange.offset);
265         sizes.push_back(loopRange.size);
266         continue;
267       }
268 
269       Value iv = ivs[materializedLoopNum++];
270       OpFoldResult offset = affine::makeComposedFoldedAffineApply(
271           rewriter, loc, offsetExpr,
272           ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
273       OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
274           rewriter, loc, residualTileSizeExpr,
275           {loopRange.offset, nt, tileSize, loopRange.size});
276 
277       OpFoldResult size = tileSize;
278       if (!isConstantIntValue(residualTileSize, 0)) {
279         OpFoldResult sizeMinusOffsetPerThread =
280             affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
281                                                   {offset, loopRange.size});
282         size = affine::makeComposedFoldedAffineMin(
283             rewriter, loc,
284             AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
285             {sizeMinusOffsetPerThread, tileSize});
286       }
287 
288       // Consider the case where the original loop was `[0, 100)`.
289       // If number of threads are `7`, the tile size would be computed as
290       // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
291       // - `offset = 0 + 6 * 15 = 105`
292       // - `tileSize = min(15, 100 - 105) = -5`
293       // To avoid negative tile sizes, we need to do a further
294       // `nonNegativeTileSize = affine.max(0, tileSize)`.
295       // This `max` can be avoided if
296       //  `offset + tileSize * (numThreads - 1) < (ub - lb)`
297       if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
298         AffineMap maxMap =
299             AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
300         size = affine::makeComposedFoldedAffineMax(
301             rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
302       }
303 
304       offsets.push_back(offset);
305       sizes.push_back(size);
306     }
307     return {offsets, sizes};
308   } else {
309     for (auto [tileSize, loopRange] :
310          llvm::zip_equal(tileSizes, iterationDomain)) {
311 
312       // Non-tiled cases, set the offset and size to the
313       // `loopRange.offset/size`.
314       if (isConstantIntValue(tileSize, 0)) {
315         offsets.push_back(loopRange.offset);
316         sizes.push_back(loopRange.size);
317         continue;
318       }
319 
320       Value iv = ivs[materializedLoopNum++];
321       OpFoldResult offset = getAsOpFoldResult(iv);
322       offsets.push_back(offset);
323       OpFoldResult size =
324           getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
325       sizes.push_back(size);
326     }
327     return {offsets, sizes};
328   }
329 }
330 
331 /// Function to return the bounds of the loops to be generated.
332 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
333                   SmallVector<OpFoldResult>>
334 getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
335               ArrayRef<OpFoldResult> tileSizes) {
336   SmallVector<OpFoldResult> lbs, ubs, steps;
337   for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
338     // No loop if the tile size is 0.
339     if (isConstantIntValue(tileSize, 0))
340       continue;
341     lbs.push_back(loopRange.offset);
342     ubs.push_back(loopRange.size);
343     steps.push_back(tileSize);
344   }
345   return {lbs, ubs, steps};
346 }
347 
348 /// A function that allows returning additional yielded values during
349 /// `yieldTiledValuesAndReplace`.
350 /// - `ivs` induction variable for the loop.
351 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
352 /// - `tiledValues` the tiled values to return. Must be of same size as
353 ///   `newbbArgs`, each element of this array is inserted into the corresponding
354 ///   element in `newbbArgs`.
355 /// - `resultOffsets` is of the same size as `tiledValues` and represents
356 ///   the offsets to use when inserting corresponding element from `tiledValues`
357 ///   into the element from `newBbArgs`.
358 /// - `resultSizes` is of the same size as `tiledValues` and represents
359 ///   the size of the corresponding element from `tiledValues` inserted into
360 ///   the element from `newBbArgs`.
361 /// In case the method needs to return `failure()` the method is expected
362 /// to clean up any inserted operations.
363 using YieldTiledValuesFn = std::function<LogicalResult(
364     RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
365     SmallVector<Value> &tiledValues,
366     SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
367     SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
368 
369 /// Clones the operation and updates the destination if the operation
370 /// implements the `DestinationStyleOpInterface`.
371 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
372                                                   Operation *op,
373                                                   ValueRange newDestArgs) {
374   Operation *clonedOp = rewriter.clone(*op);
375   if (newDestArgs.empty())
376     return clonedOp;
377   if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
378     destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
379   return clonedOp;
380 }
381 
382 /// Generate the tile-loop nest using `scf.for` operation.
383 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
384 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
385 /// - `destinationTensors` are the init values to use for the outer most loop.
386 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
387 /// most
388 ///    loop.
389 /// - `loops` is an in-out parameter into which the generated loops are
390 ///    populated.
391 static LogicalResult generateLoopNestUsingForOp(
392     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
393     ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
394     YieldTiledValuesFn yieldTiledValuesFn,
395     SmallVector<LoopLikeOpInterface> &loops) {
396   assert(!loopRanges.empty() && "unexpected empty loop ranges");
397   assert(loopRanges.size() == tileSizes.size() &&
398          "expected as many tile sizes as loop ranges");
399   OpBuilder::InsertionGuard guard(rewriter);
400 
401   SmallVector<OpFoldResult> lbs, ubs, steps;
402   std::tie(lbs, ubs, steps) =
403       getLoopBounds(rewriter, loc, loopRanges, tileSizes);
404   SmallVector<Value> lbVals =
405       getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
406   SmallVector<Value> ubVals =
407       getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
408   SmallVector<Value> stepVals =
409       getValueOrCreateConstantIndexOp(rewriter, loc, steps);
410 
411   SmallVector<Value> ivs;
412   for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
413     auto loop =
414         rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
415                                     [](OpBuilder &bodyBuilder, Location bodyLoc,
416                                        Value iv, ValueRange /*iterArgs*/) {});
417     loops.push_back(loop);
418     ivs.push_back(loop.getInductionVar());
419     rewriter.setInsertionPointToEnd(loop.getBody());
420     destinationTensors = loop.getRegionIterArgs();
421   }
422 
423   SmallVector<Value> tiledResults;
424   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
425   if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
426                                 tiledResults, resultOffsets, resultSizes))) {
427     return rewriter.notifyMatchFailure(
428         loc, "failed to generate inner tile loop body");
429   }
430   if (loops.empty())
431     return success();
432 
433   assert(tiledResults.size() == destinationTensors.size() &&
434          "Number of results of body should be equal to number of iter args");
435 
436   // 6. Yield all the results of the tiled operation.
437   SmallVector<Value> yieldedValues;
438   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
439        llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
440                        resultSizes)) {
441     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
442                                            rewriter.getIndexAttr(1));
443     auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
444         loc, tiledValue, destinationTensor, resultOffset, resultSize,
445         resultStride);
446     yieldedValues.push_back(insertSlice);
447   }
448   rewriter.create<scf::YieldOp>(loc, yieldedValues);
449 
450   // Add the scf.yield operations for all the outer loops.
451   for (auto [outerLoop, innerLoop] :
452        llvm::zip_equal(MutableArrayRef(loops).drop_back(),
453                        MutableArrayRef(loops).drop_front())) {
454     rewriter.setInsertionPointToEnd(
455         cast<scf::ForOp>(outerLoop.getOperation()).getBody());
456     rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
457   }
458   return success();
459 }
460 
461 /// Generate the tile-loop nest using `scf.forall` operation.
462 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
463 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
464 /// - `destinationTensors` are the init values to use for the outer most loop.
465 /// - `mappingVector` is the mapping attributes to use for loop construction.
466 ///   Can be empty.
467 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
468 /// most
469 ///    loop.
470 /// - `loops` is an in-out parameter into which the generated loops are
471 ///    populated.
472 static LogicalResult generateLoopNestUsingForallOp(
473     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
474     ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
475     ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
476     YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
477   assert(!loopRanges.empty() && "unexpected empty loop ranges");
478   assert(loopRanges.size() == tileSizes.size() &&
479          "expected as many tile sizes as loop ranges");
480   OpBuilder::InsertionGuard guard(rewriter);
481   SmallVector<OpFoldResult> offsets(loopRanges.size()),
482       sizes(loopRanges.size());
483 
484   std::optional<ArrayAttr> mappingAttr;
485   if (!mappingVector.empty())
486     mappingAttr = rewriter.getArrayAttr(mappingVector);
487 
488   scf::ForallOp forallOp;
489   bool useNumThreads = !numThreads.empty();
490 
491   if (useNumThreads) {
492     // Prune the zero numthreads.
493     SmallVector<OpFoldResult> nonZeroNumThreads;
494     for (auto nt : numThreads) {
495       if (isConstantIntValue(nt, 0))
496         continue;
497       nonZeroNumThreads.push_back(nt);
498     }
499     forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
500                                               destinationTensors, mappingAttr);
501   } else {
502     SmallVector<OpFoldResult> lbs, ubs, steps;
503     std::tie(lbs, ubs, steps) =
504         getLoopBounds(rewriter, loc, loopRanges, tileSizes);
505     forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
506                                               destinationTensors, mappingAttr);
507   }
508   loops.push_back(forallOp);
509 
510   rewriter.setInsertionPoint(forallOp.getTerminator());
511   destinationTensors = forallOp.getRegionOutArgs();
512 
513   SmallVector<Value> tiledResults;
514   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
515   if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
516                          destinationTensors, tiledResults, resultOffsets,
517                          resultSizes)))
518     return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
519 
520   rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
521   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
522        llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
523                        resultSizes)) {
524     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
525                                            rewriter.getIndexAttr(1));
526 
527     rewriter.create<tensor::ParallelInsertSliceOp>(
528         loc, tiledValue, destinationTensor, resultOffset, resultSize,
529         resultStride);
530   }
531   return success();
532 }
533 
534 /// Generate the tile-loop nest using the loop construct specifed in `options`.
535 /// - `options`: Tiling options specified.
536 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
537 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
538 /// - `destinationTensors` are the init values to use for the outer most loop.
539 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
540 /// most
541 ///    loop.
542 /// - `loops` is an in-out parameter into which the generated loops are
543 ///    populated.
544 static LogicalResult generateLoopNest(
545     RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
546     ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
547     ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
548     YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
549   // If the tile sizes are all zero, no loops are generated. Just call the
550   // callback function to handle untiled case.
551   if (llvm::all_of(tileSizes, isZeroIndex)) {
552     SmallVector<Value> tiledResults;
553     SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
554     return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
555                        tiledResults, resultOffsets, resultSizes);
556   }
557   if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
558     return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
559                                       destinationTensors, tiledBodyFn, loops);
560   }
561   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
562     return generateLoopNestUsingForallOp(
563         rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
564         destinationTensors, tiledBodyFn, loops);
565   }
566   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
567 }
568 
569 /// Append the specified additional `newInitOperands` operands to the
570 /// loops existing `init` operands (or similar), and replace `loopOp` with
571 /// the new loop that has the additional init operands. The loop body of
572 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
573 /// is called to get the new tiled values returned, and the offset
574 /// and sizes at which the tiled value is inserted into the
575 /// new region iter_args that correspond to the newly added init operands.
576 template <typename LoopType>
577 FailureOr<LoopLikeOpInterface>
578 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
579                                ValueRange newInitOperands,
580                                YieldTiledValuesFn yieldTiledValuesFn) {
581   return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
582 }
583 
584 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
585 template <>
586 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
587     scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
588     YieldTiledValuesFn yieldTiledValuesFn) {
589   OpBuilder::InsertionGuard g(rewriter);
590   Location loc = loopOp.getLoc();
591   rewriter.setInsertionPoint(loopOp);
592 
593   auto inits = llvm::to_vector(loopOp.getInitArgs());
594   inits.append(newInitOperands.begin(), newInitOperands.end());
595   auto newLoop = rewriter.create<scf::ForOp>(
596       loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
597       inits, [](OpBuilder &, Location, Value, ValueRange) {});
598 
599   // Move the loop body to the new op.
600   Block *loopBody = loopOp.getBody();
601   Block *newLoopBody = newLoop.getBody();
602   rewriter.mergeBlocks(
603       loopBody, newLoopBody,
604       newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
605 
606   auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
607   rewriter.setInsertionPoint(yieldOp);
608 
609   SmallVector<Value> tiledValues;
610   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
611   ValueRange newRegionIterArgs =
612       newLoop.getRegionIterArgs().take_back(newInitOperands.size());
613   if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
614                                 newRegionIterArgs, tiledValues, resultOffsets,
615                                 resultSizes))) {
616     rewriter.eraseOp(newLoop);
617     return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
618   }
619 
620   SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
621   for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
622        llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
623                        resultSizes)) {
624     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
625                                            rewriter.getIndexAttr(1));
626     Value insert = rewriter.create<tensor::InsertSliceOp>(
627         yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
628         resultStride);
629     newYieldValues.push_back(insert);
630   }
631 
632   rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
633   rewriter.replaceOp(loopOp,
634                      newLoop->getResults().take_front(loopOp.getNumResults()));
635   return cast<LoopLikeOpInterface>(newLoop.getOperation());
636 }
637 
638 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
639 template <>
640 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
641     scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
642     YieldTiledValuesFn yieldTiledValuesFn) {
643   OpBuilder::InsertionGuard g(rewriter);
644   Location loc = loopOp.getLoc();
645   rewriter.setInsertionPoint(loopOp);
646   auto inits = llvm::to_vector(loopOp.getOutputs());
647   inits.append(newInitOperands.begin(), newInitOperands.end());
648   auto newLoop = rewriter.create<scf::ForallOp>(
649       loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
650       loopOp.getMixedStep(), inits, loopOp.getMapping(),
651       [](OpBuilder &, Location, ValueRange) {});
652 
653   // Move the region of the current block to the newly created op.
654   Block *loopBody = loopOp.getBody();
655   Block *newLoopBody = newLoop.getBody();
656   rewriter.mergeBlocks(
657       loopBody, newLoopBody,
658       newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
659 
660   auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
661   rewriter.setInsertionPoint(terminator);
662   SmallVector<Value> tiledValues;
663   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
664   ValueRange regionIterArgs =
665       newLoop.getRegionIterArgs().take_back(newInitOperands.size());
666   if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
667                                 regionIterArgs, tiledValues, resultOffsets,
668                                 resultSizes))) {
669     rewriter.eraseOp(newLoop);
670     return rewriter.notifyMatchFailure(loopOp,
671                                        "failed to get yielded tiled values");
672   }
673 
674   // Update the terminator.
675   rewriter.setInsertionPointToEnd(terminator.getBody());
676 
677   for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
678            tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
679     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
680                                            rewriter.getIndexAttr(1));
681     rewriter.create<tensor::ParallelInsertSliceOp>(
682         terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
683         resultStride);
684   }
685 
686   rewriter.replaceOp(loopOp,
687                      newLoop->getResults().take_front(loopOp.getNumResults()));
688   return cast<LoopLikeOpInterface>(newLoop.getOperation());
689 }
690 
691 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
692 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
693 /// supported loop type.
694 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
695     LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
696     ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
697   return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>(
698              loopLikeOp.getOperation())
699       .Case<scf::ForOp, scf::ForallOp>(
700           [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
701             return yieldTiledValuesAndReplaceLoop(
702                 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
703           })
704       .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
705         return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
706       });
707 }
708 
709 /// Method to add new init values to a loop nest. Updates `loops` in-place with
710 /// new loops that use the `newInitValues`.
711 /// The outer-loops are updated to yield the new result values of the inner
712 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get
713 /// the additional values to yield form the innermost loop.
714 static LogicalResult addInitOperandsToLoopNest(
715     RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
716     ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
717   SmallVector<scf::ForOp> newLoops;
718   if (loops.empty())
719     return success();
720   OpBuilder::InsertionGuard g(rewriter);
721   rewriter.setInsertionPoint(loops.front());
722 
723   SmallVector<Value> ivs;
724   for (auto &loop : loops.drop_back()) {
725     rewriter.setInsertionPoint(loop);
726 
727     // if loops.size() > 1 we assume that scf.for is used for the loops.
728     auto forLoop = cast<scf::ForOp>(loop.getOperation());
729 
730     // Create a new loop with the new init values for this loop.
731     SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
732     newInits.append(newInitValues.begin(), newInitValues.end());
733     auto newLoop = rewriter.create<scf::ForOp>(
734         forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
735         forLoop.getStep(), newInits,
736         [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
737 
738     // Merge the body of the new loop with the body of the old loops.
739     SmallVector<Value> sourceBlockArgs;
740     sourceBlockArgs.push_back(newLoop.getInductionVar());
741     auto newRegionIterArgs = newLoop.getRegionIterArgs();
742     sourceBlockArgs.append(
743         newRegionIterArgs.begin(),
744         std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
745     rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
746     rewriter.replaceOp(
747         forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
748     loop = newLoop;
749     ivs.push_back(newLoop.getInductionVar());
750     newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
751   }
752 
753   // Update the loop body of the innermost loop to get new yield values.
754   LoopLikeOpInterface innerMostLoop = loops.back();
755   FailureOr<LoopLikeOpInterface> newInnerMostLoop =
756       yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
757                                      getNewTiledYieldsFn);
758 
759   if (failed(newInnerMostLoop))
760     return innerMostLoop.emitOpError("failed to return additional yields");
761   loops.back() = newInnerMostLoop.value();
762 
763   // Make all other loops except the innermost loops yield the values returned
764   // by the inner loop.
765   for (auto [outerLoop, innerLoop] :
766        llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
767     // Again assume that all the outer loops are scf.for operations.
768     auto outerForLoop = cast<scf::ForOp>(outerLoop);
769     auto outerLoopYield =
770         cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
771     SmallVector<Value> newYields =
772         llvm::to_vector(outerLoopYield.getOperands());
773     ValueRange additionalYields =
774         innerLoop->getResults().take_back(newInitValues.size());
775     newYields.append(additionalYields.begin(), additionalYields.end());
776     rewriter.setInsertionPoint(outerLoopYield);
777     rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
778   }
779   return success();
780 }
781 
782 /// Implementation of tiling transformation of `op` that implements the
783 /// `TilingInterface` using `scf.for` to iterate over the tiles.
784 FailureOr<scf::SCFTilingResult>
785 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
786                         const scf::SCFTilingOptions &options) {
787   if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
788     return failure();
789   }
790 
791   OpBuilder::InsertionGuard guard(rewriter);
792   rewriter.setInsertionPointAfter(op);
793 
794   // 1. Get the range of the loops that are represented by the operation.
795   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
796 
797   // 2. Materialize the tile sizes and/or number of threads;
798   SmallVector<OpFoldResult> tileSizes, numThreads;
799   std::tie(tileSizes, numThreads) =
800       getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
801 
802   // Check if it is safe to tile. This is hold over from previous iterations
803   // of tile to for-all. Consider dropping it.
804   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
805     checkSafeToTileToForall(op, tileSizes, numThreads);
806   }
807 
808   // 3. If there is an interchange specified, permute the iteration domain and
809   // the tile sizes.
810   SmallVector<int64_t> interchangeVector;
811   if (!options.interchangeVector.empty()) {
812     interchangeVector = fillInterchangeVector(options.interchangeVector,
813                                               iterationDomain.size());
814     assert(isPermutationVector(interchangeVector) &&
815            "expected interchange vector to be a permutation");
816 
817     applyPermutationToVector(iterationDomain, interchangeVector);
818     applyPermutationToVector(tileSizes, interchangeVector);
819     if (!numThreads.empty())
820       applyPermutationToVector(numThreads, interchangeVector);
821   }
822 
823   FailureOr<TilingResult> tilingResult;
824   // 4. Define the lambda function used later to generate the body of the
825   // innermost tiled loop.
826   YieldTiledValuesFn innerYieldTiledValuesFn =
827       [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
828           ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
829           SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
830           SmallVector<SmallVector<OpFoldResult>> &resultSizes)
831       -> LogicalResult {
832     // 4a. Compute the `offsets` and `sizes` to use for tiling.
833     SmallVector<OpFoldResult> offsets, sizes;
834     std::tie(offsets, sizes) = getTileOffsetAndSizes(
835         rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
836 
837     // 4b. If interchange was provided, apply inverse of the interchange
838     //     to get back the offsets/sizes in the order to be specified.
839     if (!interchangeVector.empty()) {
840       auto inversePermutation = invertPermutationVector(interchangeVector);
841       applyPermutationToVector(offsets, inversePermutation);
842       applyPermutationToVector(sizes, inversePermutation);
843     }
844 
845     // 5. Generate the tiled implementation within the inner most loop.
846 
847     // 5a. Clone the operation within the loop body.
848     auto clonedOp = cast<TilingInterface>(
849         cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
850 
851     // 5b. Early return cloned op if tiling is not happening. We can not return
852     // the original op because it could lead to
853     // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
854     if (llvm::all_of(tileSizes, isZeroIndex)) {
855       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
856       tilingResult =
857           TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
858       return success();
859     }
860 
861     // 5c. Tile the cloned operation.
862     tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
863     if (failed(tilingResult)) {
864       rewriter.eraseOp(clonedOp);
865       return op.emitOpError("faild to tile operation");
866     }
867 
868     // 5d. Delete the cloned operation.
869     rewriter.eraseOp(clonedOp);
870 
871     // 5e. Compute the offsets at which the result values are to be inserted
872     //     back into its destinations.
873     for (auto [index, tiledValue] :
874          llvm::enumerate(tilingResult->tiledValues)) {
875       tiledResults.push_back(tiledValue);
876       SmallVector<OpFoldResult> resultOffset, resultSize;
877       if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
878                                           resultOffset, resultSize))) {
879         for (auto op : tilingResult->tiledOps) {
880           rewriter.eraseOp(op);
881         }
882         return rewriter.notifyMatchFailure(
883             op, "failed to get slice of result produced");
884       }
885       resultOffsets.emplace_back(std::move(resultOffset));
886       resultSizes.emplace_back(std::move(resultSize));
887     }
888 
889     return success();
890   };
891 
892   // 6. Find the destination tensors to use for the operation.
893   SmallVector<Value> destinationTensors;
894   if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
895                                              destinationTensors))) {
896     return rewriter.notifyMatchFailure(op,
897                                        "unable to create destination tensors");
898   }
899 
900   // 7. Generate the tiled loops nest using the callback defined above.
901   SmallVector<LoopLikeOpInterface> loops;
902   if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
903                               tileSizes, numThreads, destinationTensors,
904                               innerYieldTiledValuesFn, loops)))
905     return op.emitOpError("failed to generate tiling loops");
906   assert(succeeded(tilingResult) &&
907          "expected tiling result to be computed after loop generation");
908 
909   // If loops are empty, the tiled op is used as the replacement for the untiled
910   // op.
911   if (loops.empty()) {
912     return scf::SCFTilingResult{tilingResult->tiledOps, loops,
913                                 tilingResult->tiledValues};
914   }
915 
916   SmallVector<Value> replacements = llvm::map_to_vector(
917       loops.front()->getResults(), [](OpResult r) -> Value { return r; });
918   return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
919 }
920 
921 FailureOr<scf::SCFReductionTilingResult>
922 mlir::scf::tileReductionUsingScf(RewriterBase &b,
923                                  PartialReductionOpInterface op,
924                                  ArrayRef<OpFoldResult> tileSizes) {
925   Location loc = op.getLoc();
926   // Ops implementing PartialReductionOpInterface are expected to implement
927   // TilingInterface.
928   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
929   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
930   auto tileSizesVector = llvm::to_vector(tileSizes);
931   if (tileSizesVector.size() < iterationDomain.size()) {
932     auto zero = b.getIndexAttr(0);
933     tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
934                            zero);
935   }
936   SmallVector<utils::IteratorType> iterators =
937       tilingInterfaceOp.getLoopIteratorTypes();
938 
939   SmallVector<int> reductionDims;
940   for (auto [idx, iteratorType] :
941        llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
942     if (iteratorType == utils::IteratorType::reduction)
943       reductionDims.push_back(idx);
944   }
945 
946   // 2. create the inital tensor value.
947   FailureOr<SmallVector<Value>> maybeInitTensors =
948       op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
949                                                   reductionDims);
950   if (failed(maybeInitTensors)) {
951     return b.notifyMatchFailure(op, "Failed to create initial tensors.");
952   }
953   SmallVector<Value> &initTensors = maybeInitTensors.value();
954 
955   // 3. Define the callback to use for generating the inner most tile loop body.
956   SmallVector<Operation *> parallelTiledOps;
957   auto innerYieldTiledValuesFn =
958       [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
959           ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
960           SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
961           SmallVector<SmallVector<OpFoldResult>> &resultSizes)
962       -> LogicalResult {
963     SmallVector<OpFoldResult> offsets, sizes;
964     {
965       int materializedLoopNum = 0;
966       for (auto [tileSize, loopRange] :
967            llvm::zip_equal(tileSizesVector, iterationDomain)) {
968         if (isConstantIntValue(tileSize, 0)) {
969           offsets.push_back(loopRange.offset);
970           sizes.push_back(loopRange.size);
971           continue;
972         }
973         Value iv = ivs[materializedLoopNum++];
974         offsets.push_back(iv);
975         sizes.push_back(
976             getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
977       }
978     }
979 
980     // 4a. Clone the operation.
981     {
982       auto clonedOp = cast<PartialReductionOpInterface>(
983           cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
984 
985       // 4b. Tile the cloned operation.
986       FailureOr<TilingResult> partialTilingResult =
987           clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
988                                           sizes, reductionDims);
989       if (failed(partialTilingResult)) {
990         return failure();
991       }
992       std::swap(parallelTiledOps, partialTilingResult->tiledOps);
993       std::swap(tiledResult, partialTilingResult->tiledValues);
994 
995       // 4c. Delete the cloned operation.
996       b.eraseOp(clonedOp);
997     }
998 
999     // 4d. Compute the offsets and sizes needed to insert the result of the
1000     // tiled value back into destination before yielding the destination.
1001     for (auto result : tiledResult) {
1002       SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
1003       resultOffsets.emplace_back(std::move(outOffsets));
1004 
1005       SmallVector<OpFoldResult> outSizes;
1006       for (size_t i = 0; i < offsets.size(); i++) {
1007         outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
1008       }
1009       resultSizes.emplace_back(std::move(outSizes));
1010     }
1011     return success();
1012   };
1013 
1014   // 5. Generate the tiled implementation using the destination tensors.
1015   SmallVector<LoopLikeOpInterface> loops;
1016   scf::SCFTilingOptions options;
1017   options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
1018   if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
1019                               /*numThreads=*/ArrayRef<OpFoldResult>{},
1020                               initTensors, innerYieldTiledValuesFn, loops)))
1021     return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
1022 
1023   SmallVector<Value> replacements = llvm::map_to_vector(
1024       loops.front()->getResults(), [](OpResult r) -> Value { return r; });
1025 
1026   // 5. Apply the merge reduction to combine all the partial values.
1027   b.setInsertionPointAfter(*loops.begin());
1028   FailureOr<MergeResult> mergeResult =
1029       op.mergeReductions(b, loc, replacements, reductionDims);
1030   if (failed(mergeResult)) {
1031     return failure();
1032   }
1033   b.replaceOp(op, mergeResult->replacements);
1034 
1035   SCFReductionTilingResult reductionTilingResult;
1036   std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
1037   std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
1038   std::swap(reductionTilingResult.initialValues, initTensors);
1039   std::swap(reductionTilingResult.loops, loops);
1040   std::swap(reductionTilingResult.replacements, mergeResult->replacements);
1041 
1042   return reductionTilingResult;
1043 }
1044 
1045 //===----------------------------------------------------------------------===//
1046 // tileConsumerAndFuseProducersUsingSCF implementation.
1047 //===----------------------------------------------------------------------===//
1048 
1049 /// Return the untiled producer whose slice is used in a tiled consumer. The
1050 /// method traverses the tile loop nest (`loops`) if needed, and returns the
1051 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
1052 /// indicates that this is a destination operand of the consumer. If there was
1053 /// no loop traversal needed, the second value of the returned tuple is empty.
1054 static std::tuple<OpResult, std::optional<OpOperand *>>
1055 getUntiledProducerFromSliceSource(OpOperand *source,
1056                                   ArrayRef<LoopLikeOpInterface> loops) {
1057   std::optional<OpOperand *> destinationIterArg;
1058   auto loopIt = loops.rbegin();
1059   while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
1060     auto loop = *loopIt;
1061     if (iterArg.getOwner()->getParentOp() != loop)
1062       break;
1063     source = loop.getTiedLoopInit(iterArg);
1064     loopIt++;
1065   }
1066   if (loopIt == loops.rend())
1067     destinationIterArg = source;
1068   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
1069 }
1070 
1071 /// Implementation of fusing producer of a single slice by computing the
1072 /// slice of the producer in-place.
1073 std::optional<scf::SCFFuseProducerOfSliceResult>
1074 mlir::scf::tileAndFuseProducerOfSlice(
1075     RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1076     MutableArrayRef<LoopLikeOpInterface> loops) {
1077   // 1. Get the producer of the source (potentially walking through
1078   // `iter_args` of nested `scf.for`)
1079   auto [fusableProducer, destinationInitArg] =
1080       getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1081                                         loops);
1082   if (!fusableProducer)
1083     return std::nullopt;
1084   unsigned resultNumber = fusableProducer.getResultNumber();
1085 
1086   OpBuilder::InsertionGuard g(rewriter);
1087   rewriter.setInsertionPoint(candidateSliceOp);
1088 
1089   // 2. Clone the fused producer
1090   // 2a. Compute the destination operands to use for the cloned operation.
1091   SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1092   Operation *fusableProducerOp = fusableProducer.getOwner();
1093   if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1094       failed(tensor::getOrCreateDestinations(
1095           rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1096           origDestinationTensors)))
1097     return std::nullopt;
1098 
1099   clonedOpDestinationTensors = origDestinationTensors;
1100   if (destinationInitArg &&
1101       isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1102     // 2b. If the producer is also destination style, then to maintain the
1103     // destination passing style, update the destination of the producer to be
1104     // the source of the slice.
1105     clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1106   }
1107   // 2c. Clone the fused producer.
1108   Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1109       rewriter, fusableProducerOp, clonedOpDestinationTensors);
1110   // 2d. Update the source of the candidateSlice to be the cloned producer.
1111   //     Easier to just clone the slice with different source since replacements
1112   //     and DCE of cloned ops becomes easier
1113   SmallVector<Value> candidateSliceOpOperands =
1114       llvm::to_vector(candidateSliceOp->getOperands());
1115   candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
1116   tensor::ExtractSliceOp clonedCandidateSliceOp =
1117       mlir::clone(rewriter, candidateSliceOp,
1118                   candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1119 
1120   // 3. Generate the tiled implementation of the producer of the source
1121   FailureOr<TilingResult> tileAndFuseResult =
1122       tensor::replaceExtractSliceWithTiledProducer(
1123           rewriter, clonedCandidateSliceOp,
1124           clonedProducerOp->getResult(resultNumber));
1125   if (failed(tileAndFuseResult))
1126     return std::nullopt;
1127   // Note: Do not delete the candidateSliceOp, since its passed in from the
1128   // caller.
1129   rewriter.replaceAllUsesWith(candidateSliceOp,
1130                               tileAndFuseResult->tiledValues[0]);
1131   rewriter.eraseOp(clonedCandidateSliceOp);
1132   rewriter.eraseOp(clonedProducerOp);
1133 
1134   // 3. If the slice is for a destination operand, for example,
1135   //
1136   // ```mlir
1137   // %0 = linalg.init
1138   // %1 = linalg.fill .. outs(%0 : )
1139   // %2 = scf.for .. iter_args(%arg0 = %1) {
1140   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
1141   //     %4 = tensor.extract_slice %arg1 [..]
1142   //     .. = linalg.matmul .. outs(%4 : )
1143   //   }
1144   // }
1145   // ```
1146   //
1147   // the IR is currently
1148   //
1149   // ```
1150   // %0 = linalg.init
1151   // %1 = linalg.fill
1152   // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1153   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
1154   //     %4 = tensor.extract_slice %arg1[..]
1155   //     %5 = linalg.fill .. outs(%4 : )
1156   //     .. = linalg.matmul .. outs(%5 : )
1157   //   }
1158   // }
1159   // ```
1160   //
1161   // The untiled `linalg.fill` is still used as the `init_value` since it
1162   // was originally a destination operand of the untiled `linalg.matmul`.
1163   // When fusing an operand that is a destination operand, the iter_arg of
1164   // the outer most loop should be changed to use the destination of the
1165   // fused operation. With this the IR will be.
1166   //
1167   // ```
1168   // %0 = linalg.init
1169   // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1170   //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
1171   //     %3 = tensor.extract_slice %arg1[..]
1172   //     %4 = linalg.fill .. outs(%3 : )
1173   //     .. = linalg.matmul .. outs(%4 : )
1174   //   }
1175   // }
1176   // ```
1177   if (destinationInitArg &&
1178       isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1179     loops.front()
1180         ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1181         .set(origDestinationTensors[resultNumber]);
1182   }
1183   return scf::SCFFuseProducerOfSliceResult{fusableProducer,
1184                                            tileAndFuseResult->tiledValues[0],
1185                                            tileAndFuseResult->tiledOps};
1186 }
1187 
1188 /// Reconstruct the fused producer from within the tiled-and-fused code.
1189 LogicalResult mlir::scf::yieldReplacementForFusedProducer(
1190     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1191     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1192     MutableArrayRef<LoopLikeOpInterface> loops,
1193     ArrayRef<unsigned> yieldResultNumber) {
1194   if (loops.empty())
1195     return success();
1196 
1197   Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1198             *tiledOwner = fusedProducerInfo.tiledOps[0];
1199 
1200   Location loc = originalOwner->getLoc();
1201   // a. collect all init Value to be appended
1202   SmallVector<unsigned> initNumberList =
1203       yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1204                                       0, originalOwner->getNumResults()))
1205                                 : llvm::to_vector(yieldResultNumber);
1206   SmallVector<Value> initValueList;
1207   for (const auto &resultNumber : initNumberList) {
1208     FailureOr<Value> initValue = tensor::getOrCreateDestination(
1209         rewriter, loc, originalOwner->getResult(resultNumber));
1210     if (succeeded(initValue)) {
1211       initValueList.push_back(initValue.value());
1212     } else {
1213       return failure();
1214     }
1215   }
1216 
1217   YieldTiledValuesFn newYieldValuesFn =
1218       [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1219           ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1220           SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
1221           SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1222     OpBuilder::InsertionGuard g(innerRewriter);
1223 
1224     // get sliceOp tile information
1225     SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1226                               sliceSizes = sliceOp.getMixedSizes();
1227 
1228     // expect all strides of sliceOp being 1
1229     if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
1230           return !isConstantIntValue(ofr, 1);
1231         }))
1232       return failure();
1233 
1234     unsigned sliceResultNumber =
1235         fusedProducerInfo.origProducer.getResultNumber();
1236 
1237     auto tilableOp = cast<TilingInterface>(originalOwner);
1238     // b. get iterDomain Offset and Sizes based on sliceOp tile
1239     SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1240     // skip tensor.pack/unpack/pad, which expects single opResult
1241     if (tilableOp->getNumResults() > 1 &&
1242         failed(tilableOp.getIterationDomainTileFromResultTile(
1243             rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1244             iterDomainOffset, iterDomainSizes))) {
1245       // In theory, it is unnecessary to raise an error here. Actually although
1246       // it fails to reconstruct the result tensor, it should not broke current
1247       // fusion anyway. The reason why we must return failure currently is that
1248       // the callback function `newYieldValuesFn` will be called after new init
1249       // operand(s) has already been appended. It will take more refactoring to
1250       // make sure the init operands are added consistently in the future. For
1251       // more details, please refer to:
1252       // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1253       return failure();
1254     }
1255 
1256     // c. calculate offsets and sizes info of all OpResults respectively based
1257     // on iteration Domain Tile
1258     SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1259     for (const auto &resultNumber : initNumberList) {
1260       if (resultNumber == sliceResultNumber) {
1261         offsetList.push_back(sliceOffset);
1262         sizesList.push_back(sliceSizes);
1263       } else {
1264         assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1265         // infer result tile according to the iteration domain tile
1266         SmallVector<OpFoldResult> offset, sizes;
1267         if (failed(tilableOp.getResultTilePosition(
1268                 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1269                 offset, sizes))) {
1270           return failure();
1271         }
1272         offsetList.push_back(offset);
1273         sizesList.push_back(sizes);
1274       }
1275     }
1276 
1277     // d. create `extract_slice` for `iter_args` for DPS operation if necessary
1278     if (auto tiledDestStyleOp =
1279             dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1280       rewriter.setInsertionPoint(tiledDestStyleOp);
1281       for (const auto &&[index, newRegionArg] :
1282            llvm::enumerate(newRegionIterArgs)) {
1283         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1284             loc, newRegionArg, offsetList[index], sizesList[index],
1285             SmallVector<OpFoldResult>(offsetList[index].size(),
1286                                       rewriter.getIndexAttr(1)));
1287         unsigned resultNumber = initNumberList[index];
1288         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1289           tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1290         });
1291       }
1292     }
1293 
1294     // e. prepare tiled offset and sizes for later `insert_slice` creation by
1295     // caller
1296     Block *block = rewriter.getInsertionPoint()->getBlock();
1297     rewriter.setInsertionPoint(block->getTerminator());
1298     for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1299       tiledResult.push_back(tiledOwner->getResult(resultNumber));
1300       tiledOffset.emplace_back(offsetList[index]);
1301       tiledSizes.emplace_back(sizesList[index]);
1302     }
1303     return success();
1304   };
1305 
1306   return addInitOperandsToLoopNest(rewriter, loops, initValueList,
1307                                    newYieldValuesFn);
1308 }
1309 
1310 /// Implementation of tile consumer and fuse producer greedily.
1311 FailureOr<scf::SCFTileAndFuseResult>
1312 mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1313     RewriterBase &rewriter, TilingInterface consumer,
1314     const scf::SCFTileAndFuseOptions &options) {
1315   // This transformation is only valid for ops that return values (i.e. not
1316   // valid to use with operations that have memref operands).
1317   if (!consumer->getNumResults()) {
1318     return rewriter.notifyMatchFailure(
1319         consumer, "invalid pattern for op with no results");
1320   }
1321 
1322   // 1. First tile the consumer.
1323   SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1324   llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1325 
1326   FailureOr<scf::SCFTilingResult> tilingResult =
1327       tileUsingSCF(rewriter, consumer, options.tilingOptions);
1328 
1329   if (failed(tilingResult))
1330     return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1331   for (auto *tiledOp : tilingResult->tiledOps)
1332     tiledAndFusedOps.insert(tiledOp);
1333 
1334   // If there are no loops generated, fusion is immaterial.
1335   auto &loops = tilingResult->loops;
1336   if (loops.empty()) {
1337     DenseMap<Value, Value> replacements;
1338     for (auto [origVal, replacement] :
1339          llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1340       replacements[origVal] = replacement;
1341     }
1342     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1343                                      replacements};
1344   }
1345 
1346   // To keep track of replacements for now just record the map from the original
1347   // untiled value to the result number of the for loop. Since the loop gets
1348   // potentially replaced during fusion, keeping the value directly wont work.
1349   DenseMap<Value, size_t> origValToResultNumber;
1350   for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
1351     origValToResultNumber[result] = index;
1352   }
1353 
1354   // 2. Typically, the operands of the tiled operation are slices of the
1355   //    operands of the untiled operation. These are expressed in IR using
1356   //    `tensor.extract_slice` operations with source being the operands of the
1357   //    untiled operation. Create a worklist of these `tensor.extract_slice`
1358   //    operations. If the producers of the source of the `tensor.extract_slice`
1359   //    can be tiled such that the tiled value is generated in-place, that
1360   //    effectively tiles + fuses the operations.
1361   auto addCandidateSlices = [](Operation *fusedOp,
1362                                std::deque<tensor::ExtractSliceOp> &candidates) {
1363     for (Value operand : fusedOp->getOperands())
1364       if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
1365         candidates.push_back(sliceOp);
1366   };
1367 
1368   std::deque<tensor::ExtractSliceOp> candidates;
1369   addCandidateSlices(tiledAndFusedOps.back(), candidates);
1370   OpBuilder::InsertionGuard g(rewriter);
1371   while (!candidates.empty()) {
1372     // Traverse the slices in BFS fashion.
1373     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
1374     candidates.pop_front();
1375 
1376     // Find the original producer of the slice.
1377     auto [fusableProducer, destinationInitArg] =
1378         getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1379                                           loops);
1380     if (!fusableProducer)
1381       continue;
1382 
1383     auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
1384         candidateSliceOp, fusableProducer, destinationInitArg.has_value());
1385     if (!fuseSlice)
1386       continue;
1387 
1388     // The operands of the fused producer might themselved be slices of
1389     // values produced by operations that implement the `TilingInterface`.
1390     // Add these operations to the worklist.
1391     std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1392         tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
1393     if (!fusedResult)
1394       continue;
1395 
1396     if (yieldReplacement) {
1397       // Reconstruct and yield all opResult of fusableProducerOp by default. The
1398       // caller can specific which one to yield by designating optional argument
1399       // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
1400       Operation *fusableProducerOp = fusableProducer.getOwner();
1401       if (failed(yieldReplacementForFusedProducer(
1402               rewriter, candidateSliceOp, fusedResult.value(), loops))) {
1403         return rewriter.notifyMatchFailure(
1404             fusableProducerOp, "failed to replacement value for this "
1405                                "operation from within the tiled loop");
1406       }
1407       for (auto [index, result] :
1408            llvm::enumerate(fusableProducerOp->getResults())) {
1409         origValToResultNumber[result] = loops.front()->getNumResults() -
1410                                         fusableProducerOp->getNumResults() +
1411                                         index;
1412       }
1413     }
1414 
1415     if (Operation *tiledAndFusedOp =
1416             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1417       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1418       tiledAndFusedOps.insert(tiledAndFusedOp);
1419       addCandidateSlices(tiledAndFusedOp, candidates);
1420     }
1421   }
1422 
1423   DenseMap<Value, Value> replacements;
1424   for (auto [origVal, resultNumber] : origValToResultNumber) {
1425     replacements[origVal] = loops.front()->getResult(resultNumber);
1426   }
1427 
1428   return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1429                                    replacements};
1430 }
1431 
1432 //===----------------------------------------------------------------------===//
1433 // tileAndFuseConsumerUsingSCF implementation.
1434 //===----------------------------------------------------------------------===//
1435 
1436 /// A utility function that checks whether the only use of the result of a
1437 /// tensor.insert_slice op is in a scf.yield op.
1438 static LogicalResult
1439 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1440   Value result = candidateSliceOp.getResult();
1441   Value::use_range uses = result.getUses();
1442   if (!llvm::hasSingleElement(uses)) {
1443     LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1444     return failure();
1445   }
1446   OpOperand &operandUse = (*uses.begin());
1447   Operation *userOp = operandUse.getOwner();
1448   if (!isa<scf::YieldOp>(userOp)) {
1449     LLVM_DEBUG(llvm::dbgs()
1450                << "Expected scf.yield to be the only user, but got -> "
1451                << (*userOp));
1452     return failure();
1453   }
1454   if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1455     LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1456                                "be in the same block\n");
1457     return failure();
1458   }
1459   return success();
1460 }
1461 
1462 /// Fetches the OpOperand of the only user (and use) of the value `val` which
1463 /// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
1464 /// failure otherwise.
1465 static FailureOr<OpOperand *> getConsumerFromUses(Value val,
1466                                                   Block *containingOpBlock) {
1467   // Step 1. Check that the value has exactly one use.
1468   if (!llvm::hasSingleElement(val.getUses()))
1469     return failure();
1470   // Step 2. Get uses.
1471   OpOperand &operand = (*val.getUses().begin());
1472   Operation *consumerOp = operand.getOwner();
1473   // TODO: We have to init result of consumer before scf.for, use
1474   //       DestinationStyleOpInterface to get result shape from init for now.
1475   //       Add support for other op such as op has InferTypeOpInterface.
1476   if (!isa<TilingInterface>(consumerOp) ||
1477       !isa<DestinationStyleOpInterface>(consumerOp))
1478     return failure();
1479   if (containingOpBlock != consumerOp->getBlock())
1480     return failure();
1481   return &operand;
1482 }
1483 
1484 /// Find the perfectly nested loops outside of given loop(included) sorted from
1485 /// outer to inner.
1486 ///
1487 /// E.g.
1488 ///
1489 /// ```
1490 ///  %0 = scf.for()
1491 ///    %1 = scf.for()
1492 ///      %2 = scf.for()
1493 ///         %3 = ...
1494 ///         yield %3
1495 ///      yield %2
1496 ///    yield %1
1497 /// ```
1498 ///
1499 /// This function will return three perfectly nested loops: %0 + %1 + %2, when
1500 /// target inner loop is %2.
1501 static SmallVector<scf::ForOp>
1502 getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
1503   SmallVector<scf::ForOp> nestLoops = {loop};
1504   auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1505 
1506   // Check if it is the ForOp that yield the result of inner loop.
1507   auto isForOpYieldResultOfInnerLoop =
1508       [](scf::ForOp outerLoop) -> LogicalResult {
1509     Block *body = outerLoop.getBody();
1510     if (!llvm::hasSingleElement(body->without_terminator()))
1511       return failure();
1512     auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
1513     auto innerForOp = dyn_cast<scf::ForOp>(body->front());
1514     if (!innerForOp)
1515       return failure();
1516     // All of innerForOp results should be yielded.
1517     return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1518   };
1519 
1520   while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1521     nestLoops.push_back(outerLoop);
1522     outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1523   }
1524   // sorted from outer to inner
1525   return {nestLoops.rbegin(), nestLoops.rend()};
1526 }
1527 
1528 /// Fetch the untiled consumer of a scf.for's result which is yielded by a
1529 /// tensor.insert_slice. This function makes the following assumptions :
1530 /// 1.  tensor.insert_slice has scf.yield as its only user.
1531 /// 2.  scf.for's corresponding result has only one use.
1532 static FailureOr<OpOperand *>
1533 getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
1534   if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
1535     return failure();
1536   Value sliceResult = candidateSliceOp.getResult();
1537   // Step 1. Fetch the corresponding output.
1538   OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1539   unsigned resultNumber = yieldOpOperand.getOperandNumber();
1540   // Step 2. Check containing op is scf.for.
1541   Operation *containingOp = candidateSliceOp->getParentOp();
1542   auto forOp = dyn_cast<scf::ForOp>(containingOp);
1543   if (!forOp)
1544     return failure();
1545   scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
1546   Value resultingValue = topLevelForOp->getResult(resultNumber);
1547 
1548   return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
1549 }
1550 
1551 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
1552 /// by a tensor.parallel_insert_slice.
1553 static FailureOr<OpOperand *>
1554 getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
1555   // Step 1. Fetch the corresponding output
1556   Value sliceDest = candidateSliceOp.getDest();
1557   auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1558   if (!iterArg)
1559     return failure();
1560   Operation *containingOp = iterArg.getOwner()->getParentOp();
1561   if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
1562     return failure();
1563   // Step 2. Check that the containing op is scf.forall.
1564   auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1565   if (!forallOp)
1566     return failure();
1567   Value resultingValue =
1568       forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
1569 
1570   return getConsumerFromUses(resultingValue, containingOp->getBlock());
1571 }
1572 
1573 /// This utility currently checks whether the loop either :-
1574 /// 1. Yields exactly one result.
1575 /// 2. Has consumer op as its first user and other users to be in the same
1576 /// containing block as that of consumer op's. Currently we clone the loop op
1577 /// right before the consumer op in order to maintain a valid def-use chain.
1578 /// This utility thus helps ensuring that no invalid IR is formed due to the
1579 /// same.
1580 static LogicalResult checkAssumptionForLoop(Operation *loopOp,
1581                                             Operation *consumerOp) {
1582   // Check if the loop op yields one result.
1583   if (loopOp->getNumResults() == 1)
1584     return success();
1585   // Check if the consumerOp is the first user of the loopOp and if other users
1586   // are in the same containing block as that of consumer op's.
1587   Block *parentBlock = consumerOp->getBlock();
1588   for (Operation *userOp : loopOp->getUsers()) {
1589     if (userOp == consumerOp)
1590       continue;
1591     if (parentBlock != userOp->getBlock() ||
1592         !consumerOp->isBeforeInBlock(userOp))
1593       return failure();
1594   }
1595   return success();
1596 }
1597 
1598 /// A utility to fetch an untiled consumer of
1599 /// tensor.insert_slice/tensor.parallel_insert_slice.
1600 static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
1601   if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1602     return getUntiledConsumerFromSlice(insertSlice);
1603   } else if (auto parallelInsertSlice =
1604                  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1605     return getUntiledConsumerFromSlice(parallelInsertSlice);
1606   } else {
1607     return failure();
1608   }
1609 }
1610 
1611 /// Implementation of fusing consumer of a single slice by computing the
1612 /// slice of the consumer in-place for scf loop.
1613 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1614 mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1615                                       Operation *candidateSliceOp) {
1616   if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1617           candidateSliceOp))
1618     return failure();
1619 
1620   bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1621 
1622   // 1. Get the consumer of scf.for for the result yielded by
1623   // tensor.insert_slice/parallel_insert_slice.
1624   FailureOr<OpOperand *> maybeConsumerOpOperand =
1625       getUntiledConsumerFromSlice(candidateSliceOp);
1626   if (failed(maybeConsumerOpOperand)) {
1627     return rewriter.notifyMatchFailure(candidateSliceOp,
1628                                        "could not fetch consumer to fuse");
1629   }
1630   OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1631   Operation *consumerOp = consumerOpOperand->getOwner();
1632   unsigned operandNumber = consumerOpOperand->getOperandNumber();
1633   unsigned resultNumber = 0;
1634   if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
1635     resultNumber = producerResult.getResultNumber();
1636   } else {
1637     return rewriter.notifyMatchFailure(
1638         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
1639   }
1640 
1641   // There are two possible cases regarding `oldLoopOp` here:
1642   // 1. single `scf.forall` or `scf.for`.
1643   // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1644   // top-level loop is the outer-most one of these nested loops.
1645   LoopLikeOpInterface innerMostLoop =
1646       candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
1647   SmallVector<LoopLikeOpInterface> nestedLoops;
1648   if (isInsertSliceOp) {
1649     nestedLoops = llvm::map_to_vector(
1650         getPerfectlyNestedLoopsOutsideOf(
1651             cast<scf::ForOp>(innerMostLoop.getOperation())),
1652         [](scf::ForOp forOp) {
1653           return cast<LoopLikeOpInterface>(forOp.getOperation());
1654         });
1655   } else {
1656     nestedLoops = {innerMostLoop};
1657   }
1658 
1659   LoopLikeOpInterface outerMostLoop = nestedLoops.front();
1660 
1661   if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) {
1662     return rewriter.notifyMatchFailure(
1663         outerMostLoop,
1664         "containing loop op should either yield just one value or "
1665         "have the consumer op as its first user");
1666   }
1667 
1668   OpBuilder::InsertionGuard g(rewriter);
1669 
1670   // 2. Check consumer is not using scf loop's output as init.
1671   auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
1672   if (!dstOp)
1673     return rewriter.notifyMatchFailure(consumerOp,
1674                                        "consumer op is not DPS operation");
1675   SmallVector<Value> dpsInits =
1676       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
1677   if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
1678     return rewriter.notifyMatchFailure(
1679         consumerOp,
1680         "consumer op taking the result of scf.for as init is not supported");
1681   }
1682   SmallVector<Value> newInits = dpsInits;
1683 
1684   Location loc = outerMostLoop->getLoc();
1685 
1686   // 3. Move the whole loop structure right before consumer Op, the dominance
1687   // should be already ensured by `checkAssumptionForLoop`.
1688   rewriter.moveOpBefore(outerMostLoop, consumerOp);
1689 
1690   // 4. Set insertion point before terminator op of the loop and create a new
1691   // tensor.insert_slice. In the scf.for case this is a clone of the
1692   // candidateSliceOp whereas in the scf.forall case this is created from the
1693   // operands of tensor.parallel_insert_slice.
1694   tensor::InsertSliceOp clonedInsertSliceOp;
1695   if (auto sliceOp =
1696           dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1697     auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
1698     rewriter.setInsertionPoint(newForallOp.getTerminator());
1699     clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
1700         loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
1701         sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
1702   } else {
1703     rewriter.setInsertionPoint(candidateSliceOp);
1704     clonedInsertSliceOp =
1705         cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
1706   }
1707 
1708   // 5.a. Clone consumer op.
1709   auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
1710 
1711   // 5.b. Replace all uses of the loop result with the result of the cloned
1712   // tensor.insert_slice.
1713   OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
1714   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
1715     operandToReplace.set(clonedInsertSliceOp.getResult());
1716   });
1717 
1718   // 6. Perform tiling of the cloned consumer and replace the operand at
1719   // `operandNumber` with the source of the cloned tensor.insert_slice op.
1720   auto ossSliceOp =
1721       cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
1722   FailureOr<TilingResult> tileAndFuseResult =
1723       tensor::replaceInsertSliceWithTiledConsumer(
1724           rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
1725   if (failed(tileAndFuseResult)) {
1726     return failure();
1727   }
1728   auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
1729   rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
1730                               clonedInsertSliceOp.getSource());
1731 
1732   // 7. Reconstruct [nested] loop with new inits.
1733   YieldTiledValuesFn newYieldValuesFn =
1734       [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1735           ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1736           SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
1737           SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1738     OpBuilder::InsertionGuard g(innerRewriter);
1739     // 8. Set inner insertPoint right before tiled consumer op.
1740     innerRewriter.setInsertionPoint(tiledConsumerOp);
1741 
1742     SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
1743     SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
1744     SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
1745 
1746     // 9. Check all insert stride is 1.
1747     if (llvm::any_of(strides, [](OpFoldResult stride) {
1748           return !isConstantIntValue(stride, 1);
1749         })) {
1750       return rewriter.notifyMatchFailure(
1751           candidateSliceOp, "containingOp's result yield with stride");
1752     }
1753 
1754     // 10. Try to get iter domain position from input position.
1755     SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
1756     if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile(
1757             rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
1758             iterDomainSizes))) {
1759       return rewriter.notifyMatchFailure(
1760           tiledConsumerOp,
1761           "can't get iter domain position from input position");
1762     }
1763 
1764     // 11. Try to fetch the offset and size for all results of the cloned
1765     // consumer. This would then be used to form the corresponding
1766     // tensor.insert_slice/parallel_insert_slice later.
1767     unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
1768     SmallVector<SmallVector<OpFoldResult>> resultOffsets(
1769         totalNumResultsOfConsumer);
1770     SmallVector<SmallVector<OpFoldResult>> resultSizes(
1771         totalNumResultsOfConsumer);
1772     for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
1773       if (failed(tiledConsumerOp.getResultTilePosition(
1774               rewriter, idx, iterDomainOffsets, iterDomainSizes,
1775               resultOffsets[idx], resultSizes[idx]))) {
1776         return rewriter.notifyMatchFailure(
1777             tiledConsumerOp,
1778             "can't get result domain position from iter domain position");
1779       }
1780     }
1781 
1782     // 12. Create `extract_slice` for `iter_args` for DPS operation if
1783     // necessary.
1784     if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
1785             tiledConsumerOp.getOperation())) {
1786       rewriter.setInsertionPoint(tiledDestStyleOp);
1787       for (const auto &&[index, newRegionArg] :
1788            llvm::enumerate(newRegionIterArgs)) {
1789         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1790             loc, newRegionArg, resultOffsets[index], resultSizes[index],
1791             SmallVector<OpFoldResult>(resultOffsets[index].size(),
1792                                       rewriter.getIndexAttr(1)));
1793         // Make a copy of index to avoid a capturing structured binding, which
1794         // is a C++20 extension.
1795         auto dstNumber = index;
1796         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1797           tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
1798         });
1799       }
1800     }
1801 
1802     // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
1803     // caller.
1804     Block *block = rewriter.getInsertionPoint()->getBlock();
1805     rewriter.setInsertionPoint(block->getTerminator());
1806     for (const auto &&[index, result] :
1807          llvm::enumerate(tiledConsumerOp->getResults())) {
1808       tiledResult.push_back(result);
1809       tiledOffset.emplace_back(resultOffsets[index]);
1810       tiledSizes.emplace_back(resultSizes[index]);
1811     }
1812     return success();
1813   };
1814   // 14. Add new inits to [nested] loops.
1815   if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
1816                                        newYieldValuesFn))) {
1817     return rewriter.notifyMatchFailure(tiledConsumerOp,
1818                                        "unable to add new inits to nest loop");
1819   }
1820 
1821   // 15. Replace the result of scf loop and consumer op with new loop's results.
1822 
1823   for (auto &&[oldResult, newResult] : llvm::zip(
1824            consumerOp->getResults(),
1825            nestedLoops.front()->getResults().take_back(newInits.size()))) {
1826     rewriter.replaceAllUsesWith(oldResult, newResult);
1827   }
1828 
1829   // 16. Need to erase the old scf loop and the cloned consumer op.
1830   rewriter.eraseOp(clonedConsumerOp);
1831 
1832   return scf::SCFFuseConsumerOfSliceResult{
1833       consumerOpOperand,
1834       &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
1835       tileAndFuseResult->tiledOps};
1836 }
1837 
1838 //===----------------------------------------------------------------------===//
1839 // lowerToLoopsUsingSCFForOp implementation.
1840 //===----------------------------------------------------------------------===//
1841 
1842 FailureOr<SmallVector<scf::ForOp>>
1843 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
1844                                      TilingInterface op) {
1845   // TODO: Handle cases where the op has results if needed.
1846   if (op->getNumResults() > 0) {
1847     return rewriter.notifyMatchFailure(
1848         op, "unable to lower to loops operations with return values");
1849   }
1850 
1851   SmallVector<Range> domain = op.getIterationDomain(rewriter);
1852   SmallVector<Value> ivs;
1853   SmallVector<scf::ForOp> loops;
1854   Location loc = op.getLoc();
1855   for (auto loopRange : domain) {
1856     Value offsetVal =
1857         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
1858     Value sizeVal =
1859         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
1860     Value strideVal =
1861         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
1862     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
1863                                             strideVal, ValueRange{});
1864     loops.push_back(loop);
1865     ivs.push_back(loop.getInductionVar());
1866     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
1867   }
1868   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
1869     return failure();
1870   }
1871   return loops;
1872 }
1873