xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision d5f0969c96224a44062715751da3c369ce5ea3f8)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/IR/Dominance.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
26 #include "mlir/Interfaces/TilingInterface.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 #include <optional>
30 
31 #define DEBUG_TYPE "tile-using-interface"
32 
33 using namespace mlir;
34 
35 scf::SCFTilingOptions &
36 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
37   assert(!tileSizeComputationFunction && "tile sizes already set");
38   auto tileSizes = llvm::to_vector(ts);
39   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
40     return tileSizes;
41   };
42   return *this;
43 }
44 
45 scf::SCFTilingOptions &
46 scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
47   assert(!numThreadsComputationFunction && "num tiles already set");
48   auto numThreads = llvm::to_vector(nt);
49   numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
50     return numThreads;
51   };
52   return *this;
53 }
54 
55 /// Helper method to adjust the interchange vector to match the iteration
56 /// domain.
57 static SmallVector<int64_t>
58 fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
59                       size_t iterationDomainSize) {
60   SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
61   if (filledVector.size() < iterationDomainSize) {
62     auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
63     filledVector.append(range.begin(), range.end());
64   }
65   if (filledVector.size() > iterationDomainSize)
66     filledVector.resize(iterationDomainSize);
67   return filledVector;
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // tileUsingSCF implementation.
72 //===----------------------------------------------------------------------===//
73 
74 /// Verify the tile size options are set in a consistent manner.
75 static LogicalResult
76 verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
77                       const scf::SCFTilingOptions &options) {
78   // Specifying number of threads is only supported on `scf.forall` op.
79   if (options.numThreadsComputationFunction &&
80       options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
81     return rewriter.notifyMatchFailure(
82         loc, "number of threads can only by specified when loop type is "
83              "set to use `scf.forall`");
84   }
85 
86   // If specified, check that the interchange vector is a permutation.
87   if (!options.interchangeVector.empty()) {
88     if (!isPermutationVector(options.interchangeVector)) {
89       return rewriter.notifyMatchFailure(
90           loc, "invalid interchange vector, not a permutation of the entire "
91                "iteration space");
92     }
93   }
94   return success();
95 }
96 
97 /// Method to instantiate the tile sizes and/or number of threads specified
98 /// by the user.
99 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
100 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
101                               ArrayRef<Range> iterationDomain,
102                               const scf::SCFTilingOptions &options) {
103   OpFoldResult zero = rewriter.getIndexAttr(0);
104   SmallVector<OpFoldResult> tileSizes, numThreads;
105   size_t numLoops = iterationDomain.size();
106 
107   // Check whether the number of tiles to use is specified.
108   if (options.numThreadsComputationFunction) {
109     numThreads = options.numThreadsComputationFunction(rewriter, op);
110     numThreads.resize(numLoops, zero);
111 
112     // If the number of tiles is also specified, use that.
113     if (options.tileSizeComputationFunction) {
114       tileSizes = options.tileSizeComputationFunction(rewriter, op);
115       tileSizes.resize(numLoops, zero);
116       return {tileSizes, numThreads};
117     }
118 
119     // Compute the tile sizes from the iteration domain and number
120     // of tiles as follows
121     // - niters = ceilDiv(ub - lb, step)
122     // - tileSize = ceilDiv(niters, numThreads)
123     AffineExpr s0, s1, s2;
124     bindSymbols(rewriter.getContext(), s0, s1, s2);
125     // TODO: The step here is assumed to be 1.
126     AffineExpr numItersExpr = (s1 - s0);
127     AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
128     tileSizes.resize(numLoops, zero);
129     for (auto [index, range, nt] :
130          llvm::enumerate(iterationDomain, numThreads)) {
131       if (isConstantIntValue(nt, 0))
132         continue;
133 
134       tileSizes[index] = affine::makeComposedFoldedAffineApply(
135           rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
136     }
137     tileSizes.resize(numLoops, zero);
138     return {tileSizes, numThreads};
139   }
140 
141   // Enforce the convention that "tiling by zero"
142   // skips tiling a particular dimension. This convention is significantly
143   // simpler to handle instead of adjusting affine maps to account for missing
144   // dimensions.
145   assert(options.tileSizeComputationFunction &&
146          "expected tile sizes to be specified");
147   tileSizes = options.tileSizeComputationFunction(rewriter, op);
148   tileSizes.resize(numLoops, zero);
149 
150   return {tileSizes, numThreads};
151 }
152 
153 /// Checks if any of the tiled loops are not parallel.
154 static void checkSafeToTileToForall(TilingInterface op,
155                                     ArrayRef<OpFoldResult> tileSizes,
156                                     ArrayRef<OpFoldResult> numThreads) {
157   auto iterators = op.getLoopIteratorTypes();
158   assert(iterators.size() == tileSizes.size() &&
159          "expected as many tile size values as number of loops");
160   assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
161          "when specified, expected number of threads to use for each loop");
162 
163   for (auto [index, iterator, tileSize] :
164        llvm::enumerate(iterators, tileSizes)) {
165     // If num threads is specified, check that it is greater than one only for
166     // parallel dimensions.
167     if (!numThreads.empty()) {
168       if (std::optional<int64_t> constNumThreads =
169               getConstantIntValue(numThreads[index])) {
170         if (constNumThreads.value() > 1 &&
171             iterator != utils::IteratorType::parallel) {
172           op.emitWarning() << "tiling is not thread safe at axis #" << index;
173         }
174       }
175       continue;
176     }
177 
178     if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
179       if (constTileSize.value() > 0 &&
180           iterator != utils::IteratorType::parallel) {
181         op.emitWarning() << "tiling is not thread safe at axis #" << index;
182       }
183     }
184   }
185 }
186 
187 /// Check if `stride` evenly divides the trip count `size - offset`.
188 static bool tileDividesIterationDomain(Range loopRange) {
189   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
190   if (!offsetAsInt)
191     return false;
192   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
193   if (!sizeAsInt)
194     return false;
195   std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
196   if (!strideAsInt)
197     return false;
198   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
199 }
200 
201 /// Returns the bounded tile size given the current `offset`, `loopRange` and
202 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
203 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
204                                        Range loopRange, OpFoldResult offset,
205                                        OpFoldResult tileSize) {
206   std::optional<int64_t> ts = getConstantIntValue(tileSize);
207   if (ts && ts.value() == 1)
208     return tileSize;
209 
210   if (tileDividesIterationDomain(
211           Range{loopRange.offset, loopRange.size, tileSize}))
212     return tileSize;
213 
214   // The tile size to use (to avoid out of bounds access) is  minimum of
215   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
216   // loop.
217   AffineExpr s0, s1, d0;
218   bindDims(b.getContext(), d0);
219   bindSymbols(b.getContext(), s0, s1);
220   AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
221   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
222   return affine::makeComposedFoldedAffineMin(
223       b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
224 }
225 
226 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
227 /// than `iterationSize`.
228 static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
229                                            OpFoldResult numThreads,
230                                            OpFoldResult iterationSize) {
231   std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
232   std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
233   std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
234   if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
235     return false;
236   return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
237 }
238 
239 /// Compute the `OpFoldResult`s that represents the multi-dimensional
240 /// `offset`s and `size`s of the tile of the iteration space that the
241 /// innermost loop body of the generated tiled loops corresponds to.
242 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
243 getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
244                       ArrayRef<Range> iterationDomain,
245                       ArrayRef<OpFoldResult> tileSizes,
246                       ArrayRef<OpFoldResult> numThreads) {
247   SmallVector<OpFoldResult> offsets, sizes;
248   int materializedLoopNum = 0;
249 
250   if (!numThreads.empty()) {
251     AffineExpr d0, d1, s0, s1;
252     AffineExpr offsetExpr, residualTileSizeExpr;
253     bindDims(rewriter.getContext(), d0, d1);
254     bindSymbols(rewriter.getContext(), s0, s1);
255     offsetExpr = d0 + d1 * s0;
256     residualTileSizeExpr = s1 - (d0 + d1 * s0);
257 
258     for (auto [nt, tileSize, loopRange] :
259          llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
260 
261       // Non-tiled cases, set the offset and size to the
262       // `loopRange.offset/size`.
263       if (isConstantIntValue(nt, 0)) {
264         offsets.push_back(loopRange.offset);
265         sizes.push_back(loopRange.size);
266         continue;
267       }
268 
269       Value iv = ivs[materializedLoopNum++];
270       OpFoldResult offset = affine::makeComposedFoldedAffineApply(
271           rewriter, loc, offsetExpr,
272           ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
273       OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
274           rewriter, loc, residualTileSizeExpr,
275           {loopRange.offset, nt, tileSize, loopRange.size});
276 
277       OpFoldResult size = tileSize;
278       if (!isConstantIntValue(residualTileSize, 0)) {
279         OpFoldResult sizeMinusOffsetPerThread =
280             affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
281                                                   {offset, loopRange.size});
282         size = affine::makeComposedFoldedAffineMin(
283             rewriter, loc,
284             AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
285             {sizeMinusOffsetPerThread, tileSize});
286       }
287 
288       // Consider the case where the original loop was `[0, 100)`.
289       // If number of threads are `7`, the tile size would be computed as
290       // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
291       // - `offset = 0 + 6 * 15 = 105`
292       // - `tileSize = min(15, 100 - 105) = -5`
293       // To avoid negative tile sizes, we need to do a further
294       // `nonNegativeTileSize = affine.max(0, tileSize)`.
295       // This `max` can be avoided if
296       //  `offset + tileSize * (numThreads - 1) < (ub - lb)`
297       if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
298         AffineMap maxMap =
299             AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
300         size = affine::makeComposedFoldedAffineMax(
301             rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
302       }
303 
304       offsets.push_back(offset);
305       sizes.push_back(size);
306     }
307     return {offsets, sizes};
308   } else {
309     for (auto [tileSize, loopRange] :
310          llvm::zip_equal(tileSizes, iterationDomain)) {
311 
312       // Non-tiled cases, set the offset and size to the
313       // `loopRange.offset/size`.
314       if (isConstantIntValue(tileSize, 0)) {
315         offsets.push_back(loopRange.offset);
316         sizes.push_back(loopRange.size);
317         continue;
318       }
319 
320       Value iv = ivs[materializedLoopNum++];
321       OpFoldResult offset = getAsOpFoldResult(iv);
322       offsets.push_back(offset);
323       OpFoldResult size =
324           getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
325       sizes.push_back(size);
326     }
327     return {offsets, sizes};
328   }
329 }
330 
331 /// Function to return the bounds of the loops to be generated.
332 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
333                   SmallVector<OpFoldResult>>
334 getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
335               ArrayRef<OpFoldResult> tileSizes) {
336   SmallVector<OpFoldResult> lbs, ubs, steps;
337   for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
338     // No loop if the tile size is 0.
339     if (isConstantIntValue(tileSize, 0))
340       continue;
341     lbs.push_back(loopRange.offset);
342     ubs.push_back(loopRange.size);
343     steps.push_back(tileSize);
344   }
345   return {lbs, ubs, steps};
346 }
347 
348 /// A function that allows returning additional yielded values during
349 /// `yieldTiledValuesAndReplace`.
350 /// - `ivs` induction variable for the loop.
351 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
352 /// - `tiledValues` the tiled values to return. Must be of same size as
353 ///   `newbbArgs`, each element of this array is inserted into the corresponding
354 ///   element in `newbbArgs`.
355 /// - `resultOffsets` is of the same size as `tiledValues` and represents
356 ///   the offsets to use when inserting corresponding element from `tiledValues`
357 ///   into the element from `newBbArgs`.
358 /// - `resultSizes` is of the same size as `tiledValues` and represents
359 ///   the size of the corresponding element from `tiledValues` inserted into
360 ///   the element from `newBbArgs`.
361 /// In case the method needs to return `failure()` the method is expected
362 /// to clean up any inserted operations.
363 using YieldTiledValuesFn = std::function<LogicalResult(
364     RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
365     SmallVector<Value> &tiledValues,
366     SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
367     SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
368 
369 /// Clones the operation and updates the destination if the operation
370 /// implements the `DestinationStyleOpInterface`.
371 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
372                                                   Operation *op,
373                                                   ValueRange newDestArgs) {
374   Operation *clonedOp = rewriter.clone(*op);
375   if (newDestArgs.empty())
376     return clonedOp;
377   if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
378     destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
379   return clonedOp;
380 }
381 
382 /// Generate the tile-loop nest using `scf.for` operation.
383 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
384 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
385 /// - `destinationTensors` are the init values to use for the outer most loop.
386 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
387 /// most
388 ///    loop.
389 /// - `loops` is an in-out parameter into which the generated loops are
390 ///    populated.
391 static LogicalResult generateLoopNestUsingForOp(
392     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
393     ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
394     YieldTiledValuesFn yieldTiledValuesFn,
395     SmallVector<LoopLikeOpInterface> &loops) {
396   assert(!loopRanges.empty() && "unexpected empty loop ranges");
397   assert(loopRanges.size() == tileSizes.size() &&
398          "expected as many tile sizes as loop ranges");
399   OpBuilder::InsertionGuard guard(rewriter);
400 
401   SmallVector<OpFoldResult> lbs, ubs, steps;
402   std::tie(lbs, ubs, steps) =
403       getLoopBounds(rewriter, loc, loopRanges, tileSizes);
404   SmallVector<Value> lbVals =
405       getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
406   SmallVector<Value> ubVals =
407       getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
408   SmallVector<Value> stepVals =
409       getValueOrCreateConstantIndexOp(rewriter, loc, steps);
410 
411   SmallVector<Value> ivs;
412   for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
413     auto loop =
414         rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
415                                     [](OpBuilder &bodyBuilder, Location bodyLoc,
416                                        Value iv, ValueRange /*iterArgs*/) {});
417     loops.push_back(loop);
418     ivs.push_back(loop.getInductionVar());
419     rewriter.setInsertionPointToEnd(loop.getBody());
420     destinationTensors = loop.getRegionIterArgs();
421   }
422 
423   SmallVector<Value> tiledResults;
424   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
425   if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
426                                 tiledResults, resultOffsets, resultSizes))) {
427     return rewriter.notifyMatchFailure(
428         loc, "failed to generate inner tile loop body");
429   }
430   if (loops.empty())
431     return success();
432 
433   assert(tiledResults.size() == destinationTensors.size() &&
434          "Number of results of body should be equal to number of iter args");
435 
436   // 6. Yield all the results of the tiled operation.
437   SmallVector<Value> yieldedValues;
438   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
439        llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
440                        resultSizes)) {
441     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
442                                            rewriter.getIndexAttr(1));
443     auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
444         loc, tiledValue, destinationTensor, resultOffset, resultSize,
445         resultStride);
446     yieldedValues.push_back(insertSlice);
447   }
448   rewriter.create<scf::YieldOp>(loc, yieldedValues);
449 
450   // Add the scf.yield operations for all the outer loops.
451   for (auto [outerLoop, innerLoop] :
452        llvm::zip_equal(MutableArrayRef(loops).drop_back(),
453                        MutableArrayRef(loops).drop_front())) {
454     rewriter.setInsertionPointToEnd(
455         cast<scf::ForOp>(outerLoop.getOperation()).getBody());
456     rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
457   }
458   return success();
459 }
460 
461 /// Generate the tile-loop nest using `scf.forall` operation.
462 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
463 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
464 /// - `destinationTensors` are the init values to use for the outer most loop.
465 /// - `mappingVector` is the mapping attributes to use for loop construction.
466 ///   Can be empty.
467 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
468 /// most
469 ///    loop.
470 /// - `loops` is an in-out parameter into which the generated loops are
471 ///    populated.
472 static LogicalResult generateLoopNestUsingForallOp(
473     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
474     ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
475     ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
476     YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
477   assert(!loopRanges.empty() && "unexpected empty loop ranges");
478   assert(loopRanges.size() == tileSizes.size() &&
479          "expected as many tile sizes as loop ranges");
480   OpBuilder::InsertionGuard guard(rewriter);
481   SmallVector<OpFoldResult> offsets(loopRanges.size()),
482       sizes(loopRanges.size());
483 
484   std::optional<ArrayAttr> mappingAttr;
485   if (!mappingVector.empty())
486     mappingAttr = rewriter.getArrayAttr(mappingVector);
487 
488   scf::ForallOp forallOp;
489   bool useNumThreads = !numThreads.empty();
490 
491   if (useNumThreads) {
492     // Prune the zero numthreads.
493     SmallVector<OpFoldResult> nonZeroNumThreads;
494     for (auto nt : numThreads) {
495       if (isConstantIntValue(nt, 0))
496         continue;
497       nonZeroNumThreads.push_back(nt);
498     }
499     forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
500                                               destinationTensors, mappingAttr);
501   } else {
502     SmallVector<OpFoldResult> lbs, ubs, steps;
503     std::tie(lbs, ubs, steps) =
504         getLoopBounds(rewriter, loc, loopRanges, tileSizes);
505     forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
506                                               destinationTensors, mappingAttr);
507   }
508   loops.push_back(forallOp);
509 
510   rewriter.setInsertionPoint(forallOp.getTerminator());
511   destinationTensors = forallOp.getRegionOutArgs();
512 
513   SmallVector<Value> tiledResults;
514   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
515   if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
516                          destinationTensors, tiledResults, resultOffsets,
517                          resultSizes)))
518     return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
519 
520   rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
521   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
522        llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
523                        resultSizes)) {
524     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
525                                            rewriter.getIndexAttr(1));
526 
527     rewriter.create<tensor::ParallelInsertSliceOp>(
528         loc, tiledValue, destinationTensor, resultOffset, resultSize,
529         resultStride);
530   }
531   return success();
532 }
533 
534 /// Generate the tile-loop nest using the loop construct specifed in `options`.
535 /// - `options`: Tiling options specified.
536 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
537 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
538 /// - `destinationTensors` are the init values to use for the outer most loop.
539 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
540 /// most
541 ///    loop.
542 /// - `loops` is an in-out parameter into which the generated loops are
543 ///    populated.
544 static LogicalResult generateLoopNest(
545     RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
546     ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
547     ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
548     YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
549   // If the tile sizes are all zero, no loops are generated. Just call the
550   // callback function to handle untiled case.
551   if (llvm::all_of(tileSizes, isZeroIndex)) {
552     SmallVector<Value> tiledResults;
553     SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
554     return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
555                        tiledResults, resultOffsets, resultSizes);
556   }
557   if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
558     return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
559                                       destinationTensors, tiledBodyFn, loops);
560   }
561   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
562     return generateLoopNestUsingForallOp(
563         rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
564         destinationTensors, tiledBodyFn, loops);
565   }
566   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
567 }
568 
569 /// Append the specified additional `newInitOperands` operands to the
570 /// loops existing `init` operands (or similar), and replace `loopOp` with
571 /// the new loop that has the additional init operands. The loop body of
572 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
573 /// is called to get the new tiled values returned, and the offset
574 /// and sizes at which the tiled value is inserted into the
575 /// new region iter_args that correspond to the newly added init operands.
576 template <typename LoopType>
577 FailureOr<LoopLikeOpInterface>
578 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
579                                ValueRange newInitOperands,
580                                YieldTiledValuesFn yieldTiledValuesFn) {
581   return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
582 }
583 
584 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
585 template <>
586 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
587     scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
588     YieldTiledValuesFn yieldTiledValuesFn) {
589   OpBuilder::InsertionGuard g(rewriter);
590   Location loc = loopOp.getLoc();
591   rewriter.setInsertionPoint(loopOp);
592 
593   auto inits = llvm::to_vector(loopOp.getInitArgs());
594   inits.append(newInitOperands.begin(), newInitOperands.end());
595   auto newLoop = rewriter.create<scf::ForOp>(
596       loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
597       inits, [](OpBuilder &, Location, Value, ValueRange) {});
598 
599   // Move the loop body to the new op.
600   Block *loopBody = loopOp.getBody();
601   Block *newLoopBody = newLoop.getBody();
602   rewriter.mergeBlocks(
603       loopBody, newLoopBody,
604       newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
605 
606   auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
607   rewriter.setInsertionPoint(yieldOp);
608 
609   SmallVector<Value> tiledValues;
610   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
611   ValueRange newRegionIterArgs =
612       newLoop.getRegionIterArgs().take_back(newInitOperands.size());
613   if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
614                                 newRegionIterArgs, tiledValues, resultOffsets,
615                                 resultSizes))) {
616     rewriter.eraseOp(newLoop);
617     return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
618   }
619 
620   SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
621   for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
622        llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
623                        resultSizes)) {
624     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
625                                            rewriter.getIndexAttr(1));
626     Value insert = rewriter.create<tensor::InsertSliceOp>(
627         yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
628         resultStride);
629     newYieldValues.push_back(insert);
630   }
631 
632   rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
633   rewriter.replaceOp(loopOp,
634                      newLoop->getResults().take_front(loopOp.getNumResults()));
635   return cast<LoopLikeOpInterface>(newLoop.getOperation());
636 }
637 
638 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
639 template <>
640 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
641     scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
642     YieldTiledValuesFn yieldTiledValuesFn) {
643   OpBuilder::InsertionGuard g(rewriter);
644   Location loc = loopOp.getLoc();
645   rewriter.setInsertionPoint(loopOp);
646   auto inits = llvm::to_vector(loopOp.getOutputs());
647   inits.append(newInitOperands.begin(), newInitOperands.end());
648   auto newLoop = rewriter.create<scf::ForallOp>(
649       loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
650       loopOp.getMixedStep(), inits, loopOp.getMapping(),
651       [](OpBuilder &, Location, ValueRange) {});
652 
653   // Move the region of the current block to the newly created op.
654   Block *loopBody = loopOp.getBody();
655   Block *newLoopBody = newLoop.getBody();
656   rewriter.mergeBlocks(
657       loopBody, newLoopBody,
658       newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
659 
660   auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
661   rewriter.setInsertionPoint(terminator);
662   SmallVector<Value> tiledValues;
663   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
664   ValueRange regionIterArgs =
665       newLoop.getRegionIterArgs().take_back(newInitOperands.size());
666   if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
667                                 regionIterArgs, tiledValues, resultOffsets,
668                                 resultSizes))) {
669     rewriter.eraseOp(newLoop);
670     return rewriter.notifyMatchFailure(loopOp,
671                                        "failed to get yielded tiled values");
672   }
673 
674   // Update the terminator.
675   rewriter.setInsertionPointToEnd(terminator.getBody());
676 
677   for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
678            tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
679     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
680                                            rewriter.getIndexAttr(1));
681     rewriter.create<tensor::ParallelInsertSliceOp>(
682         terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
683         resultStride);
684   }
685 
686   rewriter.replaceOp(loopOp,
687                      newLoop->getResults().take_front(loopOp.getNumResults()));
688   return cast<LoopLikeOpInterface>(newLoop.getOperation());
689 }
690 
691 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
692 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
693 /// supported loop type.
694 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
695     LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
696     ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
697   return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>(
698              loopLikeOp.getOperation())
699       .Case<scf::ForOp, scf::ForallOp>(
700           [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
701             return yieldTiledValuesAndReplaceLoop(
702                 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
703           })
704       .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
705         return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
706       });
707 }
708 
709 /// Method to add new init values to a loop nest. Updates `loops` in-place with
710 /// new loops that use the `newInitValues`.
711 /// The outer-loops are updated to yield the new result values of the inner
712 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get
713 /// the additional values to yield form the innermost loop.
714 static LogicalResult addInitOperandsToLoopNest(
715     RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
716     ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
717   SmallVector<scf::ForOp> newLoops;
718   if (loops.empty())
719     return success();
720   OpBuilder::InsertionGuard g(rewriter);
721   rewriter.setInsertionPoint(loops.front());
722 
723   SmallVector<Value> ivs;
724   for (auto &loop : loops.drop_back()) {
725     rewriter.setInsertionPoint(loop);
726 
727     // if loops.size() > 1 we assume that scf.for is used for the loops.
728     auto forLoop = cast<scf::ForOp>(loop.getOperation());
729 
730     // Create a new loop with the new init values for this loop.
731     SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
732     newInits.append(newInitValues.begin(), newInitValues.end());
733     auto newLoop = rewriter.create<scf::ForOp>(
734         forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
735         forLoop.getStep(), newInits,
736         [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
737 
738     // Merge the body of the new loop with the body of the old loops.
739     SmallVector<Value> sourceBlockArgs;
740     sourceBlockArgs.push_back(newLoop.getInductionVar());
741     auto newRegionIterArgs = newLoop.getRegionIterArgs();
742     sourceBlockArgs.append(
743         newRegionIterArgs.begin(),
744         std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
745     rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
746     rewriter.replaceOp(
747         forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
748     loop = newLoop;
749     ivs.push_back(newLoop.getInductionVar());
750     newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
751   }
752 
753   // Update the loop body of the innermost loop to get new yield values.
754   LoopLikeOpInterface innerMostLoop = loops.back();
755   FailureOr<LoopLikeOpInterface> newInnerMostLoop =
756       yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
757                                      getNewTiledYieldsFn);
758 
759   if (failed(newInnerMostLoop))
760     return innerMostLoop.emitOpError("failed to return additional yields");
761   loops.back() = newInnerMostLoop.value();
762 
763   // Make all other loops except the innermost loops yield the values returned
764   // by the inner loop.
765   for (auto [outerLoop, innerLoop] :
766        llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
767     // Again assume that all the outer loops are scf.for operations.
768     auto outerForLoop = cast<scf::ForOp>(outerLoop);
769     auto outerLoopYield =
770         cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
771     SmallVector<Value> newYields =
772         llvm::to_vector(outerLoopYield.getOperands());
773     ValueRange additionalYields =
774         innerLoop->getResults().take_back(newInitValues.size());
775     newYields.append(additionalYields.begin(), additionalYields.end());
776     rewriter.setInsertionPoint(outerLoopYield);
777     rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
778   }
779   return success();
780 }
781 
782 /// Implementation of tiling transformation of `op` that implements the
783 /// `TilingInterface` using `scf.for` to iterate over the tiles.
784 FailureOr<scf::SCFTilingResult>
785 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
786                         const scf::SCFTilingOptions &options) {
787   if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
788     return failure();
789   }
790 
791   OpBuilder::InsertionGuard guard(rewriter);
792   rewriter.setInsertionPointAfter(op);
793 
794   // 1. Get the range of the loops that are represented by the operation.
795   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
796 
797   // 2. Materialize the tile sizes and/or number of threads;
798   SmallVector<OpFoldResult> tileSizes, numThreads;
799   std::tie(tileSizes, numThreads) =
800       getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
801 
802   // Check if it is safe to tile. This is hold over from previous iterations
803   // of tile to for-all. Consider dropping it.
804   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
805     checkSafeToTileToForall(op, tileSizes, numThreads);
806   }
807 
808   // 3. If there is an interchange specified, permute the iteration domain and
809   // the tile sizes.
810   SmallVector<int64_t> interchangeVector;
811   if (!options.interchangeVector.empty()) {
812     interchangeVector = fillInterchangeVector(options.interchangeVector,
813                                               iterationDomain.size());
814     assert(isPermutationVector(interchangeVector) &&
815            "expected interchange vector to be a permutation");
816 
817     applyPermutationToVector(iterationDomain, interchangeVector);
818     applyPermutationToVector(tileSizes, interchangeVector);
819     if (!numThreads.empty())
820       applyPermutationToVector(numThreads, interchangeVector);
821   }
822 
823   FailureOr<TilingResult> tilingResult;
824   // 4. Define the lambda function used later to generate the body of the
825   // innermost tiled loop.
826   YieldTiledValuesFn innerYieldTiledValuesFn =
827       [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
828           ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
829           SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
830           SmallVector<SmallVector<OpFoldResult>> &resultSizes)
831       -> LogicalResult {
832     // 4a. Compute the `offsets` and `sizes` to use for tiling.
833     SmallVector<OpFoldResult> offsets, sizes;
834     std::tie(offsets, sizes) = getTileOffsetAndSizes(
835         rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
836 
837     // 4b. If interchange was provided, apply inverse of the interchange
838     //     to get back the offsets/sizes in the order to be specified.
839     if (!interchangeVector.empty()) {
840       auto inversePermutation = invertPermutationVector(interchangeVector);
841       applyPermutationToVector(offsets, inversePermutation);
842       applyPermutationToVector(sizes, inversePermutation);
843     }
844 
845     // 5. Generate the tiled implementation within the inner most loop.
846 
847     // 5a. Clone the operation within the loop body.
848     auto clonedOp = cast<TilingInterface>(
849         cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
850 
851     // 5b. Early return cloned op if tiling is not happening. We can not return
852     // the original op because it could lead to
853     // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
854     if (llvm::all_of(tileSizes, isZeroIndex)) {
855       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
856       tilingResult =
857           TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
858                        /*generatedSlices=*/{}};
859       return success();
860     }
861 
862     // 5c. Tile the cloned operation.
863     tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
864     if (failed(tilingResult)) {
865       rewriter.eraseOp(clonedOp);
866       return op.emitOpError("faild to tile operation");
867     }
868 
869     // 5d. Delete the cloned operation.
870     rewriter.eraseOp(clonedOp);
871 
872     // 5e. Compute the offsets at which the result values are to be inserted
873     //     back into its destinations.
874     for (auto [index, tiledValue] :
875          llvm::enumerate(tilingResult->tiledValues)) {
876       tiledResults.push_back(tiledValue);
877       SmallVector<OpFoldResult> resultOffset, resultSize;
878       if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
879                                           resultOffset, resultSize))) {
880         for (auto op : tilingResult->tiledOps) {
881           rewriter.eraseOp(op);
882         }
883         return rewriter.notifyMatchFailure(
884             op, "failed to get slice of result produced");
885       }
886       resultOffsets.emplace_back(std::move(resultOffset));
887       resultSizes.emplace_back(std::move(resultSize));
888     }
889 
890     return success();
891   };
892 
893   // 6. Find the destination tensors to use for the operation.
894   SmallVector<Value> destinationTensors;
895   if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
896                                              destinationTensors))) {
897     return rewriter.notifyMatchFailure(op,
898                                        "unable to create destination tensors");
899   }
900 
901   // 7. Generate the tiled loops nest using the callback defined above.
902   SmallVector<LoopLikeOpInterface> loops;
903   if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
904                               tileSizes, numThreads, destinationTensors,
905                               innerYieldTiledValuesFn, loops)))
906     return op.emitOpError("failed to generate tiling loops");
907   assert(succeeded(tilingResult) &&
908          "expected tiling result to be computed after loop generation");
909 
910   // If loops are empty, the tiled op is used as the replacement for the untiled
911   // op.
912   if (loops.empty()) {
913     return scf::SCFTilingResult{tilingResult->tiledOps, loops,
914                                 tilingResult->tiledValues,
915                                 tilingResult->generatedSlices};
916   }
917 
918   SmallVector<Value> replacements = llvm::map_to_vector(
919       loops.front()->getResults(), [](OpResult r) -> Value { return r; });
920   return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
921                               tilingResult->generatedSlices};
922 }
923 
924 FailureOr<scf::SCFReductionTilingResult>
925 mlir::scf::tileReductionUsingScf(RewriterBase &b,
926                                  PartialReductionOpInterface op,
927                                  ArrayRef<OpFoldResult> tileSizes) {
928   Location loc = op.getLoc();
929   // Ops implementing PartialReductionOpInterface are expected to implement
930   // TilingInterface.
931   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
932   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
933   auto tileSizesVector = llvm::to_vector(tileSizes);
934   if (tileSizesVector.size() < iterationDomain.size()) {
935     auto zero = b.getIndexAttr(0);
936     tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
937                            zero);
938   }
939   SmallVector<utils::IteratorType> iterators =
940       tilingInterfaceOp.getLoopIteratorTypes();
941 
942   SmallVector<int> reductionDims;
943   for (auto [idx, iteratorType] :
944        llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
945     if (iteratorType == utils::IteratorType::reduction)
946       reductionDims.push_back(idx);
947   }
948 
949   // 2. create the inital tensor value.
950   FailureOr<SmallVector<Value>> maybeInitTensors =
951       op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
952                                                   reductionDims);
953   if (failed(maybeInitTensors)) {
954     return b.notifyMatchFailure(op, "Failed to create initial tensors.");
955   }
956   SmallVector<Value> &initTensors = maybeInitTensors.value();
957 
958   // 3. Define the callback to use for generating the inner most tile loop body.
959   SmallVector<Operation *> parallelTiledOps;
960   auto innerYieldTiledValuesFn =
961       [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
962           ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
963           SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
964           SmallVector<SmallVector<OpFoldResult>> &resultSizes)
965       -> LogicalResult {
966     SmallVector<OpFoldResult> offsets, sizes;
967     {
968       int materializedLoopNum = 0;
969       for (auto [tileSize, loopRange] :
970            llvm::zip_equal(tileSizesVector, iterationDomain)) {
971         if (isConstantIntValue(tileSize, 0)) {
972           offsets.push_back(loopRange.offset);
973           sizes.push_back(loopRange.size);
974           continue;
975         }
976         Value iv = ivs[materializedLoopNum++];
977         offsets.push_back(iv);
978         sizes.push_back(
979             getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
980       }
981     }
982 
983     // 4a. Clone the operation.
984     {
985       auto clonedOp = cast<PartialReductionOpInterface>(
986           cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
987 
988       // 4b. Tile the cloned operation.
989       FailureOr<TilingResult> partialTilingResult =
990           clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
991                                           sizes, reductionDims);
992       if (failed(partialTilingResult)) {
993         return failure();
994       }
995       std::swap(parallelTiledOps, partialTilingResult->tiledOps);
996       std::swap(tiledResult, partialTilingResult->tiledValues);
997 
998       // 4c. Delete the cloned operation.
999       b.eraseOp(clonedOp);
1000     }
1001 
1002     // 4d. Compute the offsets and sizes needed to insert the result of the
1003     // tiled value back into destination before yielding the destination.
1004     for (auto result : tiledResult) {
1005       SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
1006       resultOffsets.emplace_back(std::move(outOffsets));
1007 
1008       SmallVector<OpFoldResult> outSizes;
1009       for (size_t i = 0; i < offsets.size(); i++) {
1010         outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
1011       }
1012       resultSizes.emplace_back(std::move(outSizes));
1013     }
1014     return success();
1015   };
1016 
1017   // 5. Generate the tiled implementation using the destination tensors.
1018   SmallVector<LoopLikeOpInterface> loops;
1019   scf::SCFTilingOptions options;
1020   options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
1021   if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
1022                               /*numThreads=*/ArrayRef<OpFoldResult>{},
1023                               initTensors, innerYieldTiledValuesFn, loops)))
1024     return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
1025 
1026   SmallVector<Value> replacements = llvm::map_to_vector(
1027       loops.front()->getResults(), [](OpResult r) -> Value { return r; });
1028 
1029   // 5. Apply the merge reduction to combine all the partial values.
1030   b.setInsertionPointAfter(*loops.begin());
1031   FailureOr<MergeResult> mergeResult =
1032       op.mergeReductions(b, loc, replacements, reductionDims);
1033   if (failed(mergeResult)) {
1034     return failure();
1035   }
1036   b.replaceOp(op, mergeResult->replacements);
1037 
1038   SCFReductionTilingResult reductionTilingResult;
1039   std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
1040   std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
1041   std::swap(reductionTilingResult.initialValues, initTensors);
1042   std::swap(reductionTilingResult.loops, loops);
1043   std::swap(reductionTilingResult.replacements, mergeResult->replacements);
1044 
1045   return reductionTilingResult;
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // tileConsumerAndFuseProducersUsingSCF implementation.
1050 //===----------------------------------------------------------------------===//
1051 
1052 /// Return the untiled producer whose slice is used in a tiled consumer. The
1053 /// method traverses the tile loop nest (`loops`) if needed, and returns the
1054 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
1055 /// indicates that this is a destination operand of the consumer. If there was
1056 /// no loop traversal needed, the second value of the returned tuple is empty.
1057 static std::tuple<OpResult, std::optional<OpOperand *>>
1058 getUntiledProducerFromSliceSource(OpOperand *source,
1059                                   ArrayRef<LoopLikeOpInterface> loops) {
1060   std::optional<OpOperand *> destinationIterArg;
1061   auto loopIt = loops.rbegin();
1062   while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
1063     auto loop = *loopIt;
1064     if (iterArg.getOwner()->getParentOp() != loop)
1065       break;
1066     source = loop.getTiedLoopInit(iterArg);
1067     loopIt++;
1068   }
1069   if (loopIt == loops.rend())
1070     destinationIterArg = source;
1071   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
1072 }
1073 
1074 /// Implementation of fusing producer of a single slice by computing the
1075 /// slice of the producer in-place.
1076 std::optional<scf::SCFFuseProducerOfSliceResult>
1077 mlir::scf::tileAndFuseProducerOfSlice(
1078     RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1079     MutableArrayRef<LoopLikeOpInterface> loops) {
1080   // 1. Get the producer of the source (potentially walking through
1081   // `iter_args` of nested `scf.for`)
1082   auto [fusableProducer, destinationInitArg] =
1083       getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1084                                         loops);
1085   if (!fusableProducer)
1086     return std::nullopt;
1087   unsigned resultNumber = fusableProducer.getResultNumber();
1088 
1089   OpBuilder::InsertionGuard g(rewriter);
1090   rewriter.setInsertionPoint(candidateSliceOp);
1091 
1092   // 2. Clone the fused producer
1093   // 2a. Compute the destination operands to use for the cloned operation.
1094   SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1095   Operation *fusableProducerOp = fusableProducer.getOwner();
1096   if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1097       failed(tensor::getOrCreateDestinations(
1098           rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1099           origDestinationTensors)))
1100     return std::nullopt;
1101 
1102   clonedOpDestinationTensors = origDestinationTensors;
1103   if (destinationInitArg &&
1104       isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1105     // 2b. If the producer is also destination style, then to maintain the
1106     // destination passing style, update the destination of the producer to be
1107     // the source of the slice.
1108     clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1109   }
1110   // 2c. Clone the fused producer.
1111   Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1112       rewriter, fusableProducerOp, clonedOpDestinationTensors);
1113   // 2d. Update the source of the candidateSlice to be the cloned producer.
1114   //     Easier to just clone the slice with different source since replacements
1115   //     and DCE of cloned ops becomes easier
1116   SmallVector<Value> candidateSliceOpOperands =
1117       llvm::to_vector(candidateSliceOp->getOperands());
1118   candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
1119   tensor::ExtractSliceOp clonedCandidateSliceOp =
1120       mlir::clone(rewriter, candidateSliceOp,
1121                   candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1122 
1123   // 3. Generate the tiled implementation of the producer of the source
1124   FailureOr<TilingResult> tileAndFuseResult =
1125       tensor::replaceExtractSliceWithTiledProducer(
1126           rewriter, clonedCandidateSliceOp,
1127           clonedProducerOp->getResult(resultNumber));
1128   if (failed(tileAndFuseResult))
1129     return std::nullopt;
1130   // Note: Do not delete the candidateSliceOp, since its passed in from the
1131   // caller.
1132   rewriter.replaceAllUsesWith(candidateSliceOp,
1133                               tileAndFuseResult->tiledValues[0]);
1134   rewriter.eraseOp(clonedCandidateSliceOp);
1135   rewriter.eraseOp(clonedProducerOp);
1136 
1137   // 3. If the slice is for a destination operand, for example,
1138   //
1139   // ```mlir
1140   // %0 = linalg.init
1141   // %1 = linalg.fill .. outs(%0 : )
1142   // %2 = scf.for .. iter_args(%arg0 = %1) {
1143   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
1144   //     %4 = tensor.extract_slice %arg1 [..]
1145   //     .. = linalg.matmul .. outs(%4 : )
1146   //   }
1147   // }
1148   // ```
1149   //
1150   // the IR is currently
1151   //
1152   // ```
1153   // %0 = linalg.init
1154   // %1 = linalg.fill
1155   // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1156   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
1157   //     %4 = tensor.extract_slice %arg1[..]
1158   //     %5 = linalg.fill .. outs(%4 : )
1159   //     .. = linalg.matmul .. outs(%5 : )
1160   //   }
1161   // }
1162   // ```
1163   //
1164   // The untiled `linalg.fill` is still used as the `init_value` since it
1165   // was originally a destination operand of the untiled `linalg.matmul`.
1166   // When fusing an operand that is a destination operand, the iter_arg of
1167   // the outer most loop should be changed to use the destination of the
1168   // fused operation. With this the IR will be.
1169   //
1170   // ```
1171   // %0 = linalg.init
1172   // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1173   //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
1174   //     %3 = tensor.extract_slice %arg1[..]
1175   //     %4 = linalg.fill .. outs(%3 : )
1176   //     .. = linalg.matmul .. outs(%4 : )
1177   //   }
1178   // }
1179   // ```
1180   if (destinationInitArg &&
1181       isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1182     loops.front()
1183         ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1184         .set(origDestinationTensors[resultNumber]);
1185   }
1186   return scf::SCFFuseProducerOfSliceResult{
1187       fusableProducer, tileAndFuseResult->tiledValues[0],
1188       tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1189 }
1190 
1191 /// Reconstruct the fused producer from within the tiled-and-fused code.
1192 FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1193     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1194     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1195     MutableArrayRef<LoopLikeOpInterface> loops,
1196     ArrayRef<unsigned> yieldResultNumber) {
1197   if (loops.empty())
1198     return success();
1199 
1200   Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1201             *tiledOwner = fusedProducerInfo.tiledOps[0];
1202 
1203   Location loc = originalOwner->getLoc();
1204   // a. collect all init Value to be appended
1205   SmallVector<unsigned> initNumberList =
1206       yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1207                                       0, originalOwner->getNumResults()))
1208                                 : llvm::to_vector(yieldResultNumber);
1209   SmallVector<Value> initValueList;
1210   for (const auto &resultNumber : initNumberList) {
1211     FailureOr<Value> initValue = tensor::getOrCreateDestination(
1212         rewriter, loc, originalOwner->getResult(resultNumber));
1213     if (succeeded(initValue)) {
1214       initValueList.push_back(initValue.value());
1215     } else {
1216       return failure();
1217     }
1218   }
1219 
1220   SmallVector<Operation *> generatedSlices;
1221   YieldTiledValuesFn newYieldValuesFn =
1222       [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1223           ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1224           SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
1225           SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1226     OpBuilder::InsertionGuard g(innerRewriter);
1227 
1228     // get sliceOp tile information
1229     SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1230                               sliceSizes = sliceOp.getMixedSizes();
1231 
1232     // expect all strides of sliceOp being 1
1233     if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
1234           return !isConstantIntValue(ofr, 1);
1235         }))
1236       return failure();
1237 
1238     unsigned sliceResultNumber =
1239         fusedProducerInfo.origProducer.getResultNumber();
1240 
1241     auto tilableOp = cast<TilingInterface>(originalOwner);
1242     // b. get iterDomain Offset and Sizes based on sliceOp tile
1243     SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1244     // skip tensor.pack/unpack/pad, which expects single opResult
1245     if (tilableOp->getNumResults() > 1 &&
1246         failed(tilableOp.getIterationDomainTileFromResultTile(
1247             rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1248             iterDomainOffset, iterDomainSizes))) {
1249       // In theory, it is unnecessary to raise an error here. Actually although
1250       // it fails to reconstruct the result tensor, it should not broke current
1251       // fusion anyway. The reason why we must return failure currently is that
1252       // the callback function `newYieldValuesFn` will be called after new init
1253       // operand(s) has already been appended. It will take more refactoring to
1254       // make sure the init operands are added consistently in the future. For
1255       // more details, please refer to:
1256       // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1257       return failure();
1258     }
1259 
1260     // c. calculate offsets and sizes info of all OpResults respectively based
1261     // on iteration Domain Tile
1262     SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1263     for (const auto &resultNumber : initNumberList) {
1264       if (resultNumber == sliceResultNumber) {
1265         offsetList.push_back(sliceOffset);
1266         sizesList.push_back(sliceSizes);
1267       } else {
1268         assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1269         // infer result tile according to the iteration domain tile
1270         SmallVector<OpFoldResult> offset, sizes;
1271         if (failed(tilableOp.getResultTilePosition(
1272                 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1273                 offset, sizes))) {
1274           return failure();
1275         }
1276         offsetList.push_back(offset);
1277         sizesList.push_back(sizes);
1278       }
1279     }
1280 
1281     // d. create `extract_slice` for `iter_args` for DPS operation if necessary
1282     if (auto tiledDestStyleOp =
1283             dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1284       rewriter.setInsertionPoint(tiledDestStyleOp);
1285       for (const auto &&[index, newRegionArg] :
1286            llvm::enumerate(newRegionIterArgs)) {
1287         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1288             loc, newRegionArg, offsetList[index], sizesList[index],
1289             SmallVector<OpFoldResult>(offsetList[index].size(),
1290                                       rewriter.getIndexAttr(1)));
1291         generatedSlices.push_back(destSlice);
1292         unsigned resultNumber = initNumberList[index];
1293         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1294           tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1295         });
1296       }
1297     }
1298 
1299     // e. prepare tiled offset and sizes for later `insert_slice` creation by
1300     // caller
1301     Block *block = rewriter.getInsertionPoint()->getBlock();
1302     rewriter.setInsertionPoint(block->getTerminator());
1303     for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1304       tiledResult.push_back(tiledOwner->getResult(resultNumber));
1305       tiledOffset.emplace_back(offsetList[index]);
1306       tiledSizes.emplace_back(sizesList[index]);
1307     }
1308     return success();
1309   };
1310 
1311   if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
1312                                        newYieldValuesFn))) {
1313     return failure();
1314   }
1315   return generatedSlices;
1316 }
1317 
1318 /// Implementation of tile consumer and fuse producer greedily.
1319 FailureOr<scf::SCFTileAndFuseResult>
1320 mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1321     RewriterBase &rewriter, TilingInterface consumer,
1322     const scf::SCFTileAndFuseOptions &options) {
1323   // This transformation is only valid for ops that return values (i.e. not
1324   // valid to use with operations that have memref operands).
1325   if (!consumer->getNumResults()) {
1326     return rewriter.notifyMatchFailure(
1327         consumer, "invalid pattern for op with no results");
1328   }
1329 
1330   // 1. First tile the consumer.
1331   SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1332   llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1333 
1334   FailureOr<scf::SCFTilingResult> tilingResult =
1335       tileUsingSCF(rewriter, consumer, options.tilingOptions);
1336 
1337   if (failed(tilingResult))
1338     return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1339   for (auto *tiledOp : tilingResult->tiledOps)
1340     tiledAndFusedOps.insert(tiledOp);
1341 
1342   // If there are no loops generated, fusion is immaterial.
1343   auto &loops = tilingResult->loops;
1344   if (loops.empty()) {
1345     DenseMap<Value, Value> replacements;
1346     for (auto [origVal, replacement] :
1347          llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1348       replacements[origVal] = replacement;
1349     }
1350     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1351                                      replacements};
1352   }
1353 
1354   // To keep track of replacements for now just record the map from the original
1355   // untiled value to the result number of the for loop. Since the loop gets
1356   // potentially replaced during fusion, keeping the value directly wont work.
1357   DenseMap<Value, size_t> origValToResultNumber;
1358   for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
1359     origValToResultNumber[result] = index;
1360   }
1361 
1362   // 2. Typically, the operands of the tiled operation are slices of the
1363   //    operands of the untiled operation. These are expressed in IR using
1364   //    `tensor.extract_slice` operations with source being the operands of the
1365   //    untiled operation. Create a worklist of these `tensor.extract_slice`
1366   //    operations. If the producers of the source of the `tensor.extract_slice`
1367   //    can be tiled such that the tiled value is generated in-place, that
1368   //    effectively tiles + fuses the operations.
1369   struct WorklistItem {
1370     tensor::ExtractSliceOp candidateSlice;
1371     SCFTileAndFuseOptions::ControlFnResult controlFnResult;
1372   };
1373   std::deque<WorklistItem> worklist;
1374   auto addCandidateSlices = [&worklist, &options,
1375                              &loops](ArrayRef<Operation *> candidates) {
1376     for (auto candidate : candidates) {
1377       auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
1378       if (!sliceOp || sliceOp.use_empty())
1379         continue;
1380 
1381       auto [fusableProducer, destinationInitArg] =
1382           getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops);
1383       if (!fusableProducer)
1384         continue;
1385       std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1386           options.fusionControlFn(sliceOp, fusableProducer,
1387                                   destinationInitArg.has_value());
1388       if (!controlFnResult)
1389         continue;
1390       worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()});
1391     }
1392   };
1393 
1394   addCandidateSlices(tilingResult->generatedSlices);
1395   OpBuilder::InsertionGuard g(rewriter);
1396   while (!worklist.empty()) {
1397     // Traverse the slices in BFS fashion.
1398     WorklistItem worklistItem = worklist.front();
1399     worklist.pop_front();
1400 
1401     // The operands of the fused producer might themselved be slices of
1402     // values produced by operations that implement the `TilingInterface`.
1403     // Add these operations to the worklist.
1404     std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1405         tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1406                                    loops);
1407     if (!fusedResult)
1408       continue;
1409 
1410     if (worklistItem.controlFnResult.yieldProducerReplacement) {
1411       // Reconstruct and yield all opResult of fusableProducerOp by default. The
1412       // caller can specific which one to yield by designating optional argument
1413       // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
1414       Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1415       FailureOr<SmallVector<Operation *>> newSlices =
1416           yieldReplacementForFusedProducer(rewriter,
1417                                            worklistItem.candidateSlice,
1418                                            fusedResult.value(), loops);
1419       if (failed(newSlices)) {
1420         return rewriter.notifyMatchFailure(
1421             fusableProducerOp, "failed to replacement value for this "
1422                                "operation from within the tiled loop");
1423       }
1424       addCandidateSlices(newSlices.value());
1425       for (auto [index, result] :
1426            llvm::enumerate(fusableProducerOp->getResults())) {
1427         origValToResultNumber[result] = loops.front()->getNumResults() -
1428                                         fusableProducerOp->getNumResults() +
1429                                         index;
1430       }
1431     }
1432     addCandidateSlices(fusedResult->generatedSlices);
1433     if (Operation *tiledAndFusedOp =
1434             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1435       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1436       tiledAndFusedOps.insert(tiledAndFusedOp);
1437     }
1438   }
1439 
1440   DenseMap<Value, Value> replacements;
1441   for (auto [origVal, resultNumber] : origValToResultNumber) {
1442     replacements[origVal] = loops.front()->getResult(resultNumber);
1443   }
1444 
1445   return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1446                                    replacements};
1447 }
1448 
1449 //===----------------------------------------------------------------------===//
1450 // tileAndFuseConsumerUsingSCF implementation.
1451 //===----------------------------------------------------------------------===//
1452 
1453 /// A utility function that checks whether the only use of the result of a
1454 /// tensor.insert_slice op is in a scf.yield op.
1455 static LogicalResult
1456 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1457   Value result = candidateSliceOp.getResult();
1458   Value::use_range uses = result.getUses();
1459   if (!llvm::hasSingleElement(uses)) {
1460     LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1461     return failure();
1462   }
1463   OpOperand &operandUse = (*uses.begin());
1464   Operation *userOp = operandUse.getOwner();
1465   if (!isa<scf::YieldOp>(userOp)) {
1466     LLVM_DEBUG(llvm::dbgs()
1467                << "Expected scf.yield to be the only user, but got -> "
1468                << (*userOp));
1469     return failure();
1470   }
1471   if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1472     LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1473                                "be in the same block\n");
1474     return failure();
1475   }
1476   return success();
1477 }
1478 
1479 /// Fetches the OpOperand of the only user (and use) of the value `val` which
1480 /// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
1481 /// failure otherwise.
1482 static FailureOr<OpOperand *> getConsumerFromUses(Value val,
1483                                                   Block *containingOpBlock) {
1484   // Step 1. Check that the value has exactly one use.
1485   if (!llvm::hasSingleElement(val.getUses()))
1486     return failure();
1487   // Step 2. Get uses.
1488   OpOperand &operand = (*val.getUses().begin());
1489   Operation *consumerOp = operand.getOwner();
1490   // TODO: We have to init result of consumer before scf.for, use
1491   //       DestinationStyleOpInterface to get result shape from init for now.
1492   //       Add support for other op such as op has InferTypeOpInterface.
1493   if (!isa<TilingInterface>(consumerOp) ||
1494       !isa<DestinationStyleOpInterface>(consumerOp))
1495     return failure();
1496   if (containingOpBlock != consumerOp->getBlock())
1497     return failure();
1498   return &operand;
1499 }
1500 
1501 /// Find the perfectly nested loops outside of given loop(included) sorted from
1502 /// outer to inner.
1503 ///
1504 /// E.g.
1505 ///
1506 /// ```
1507 ///  %0 = scf.for()
1508 ///    %1 = scf.for()
1509 ///      %2 = scf.for()
1510 ///         %3 = ...
1511 ///         yield %3
1512 ///      yield %2
1513 ///    yield %1
1514 /// ```
1515 ///
1516 /// This function will return three perfectly nested loops: %0 + %1 + %2, when
1517 /// target inner loop is %2.
1518 static SmallVector<scf::ForOp>
1519 getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
1520   SmallVector<scf::ForOp> nestLoops = {loop};
1521   auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1522 
1523   // Check if it is the ForOp that yield the result of inner loop.
1524   auto isForOpYieldResultOfInnerLoop =
1525       [](scf::ForOp outerLoop) -> LogicalResult {
1526     Block *body = outerLoop.getBody();
1527     if (!llvm::hasSingleElement(body->without_terminator()))
1528       return failure();
1529     auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
1530     auto innerForOp = dyn_cast<scf::ForOp>(body->front());
1531     if (!innerForOp)
1532       return failure();
1533     // All of innerForOp results should be yielded.
1534     return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1535   };
1536 
1537   while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1538     nestLoops.push_back(outerLoop);
1539     outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1540   }
1541   // sorted from outer to inner
1542   return {nestLoops.rbegin(), nestLoops.rend()};
1543 }
1544 
1545 /// Fetch the untiled consumer of a scf.for's result which is yielded by a
1546 /// tensor.insert_slice. This function makes the following assumptions :
1547 /// 1.  tensor.insert_slice has scf.yield as its only user.
1548 /// 2.  scf.for's corresponding result has only one use.
1549 static FailureOr<OpOperand *>
1550 getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
1551   if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
1552     return failure();
1553   Value sliceResult = candidateSliceOp.getResult();
1554   // Step 1. Fetch the corresponding output.
1555   OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1556   unsigned resultNumber = yieldOpOperand.getOperandNumber();
1557   // Step 2. Check containing op is scf.for.
1558   Operation *containingOp = candidateSliceOp->getParentOp();
1559   auto forOp = dyn_cast<scf::ForOp>(containingOp);
1560   if (!forOp)
1561     return failure();
1562   scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
1563   Value resultingValue = topLevelForOp->getResult(resultNumber);
1564 
1565   return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
1566 }
1567 
1568 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
1569 /// by a tensor.parallel_insert_slice.
1570 static FailureOr<OpOperand *>
1571 getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
1572   // Step 1. Fetch the corresponding output
1573   Value sliceDest = candidateSliceOp.getDest();
1574   auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1575   if (!iterArg)
1576     return failure();
1577   Operation *containingOp = iterArg.getOwner()->getParentOp();
1578   if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
1579     return failure();
1580   // Step 2. Check that the containing op is scf.forall.
1581   auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1582   if (!forallOp)
1583     return failure();
1584   Value resultingValue =
1585       forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
1586 
1587   return getConsumerFromUses(resultingValue, containingOp->getBlock());
1588 }
1589 
1590 /// This utility currently checks whether the loop either :-
1591 /// 1. Yields exactly one result.
1592 /// 2. Has consumer op as its first user and other users to be in the same
1593 /// containing block as that of consumer op's. Currently we clone the loop op
1594 /// right before the consumer op in order to maintain a valid def-use chain.
1595 /// This utility thus helps ensuring that no invalid IR is formed due to the
1596 /// same.
1597 static LogicalResult checkAssumptionForLoop(Operation *loopOp,
1598                                             Operation *consumerOp) {
1599   // Check if the loop op yields one result.
1600   if (loopOp->getNumResults() == 1)
1601     return success();
1602   // Check if the consumerOp is the first user of the loopOp and if other users
1603   // are in the same containing block as that of consumer op's.
1604   Block *parentBlock = consumerOp->getBlock();
1605   for (Operation *userOp : loopOp->getUsers()) {
1606     if (userOp == consumerOp)
1607       continue;
1608     if (parentBlock != userOp->getBlock() ||
1609         !consumerOp->isBeforeInBlock(userOp))
1610       return failure();
1611   }
1612   return success();
1613 }
1614 
1615 /// A utility to fetch an untiled consumer of
1616 /// tensor.insert_slice/tensor.parallel_insert_slice.
1617 static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
1618   if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1619     return getUntiledConsumerFromSlice(insertSlice);
1620   } else if (auto parallelInsertSlice =
1621                  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1622     return getUntiledConsumerFromSlice(parallelInsertSlice);
1623   } else {
1624     return failure();
1625   }
1626 }
1627 
1628 /// Implementation of fusing consumer of a single slice by computing the
1629 /// slice of the consumer in-place for scf loop.
1630 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1631 mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1632                                       Operation *candidateSliceOp) {
1633   if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1634           candidateSliceOp))
1635     return failure();
1636 
1637   bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1638 
1639   // 1. Get the consumer of scf.for for the result yielded by
1640   // tensor.insert_slice/parallel_insert_slice.
1641   FailureOr<OpOperand *> maybeConsumerOpOperand =
1642       getUntiledConsumerFromSlice(candidateSliceOp);
1643   if (failed(maybeConsumerOpOperand)) {
1644     return rewriter.notifyMatchFailure(candidateSliceOp,
1645                                        "could not fetch consumer to fuse");
1646   }
1647   OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1648   Operation *consumerOp = consumerOpOperand->getOwner();
1649   unsigned operandNumber = consumerOpOperand->getOperandNumber();
1650   unsigned resultNumber = 0;
1651   if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
1652     resultNumber = producerResult.getResultNumber();
1653   } else {
1654     return rewriter.notifyMatchFailure(
1655         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
1656   }
1657 
1658   // There are two possible cases regarding `oldLoopOp` here:
1659   // 1. single `scf.forall` or `scf.for`.
1660   // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1661   // top-level loop is the outer-most one of these nested loops.
1662   LoopLikeOpInterface innerMostLoop =
1663       candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
1664   SmallVector<LoopLikeOpInterface> nestedLoops;
1665   if (isInsertSliceOp) {
1666     nestedLoops = llvm::map_to_vector(
1667         getPerfectlyNestedLoopsOutsideOf(
1668             cast<scf::ForOp>(innerMostLoop.getOperation())),
1669         [](scf::ForOp forOp) {
1670           return cast<LoopLikeOpInterface>(forOp.getOperation());
1671         });
1672   } else {
1673     nestedLoops = {innerMostLoop};
1674   }
1675 
1676   LoopLikeOpInterface outerMostLoop = nestedLoops.front();
1677 
1678   if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) {
1679     return rewriter.notifyMatchFailure(
1680         outerMostLoop,
1681         "containing loop op should either yield just one value or "
1682         "have the consumer op as its first user");
1683   }
1684 
1685   OpBuilder::InsertionGuard g(rewriter);
1686 
1687   // 2. Check consumer is not using scf loop's output as init.
1688   auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
1689   if (!dstOp)
1690     return rewriter.notifyMatchFailure(consumerOp,
1691                                        "consumer op is not DPS operation");
1692   SmallVector<Value> dpsInits =
1693       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
1694   if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
1695     return rewriter.notifyMatchFailure(
1696         consumerOp,
1697         "consumer op taking the result of scf.for as init is not supported");
1698   }
1699   SmallVector<Value> newInits = dpsInits;
1700 
1701   Location loc = outerMostLoop->getLoc();
1702 
1703   // 3. Move the whole loop structure right before consumer Op, the dominance
1704   // should be already ensured by `checkAssumptionForLoop`.
1705   rewriter.moveOpBefore(outerMostLoop, consumerOp);
1706 
1707   // 4. Set insertion point before terminator op of the loop and create a new
1708   // tensor.insert_slice. In the scf.for case this is a clone of the
1709   // candidateSliceOp whereas in the scf.forall case this is created from the
1710   // operands of tensor.parallel_insert_slice.
1711   tensor::InsertSliceOp clonedInsertSliceOp;
1712   if (auto sliceOp =
1713           dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1714     auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
1715     rewriter.setInsertionPoint(newForallOp.getTerminator());
1716     clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
1717         loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
1718         sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
1719   } else {
1720     rewriter.setInsertionPoint(candidateSliceOp);
1721     clonedInsertSliceOp =
1722         cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
1723   }
1724 
1725   // 5.a. Clone consumer op.
1726   auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
1727 
1728   // 5.b. Replace all uses of the loop result with the result of the cloned
1729   // tensor.insert_slice.
1730   OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
1731   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
1732     operandToReplace.set(clonedInsertSliceOp.getResult());
1733   });
1734 
1735   // 6. Perform tiling of the cloned consumer and replace the operand at
1736   // `operandNumber` with the source of the cloned tensor.insert_slice op.
1737   auto ossSliceOp =
1738       cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
1739   FailureOr<TilingResult> tileAndFuseResult =
1740       tensor::replaceInsertSliceWithTiledConsumer(
1741           rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
1742   if (failed(tileAndFuseResult)) {
1743     return failure();
1744   }
1745   auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
1746   rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
1747                               clonedInsertSliceOp.getSource());
1748 
1749   // 7. Reconstruct [nested] loop with new inits.
1750   YieldTiledValuesFn newYieldValuesFn =
1751       [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1752           ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1753           SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
1754           SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1755     OpBuilder::InsertionGuard g(innerRewriter);
1756     // 8. Set inner insertPoint right before tiled consumer op.
1757     innerRewriter.setInsertionPoint(tiledConsumerOp);
1758 
1759     SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
1760     SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
1761     SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
1762 
1763     // 9. Check all insert stride is 1.
1764     if (llvm::any_of(strides, [](OpFoldResult stride) {
1765           return !isConstantIntValue(stride, 1);
1766         })) {
1767       return rewriter.notifyMatchFailure(
1768           candidateSliceOp, "containingOp's result yield with stride");
1769     }
1770 
1771     // 10. Try to get iter domain position from input position.
1772     SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
1773     if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile(
1774             rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
1775             iterDomainSizes))) {
1776       return rewriter.notifyMatchFailure(
1777           tiledConsumerOp,
1778           "can't get iter domain position from input position");
1779     }
1780 
1781     // 11. Try to fetch the offset and size for all results of the cloned
1782     // consumer. This would then be used to form the corresponding
1783     // tensor.insert_slice/parallel_insert_slice later.
1784     unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
1785     SmallVector<SmallVector<OpFoldResult>> resultOffsets(
1786         totalNumResultsOfConsumer);
1787     SmallVector<SmallVector<OpFoldResult>> resultSizes(
1788         totalNumResultsOfConsumer);
1789     for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
1790       if (failed(tiledConsumerOp.getResultTilePosition(
1791               rewriter, idx, iterDomainOffsets, iterDomainSizes,
1792               resultOffsets[idx], resultSizes[idx]))) {
1793         return rewriter.notifyMatchFailure(
1794             tiledConsumerOp,
1795             "can't get result domain position from iter domain position");
1796       }
1797     }
1798 
1799     // 12. Create `extract_slice` for `iter_args` for DPS operation if
1800     // necessary.
1801     if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
1802             tiledConsumerOp.getOperation())) {
1803       rewriter.setInsertionPoint(tiledDestStyleOp);
1804       for (const auto &&[index, newRegionArg] :
1805            llvm::enumerate(newRegionIterArgs)) {
1806         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1807             loc, newRegionArg, resultOffsets[index], resultSizes[index],
1808             SmallVector<OpFoldResult>(resultOffsets[index].size(),
1809                                       rewriter.getIndexAttr(1)));
1810         // Make a copy of index to avoid a capturing structured binding, which
1811         // is a C++20 extension.
1812         auto dstNumber = index;
1813         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1814           tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
1815         });
1816       }
1817     }
1818 
1819     // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
1820     // caller.
1821     Block *block = rewriter.getInsertionPoint()->getBlock();
1822     rewriter.setInsertionPoint(block->getTerminator());
1823     for (const auto &&[index, result] :
1824          llvm::enumerate(tiledConsumerOp->getResults())) {
1825       tiledResult.push_back(result);
1826       tiledOffset.emplace_back(resultOffsets[index]);
1827       tiledSizes.emplace_back(resultSizes[index]);
1828     }
1829     return success();
1830   };
1831   // 14. Add new inits to [nested] loops.
1832   if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
1833                                        newYieldValuesFn))) {
1834     return rewriter.notifyMatchFailure(tiledConsumerOp,
1835                                        "unable to add new inits to nest loop");
1836   }
1837 
1838   // 15. Replace the result of scf loop and consumer op with new loop's results.
1839 
1840   for (auto &&[oldResult, newResult] : llvm::zip(
1841            consumerOp->getResults(),
1842            nestedLoops.front()->getResults().take_back(newInits.size()))) {
1843     rewriter.replaceAllUsesWith(oldResult, newResult);
1844   }
1845 
1846   // 16. Need to erase the old scf loop and the cloned consumer op.
1847   rewriter.eraseOp(clonedConsumerOp);
1848 
1849   return scf::SCFFuseConsumerOfSliceResult{
1850       consumerOpOperand,
1851       &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
1852       tileAndFuseResult->tiledOps};
1853 }
1854 
1855 //===----------------------------------------------------------------------===//
1856 // lowerToLoopsUsingSCFForOp implementation.
1857 //===----------------------------------------------------------------------===//
1858 
1859 FailureOr<SmallVector<scf::ForOp>>
1860 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
1861                                      TilingInterface op) {
1862   // TODO: Handle cases where the op has results if needed.
1863   if (op->getNumResults() > 0) {
1864     return rewriter.notifyMatchFailure(
1865         op, "unable to lower to loops operations with return values");
1866   }
1867 
1868   SmallVector<Range> domain = op.getIterationDomain(rewriter);
1869   SmallVector<Value> ivs;
1870   SmallVector<scf::ForOp> loops;
1871   Location loc = op.getLoc();
1872   for (auto loopRange : domain) {
1873     Value offsetVal =
1874         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
1875     Value sizeVal =
1876         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
1877     Value strideVal =
1878         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
1879     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
1880                                             strideVal, ValueRange{});
1881     loops.push_back(loop);
1882     ivs.push_back(loop.getInductionVar());
1883     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
1884   }
1885   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
1886     return failure();
1887   }
1888   return loops;
1889 }
1890