xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision 2b2ce50fe843b5b550806a0ab15b06cd5c405d48)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/SCF/Utils/Utils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Utils/IndexingUtils.h"
23 #include "mlir/IR/Dominance.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
27 #include "mlir/Interfaces/TilingInterface.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/Debug.h"
30 #include <optional>
31 
32 #define DEBUG_TYPE "tile-using-interface"
33 
34 using namespace mlir;
35 
36 scf::SCFTilingOptions &
37 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
38   assert(!tileSizeComputationFunction && "tile sizes already set");
39   auto tileSizes = llvm::to_vector(ts);
40   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
41     return tileSizes;
42   };
43   return *this;
44 }
45 
46 /// Helper method to adjust the interchange vector to match the iteration
47 /// domain.
48 static SmallVector<int64_t>
49 fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
50                       size_t iterationDomainSize) {
51   SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
52   if (filledVector.size() < iterationDomainSize) {
53     auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
54     filledVector.append(range.begin(), range.end());
55   }
56   if (filledVector.size() > iterationDomainSize)
57     filledVector.resize(iterationDomainSize);
58   return filledVector;
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // tileUsingSCF implementation.
63 //===----------------------------------------------------------------------===//
64 
65 // Check if `stride` evenly divides the trip count `size - offset`.
66 static bool tileDividesIterationDomain(Range loopRange) {
67   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
68   if (!offsetAsInt)
69     return false;
70   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
71   if (!sizeAsInt)
72     return false;
73   std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
74   if (!strideAsInt)
75     return false;
76   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
77 }
78 
79 /// Returns the bounded tile size given the current `iv`, `loopRange` and
80 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
81 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
82                                        Range loopRange, Value iv,
83                                        OpFoldResult tileSize) {
84   std::optional<int64_t> ts = getConstantIntValue(tileSize);
85   if (ts && ts.value() == 1)
86     return tileSize;
87 
88   if (tileDividesIterationDomain(
89           Range{loopRange.offset, loopRange.size, tileSize}))
90     return tileSize;
91 
92   // The tile size to use (to avoid out of bounds access) is  minimum of
93   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
94   // loop.
95   AffineExpr s0, s1, d0;
96   bindDims(b.getContext(), d0);
97   bindSymbols(b.getContext(), s0, s1);
98   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
99   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
100   return affine::makeComposedFoldedAffineMin(
101       b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
102 }
103 
104 /// A function that allows returning additional yielded values during
105 /// `yieldTiledValuesAndReplace`.
106 /// - `ivs` induction variable for the loop.
107 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
108 /// - `tiledValues` the tiled values to return. Must be of same size as
109 ///   `newbbArgs`, each element of this array is inserted into the corresponding
110 ///   element in `newbbArgs`.
111 /// - `resultOffsets` is of the same size as `tiledValues` and represents
112 ///   the offsets to use when inserting corresponding element from `tiledValues`
113 ///   into the element from `newBbArgs`.
114 /// - `resultSizes` is of the same size as `tiledValues` and represents
115 ///   the size of the corresponding element from `tiledValues` inserted into
116 ///   the element from `newBbArgs`.
117 /// In case the method needs to return `failure()` the method is expected
118 /// to clean up any inserted operations.
119 using YieldTiledValuesFn = std::function<LogicalResult(
120     RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
121     SmallVector<Value> &tiledValues,
122     SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
123     SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
124 
125 /// Clones the operation and updates the destination if the operation
126 /// implements the `DestinationStyleOpInterface`.
127 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
128                                                   Operation *op,
129                                                   ValueRange newDestArgs) {
130   Operation *clonedOp = rewriter.clone(*op);
131   if (newDestArgs.empty())
132     return clonedOp;
133   if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
134     destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
135   return clonedOp;
136 }
137 
138 /// Generate the tile-loop nest using `scf.for` operation.
139 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
140 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
141 /// - `destinationTensors` are the init values to use for the outer most loop.
142 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
143 /// most
144 ///    loop.
145 /// - `loops` is an in-out parameter into which the generated loops are
146 ///    populated.
147 static LogicalResult generateLoopNestUsingForOp(
148     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
149     ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
150     YieldTiledValuesFn yieldTiledValuesFn,
151     SmallVector<LoopLikeOpInterface> &loops) {
152   assert(!loopRanges.empty() && "unexpected empty loop ranges");
153   assert(loopRanges.size() == tileSizes.size() &&
154          "expected as many tile sizes as loop ranges");
155   OpBuilder::InsertionGuard guard(rewriter);
156   SmallVector<Value> ivs;
157 
158   for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
159     // No loops if tile size is zero. Set offset and size to the loop
160     // offset and size.
161     if (isConstantIntValue(tileSize, 0))
162       continue;
163 
164     Value lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
165     Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
166     Value step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
167     auto loop =
168         rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
169                                     [](OpBuilder &bodyBuilder, Location bodyLoc,
170                                        Value iv, ValueRange /*iterArgs*/) {});
171     loops.push_back(loop);
172     ivs.push_back(loop.getInductionVar());
173     rewriter.setInsertionPointToEnd(loop.getBody());
174     destinationTensors = loop.getRegionIterArgs();
175   }
176 
177   SmallVector<Value> tiledResults;
178   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
179   if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
180                                 tiledResults, resultOffsets, resultSizes))) {
181     return rewriter.notifyMatchFailure(
182         loc, "failed to generate inner tile loop body");
183   }
184   if (loops.empty())
185     return success();
186 
187   assert(tiledResults.size() == destinationTensors.size() &&
188          "Number of results of body should be equal to number of iter args");
189 
190   // 6. Yield all the results of the tiled operation.
191   SmallVector<Value> yieldedValues;
192   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
193        llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
194                        resultSizes)) {
195     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
196                                            rewriter.getIndexAttr(1));
197     auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
198         loc, tiledValue, destinationTensor, resultOffset, resultSize,
199         resultStride);
200     yieldedValues.push_back(insertSlice);
201   }
202   rewriter.create<scf::YieldOp>(loc, yieldedValues);
203 
204   // Add the scf.yield operations for all the outer loops.
205   for (auto [outerLoop, innerLoop] :
206        llvm::zip_equal(MutableArrayRef(loops).drop_back(),
207                        MutableArrayRef(loops).drop_front())) {
208     rewriter.setInsertionPointToEnd(
209         cast<scf::ForOp>(outerLoop.getOperation()).getBody());
210     rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
211   }
212   return success();
213 }
214 
215 /// Generate the tile-loop nest using `scf.forall` operation.
216 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
217 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
218 /// - `destinationTensors` are the init values to use for the outer most loop.
219 /// - `mappingVector` is the mapping attributes to use for loop construction.
220 ///   Can be empty.
221 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
222 /// most
223 ///    loop.
224 /// - `loops` is an in-out parameter into which the generated loops are
225 ///    populated.
226 static LogicalResult generateLoopNestUsingForallOp(
227     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
228     ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector,
229     ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn,
230     SmallVector<LoopLikeOpInterface> &loops) {
231   SmallVector<OpFoldResult> lbs, ubs, steps;
232   assert(!loopRanges.empty() && "unexpected empty loop ranges");
233   assert(loopRanges.size() == tileSizes.size() &&
234          "expected as many tile sizes as loop ranges");
235   OpBuilder::InsertionGuard guard(rewriter);
236   SmallVector<OpFoldResult> offsets(loopRanges.size()),
237       sizes(loopRanges.size());
238 
239   for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
240     if (isConstantIntValue(tileSize, 0))
241       continue;
242     lbs.push_back(loopRange.offset);
243     ubs.push_back(loopRange.size);
244     steps.push_back(tileSize);
245   }
246   assert(!lbs.empty() && "Expected at least one loop range");
247 
248   std::optional<ArrayAttr> mappingAttr;
249   if (!mappingVector.empty())
250     mappingAttr = rewriter.getArrayAttr(mappingVector);
251 
252   auto forallOp = rewriter.create<scf::ForallOp>(
253       loc, lbs, ubs, steps, destinationTensors, mappingAttr);
254   loops.push_back(forallOp);
255 
256   rewriter.setInsertionPoint(forallOp.getTerminator());
257   destinationTensors = forallOp.getRegionOutArgs();
258 
259   SmallVector<Value> tiledResults;
260   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
261   if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
262                          destinationTensors, tiledResults, resultOffsets,
263                          resultSizes)))
264     return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
265 
266   rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
267   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
268        llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
269                        resultSizes)) {
270     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
271                                            rewriter.getIndexAttr(1));
272 
273     rewriter.create<tensor::ParallelInsertSliceOp>(
274         loc, tiledValue, destinationTensor, resultOffset, resultSize,
275         resultStride);
276   }
277   return success();
278 }
279 
280 /// Generate the tile-loop nest using the loop construct specifed in `options`.
281 /// - `options`: Tiling options specified.
282 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
283 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
284 /// - `destinationTensors` are the init values to use for the outer most loop.
285 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
286 /// most
287 ///    loop.
288 /// - `loops` is an in-out parameter into which the generated loops are
289 ///    populated.
290 static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
291                                       const scf::SCFTilingOptions &options,
292                                       ArrayRef<Range> loopRanges,
293                                       ArrayRef<OpFoldResult> tileSizes,
294                                       ValueRange destinationTensors,
295                                       YieldTiledValuesFn tiledBodyFn,
296                                       SmallVector<LoopLikeOpInterface> &loops) {
297   // If the tile sizes are all zero, no loops are generated. Just call the
298   // callback function to handle untiled case.
299   if (llvm::all_of(tileSizes, isZeroIndex)) {
300     SmallVector<Value> tiledResults;
301     SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
302     return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
303                        tiledResults, resultOffsets, resultSizes);
304   }
305   if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
306     return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
307                                       destinationTensors, tiledBodyFn, loops);
308   }
309   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
310     return generateLoopNestUsingForallOp(
311         rewriter, loc, loopRanges, tileSizes, options.mappingVector,
312         destinationTensors, tiledBodyFn, loops);
313   }
314   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
315 }
316 
317 /// Append the specified additional `newInitOperands` operands to the
318 /// loops existing `init` operands (or similar), and replace `loopOp` with
319 /// the new loop that has the additional init operands. The loop body of
320 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
321 /// is called to get the new tiled values returned, and the offset
322 /// and sizes at which the tiled value is inserted into the
323 /// new region iter_args that correspond to the newly added init operands.
324 template <typename LoopType>
325 FailureOr<LoopLikeOpInterface>
326 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
327                                ValueRange newInitOperands,
328                                YieldTiledValuesFn yieldTiledValuesFn) {
329   return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
330 }
331 
332 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
333 template <>
334 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
335     scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
336     YieldTiledValuesFn yieldTiledValuesFn) {
337   OpBuilder::InsertionGuard g(rewriter);
338   Location loc = loopOp.getLoc();
339   rewriter.setInsertionPoint(loopOp);
340 
341   auto inits = llvm::to_vector(loopOp.getInitArgs());
342   inits.append(newInitOperands.begin(), newInitOperands.end());
343   auto newLoop = rewriter.create<scf::ForOp>(
344       loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
345       inits, [](OpBuilder &, Location, Value, ValueRange) {});
346 
347   // Move the loop body to the new op.
348   Block *loopBody = loopOp.getBody();
349   Block *newLoopBody = newLoop.getBody();
350   rewriter.mergeBlocks(
351       loopBody, newLoopBody,
352       newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
353 
354   auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
355   rewriter.setInsertionPoint(yieldOp);
356 
357   SmallVector<Value> tiledValues;
358   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
359   ValueRange newRegionIterArgs =
360       newLoop.getRegionIterArgs().take_back(newInitOperands.size());
361   if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
362                                 newRegionIterArgs, tiledValues, resultOffsets,
363                                 resultSizes))) {
364     rewriter.eraseOp(newLoop);
365     return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
366   }
367 
368   SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
369   for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
370        llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
371                        resultSizes)) {
372     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
373                                            rewriter.getIndexAttr(1));
374     Value insert = rewriter.create<tensor::InsertSliceOp>(
375         yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
376         resultStride);
377     newYieldValues.push_back(insert);
378   }
379 
380   rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
381   rewriter.replaceOp(loopOp,
382                      newLoop->getResults().take_front(loopOp.getNumResults()));
383   return cast<LoopLikeOpInterface>(newLoop.getOperation());
384 }
385 
386 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
387 template <>
388 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
389     scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
390     YieldTiledValuesFn yieldTiledValuesFn) {
391   OpBuilder::InsertionGuard g(rewriter);
392   Location loc = loopOp.getLoc();
393   rewriter.setInsertionPoint(loopOp);
394   auto inits = llvm::to_vector(loopOp.getOutputs());
395   inits.append(newInitOperands.begin(), newInitOperands.end());
396   auto newLoop = rewriter.create<scf::ForallOp>(
397       loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
398       loopOp.getMixedStep(), inits, loopOp.getMapping(),
399       [](OpBuilder &, Location, ValueRange) {});
400 
401   // Move the region of the current block to the newly created op.
402   Block *loopBody = loopOp.getBody();
403   Block *newLoopBody = newLoop.getBody();
404   rewriter.mergeBlocks(
405       loopBody, newLoopBody,
406       newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
407 
408   auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
409   rewriter.setInsertionPoint(terminator);
410   SmallVector<Value> tiledValues;
411   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
412   ValueRange regionIterArgs =
413       newLoop.getRegionIterArgs().take_back(newInitOperands.size());
414   if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
415                                 regionIterArgs, tiledValues, resultOffsets,
416                                 resultSizes))) {
417     rewriter.eraseOp(newLoop);
418     return rewriter.notifyMatchFailure(loopOp,
419                                        "failed to get yielded tiled values");
420   }
421 
422   // Update the terminator.
423   rewriter.setInsertionPointToEnd(terminator.getBody());
424 
425   for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
426            tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
427     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
428                                            rewriter.getIndexAttr(1));
429     rewriter.create<tensor::ParallelInsertSliceOp>(
430         terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
431         resultStride);
432   }
433 
434   rewriter.replaceOp(loopOp,
435                      newLoop->getResults().take_front(loopOp.getNumResults()));
436   return cast<LoopLikeOpInterface>(newLoop.getOperation());
437 }
438 
439 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
440 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
441 /// supported loop type.
442 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
443     LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
444     ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
445   return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>(
446              loopLikeOp.getOperation())
447       .Case<scf::ForOp, scf::ForallOp>(
448           [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
449             return yieldTiledValuesAndReplaceLoop(
450                 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
451           })
452       .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
453         return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
454       });
455 }
456 
457 /// Method to add new init values to a loop nest. Updates `loops` in-place with
458 /// new loops that use the `newInitValues`.
459 /// The outer-loops are updated to yield the new result values of the inner
460 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get
461 /// the additional values to yield form the innermost loop.
462 static LogicalResult addInitOperandsToLoopNest(
463     RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
464     ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
465   SmallVector<scf::ForOp> newLoops;
466   if (loops.empty())
467     return success();
468   OpBuilder::InsertionGuard g(rewriter);
469   rewriter.setInsertionPoint(loops.front());
470 
471   SmallVector<Value> ivs;
472   for (auto &loop : loops.drop_back()) {
473     rewriter.setInsertionPoint(loop);
474 
475     // if loops.size() > 1 we assume that scf.for is used for the loops.
476     auto forLoop = cast<scf::ForOp>(loop.getOperation());
477 
478     // Create a new loop with the new init values for this loop.
479     SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
480     newInits.append(newInitValues.begin(), newInitValues.end());
481     auto newLoop = rewriter.create<scf::ForOp>(
482         forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
483         forLoop.getStep(), newInits,
484         [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
485 
486     // Merge the body of the new loop with the body of the old loops.
487     SmallVector<Value> sourceBlockArgs;
488     sourceBlockArgs.push_back(newLoop.getInductionVar());
489     auto newRegionIterArgs = newLoop.getRegionIterArgs();
490     sourceBlockArgs.append(
491         newRegionIterArgs.begin(),
492         std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
493     rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
494     rewriter.replaceOp(
495         forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
496     loop = newLoop;
497     ivs.push_back(newLoop.getInductionVar());
498     newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
499   }
500 
501   // Update the loop body of the innermost loop to get new yield values.
502   LoopLikeOpInterface innerMostLoop = loops.back();
503   FailureOr<LoopLikeOpInterface> newInnerMostLoop =
504       yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
505                                      getNewTiledYieldsFn);
506 
507   if (failed(newInnerMostLoop))
508     return innerMostLoop.emitOpError("failed to return additional yields");
509   loops.back() = newInnerMostLoop.value();
510 
511   // Make all other loops except the innermost loops yield the values returned
512   // by the inner loop.
513   for (auto [outerLoop, innerLoop] :
514        llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
515     // Again assume that all the outer loops are scf.for operations.
516     auto outerForLoop = cast<scf::ForOp>(outerLoop);
517     auto outerLoopYield =
518         cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
519     SmallVector<Value> newYields =
520         llvm::to_vector(outerLoopYield.getOperands());
521     ValueRange additionalYields =
522         innerLoop->getResults().take_back(newInitValues.size());
523     newYields.append(additionalYields.begin(), additionalYields.end());
524     rewriter.setInsertionPoint(outerLoopYield);
525     rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
526   }
527   return success();
528 }
529 
530 /// Implementation of tiling transformation of `op` that implements the
531 /// `TilingInterface` using `scf.for` to iterate over the tiles.
532 FailureOr<scf::SCFTilingResult>
533 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
534                         const scf::SCFTilingOptions &options) {
535   OpBuilder::InsertionGuard guard(rewriter);
536   rewriter.setInsertionPointAfter(op);
537 
538   if (!options.tileSizeComputationFunction) {
539     return rewriter.notifyMatchFailure(
540         op, "missing tile size computation function");
541   }
542 
543   // 1. Get the range of the loops that are represented by the operation.
544   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
545   size_t numLoops = iterationDomain.size();
546 
547   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
548   // skips tiling a particular dimension. This convention is significantly
549   // simpler to handle instead of adjusting affine maps to account for missing
550   // dimensions.
551   SmallVector<OpFoldResult> tileSizes =
552       options.tileSizeComputationFunction(rewriter, op);
553   if (tileSizes.size() < iterationDomain.size()) {
554     auto zero = rewriter.getIndexAttr(0);
555     tileSizes.append(numLoops - tileSizes.size(), zero);
556   }
557 
558   // 3. If there is an interchange specified, permute the iteration domain and
559   // the tile sizes.
560   SmallVector<int64_t> interchangeVector;
561   if (!options.interchangeVector.empty()) {
562     interchangeVector = fillInterchangeVector(options.interchangeVector,
563                                               iterationDomain.size());
564   }
565   if (!interchangeVector.empty()) {
566     if (!isPermutationVector(interchangeVector)) {
567       return rewriter.notifyMatchFailure(
568           op, "invalid intechange vector, not a permutation of the entire "
569               "iteration space");
570     }
571 
572     applyPermutationToVector(iterationDomain, interchangeVector);
573     applyPermutationToVector(tileSizes, interchangeVector);
574   }
575 
576   FailureOr<TilingResult> tilingResult;
577   // 4. Define the lambda function used later to generate the body of the
578   // innermost tiled loop.
579   YieldTiledValuesFn innerYieldTiledValuesFn =
580       [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
581           ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
582           SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
583           SmallVector<SmallVector<OpFoldResult>> &resultSizes)
584       -> LogicalResult {
585     // 4a. Compute the `offsets` and `sizes` to use for tiling.
586     SmallVector<OpFoldResult> offsets, sizes;
587     {
588       int materializedLoopNum = 0;
589       for (auto [tileSize, loopRange] :
590            llvm::zip_equal(tileSizes, iterationDomain)) {
591         if (isConstantIntValue(tileSize, 0)) {
592           offsets.push_back(loopRange.offset);
593           sizes.push_back(loopRange.size);
594           continue;
595         }
596         Value iv = ivs[materializedLoopNum++];
597         offsets.push_back(iv);
598         sizes.push_back(
599             getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
600       }
601     }
602 
603     // 4b. If interchange was provided, apply inverse of the interchange
604     //     to get back the offsets/sizes in the order to be specified.
605     if (!interchangeVector.empty()) {
606       auto inversePermutation = invertPermutationVector(interchangeVector);
607       applyPermutationToVector(offsets, inversePermutation);
608       applyPermutationToVector(sizes, inversePermutation);
609     }
610 
611     // 5. Generate the tiled implementation within the inner most loop.
612 
613     // 5a. Clone the operation within the loop body.
614     auto clonedOp = cast<TilingInterface>(
615         cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
616 
617     // 5b. Early return cloned op if tiling is not happening. We can not return
618     // the original op because it could lead to
619     // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
620     if (llvm::all_of(tileSizes, isZeroIndex)) {
621       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
622       tilingResult =
623           TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
624       return success();
625     }
626 
627     // 5c. Tile the cloned operation.
628     tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
629     if (failed(tilingResult)) {
630       rewriter.eraseOp(clonedOp);
631       return op.emitOpError("faild to tile operation");
632     }
633 
634     // 5d. Delete the cloned operation.
635     rewriter.eraseOp(clonedOp);
636 
637     // 5e. Compute the offsets at which the result values are to be inserted
638     //     back into its destinations.
639     for (auto [index, tiledValue] :
640          llvm::enumerate(tilingResult->tiledValues)) {
641       tiledResults.push_back(tiledValue);
642       SmallVector<OpFoldResult> resultOffset, resultSize;
643       if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
644                                           resultOffset, resultSize))) {
645         for (auto op : tilingResult->tiledOps) {
646           rewriter.eraseOp(op);
647         }
648         return rewriter.notifyMatchFailure(
649             op, "failed to get slice of result produced");
650       }
651       resultOffsets.emplace_back(std::move(resultOffset));
652       resultSizes.emplace_back(std::move(resultSize));
653     }
654 
655     return success();
656   };
657 
658   // 6. Find the destination tensors to use for the operation.
659   SmallVector<Value> destinationTensors;
660   if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
661                                              destinationTensors))) {
662     return rewriter.notifyMatchFailure(op,
663                                        "unable to create destination tensors");
664   }
665 
666   // 7. Generate the tiled loops nest using the callback defined above.
667   SmallVector<LoopLikeOpInterface> loops;
668   if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
669                               tileSizes, destinationTensors,
670                               innerYieldTiledValuesFn, loops)))
671     return op.emitOpError("failed to generate tiling loops");
672   assert(succeeded(tilingResult) &&
673          "expected tiling result to be computed after loop generation");
674 
675   // If loops are empty, the tiled op is used as the replacement for the untiled
676   // op.
677   if (loops.empty()) {
678     return scf::SCFTilingResult{tilingResult->tiledOps, loops,
679                                 tilingResult->tiledValues};
680   }
681 
682   SmallVector<Value> replacements = llvm::map_to_vector(
683       loops.front()->getResults(), [](OpResult r) -> Value { return r; });
684   return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
685 }
686 
687 FailureOr<scf::SCFReductionTilingResult>
688 mlir::scf::tileReductionUsingScf(RewriterBase &b,
689                                  PartialReductionOpInterface op,
690                                  ArrayRef<OpFoldResult> tileSizes) {
691   Location loc = op.getLoc();
692   // Ops implementing PartialReductionOpInterface are expected to implement
693   // TilingInterface.
694   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
695   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
696   auto tileSizesVector = llvm::to_vector(tileSizes);
697   if (tileSizesVector.size() < iterationDomain.size()) {
698     auto zero = b.getIndexAttr(0);
699     tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
700                            zero);
701   }
702   SmallVector<utils::IteratorType> iterators =
703       tilingInterfaceOp.getLoopIteratorTypes();
704 
705   SmallVector<int> reductionDims;
706   for (auto [idx, iteratorType] :
707        llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
708     if (iteratorType == utils::IteratorType::reduction)
709       reductionDims.push_back(idx);
710   }
711 
712   // 2. create the inital tensor value.
713   FailureOr<SmallVector<Value>> maybeInitTensors =
714       op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
715                                                   reductionDims);
716   if (failed(maybeInitTensors)) {
717     return b.notifyMatchFailure(op, "Failed to create initial tensors.");
718   }
719   SmallVector<Value> &initTensors = maybeInitTensors.value();
720 
721   // 3. Define the callback to use for generating the inner most tile loop body.
722   Operation *parallelOp = nullptr;
723   auto innerYieldTiledValuesFn =
724       [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
725           ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
726           SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
727           SmallVector<SmallVector<OpFoldResult>> &resultSizes)
728       -> LogicalResult {
729     SmallVector<OpFoldResult> offsets, sizes;
730     {
731       int materializedLoopNum = 0;
732       for (auto [tileSize, loopRange] :
733            llvm::zip_equal(tileSizesVector, iterationDomain)) {
734         if (isConstantIntValue(tileSize, 0)) {
735           offsets.push_back(loopRange.offset);
736           sizes.push_back(loopRange.size);
737           continue;
738         }
739         Value iv = ivs[materializedLoopNum++];
740         offsets.push_back(iv);
741         sizes.push_back(
742             getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
743       }
744     }
745 
746     // 4a. Clone the operation.
747     auto clonedOp = cast<PartialReductionOpInterface>(
748         cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
749 
750     // 4b. Tile the cloned operation.
751     parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs,
752                                                  offsets, sizes, reductionDims);
753     // 4c. Delete the cloned operation.
754     b.eraseOp(clonedOp);
755 
756     tiledResult.append(parallelOp->result_begin(), parallelOp->result_end());
757     // 4d. Compute the offsets and sizes needed to insert the result of the
758     // tiled value back into destination before yielding the destination.
759     for (int resultIdx : llvm::seq<int>(0, parallelOp->getNumResults())) {
760       SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
761       resultOffsets.emplace_back(std::move(outOffsets));
762 
763       SmallVector<OpFoldResult> outSizes;
764       for (size_t i = 0; i < offsets.size(); i++) {
765         outSizes.push_back(
766             tensor::getMixedSize(b, loc, parallelOp->getResult(resultIdx), i));
767       }
768       resultSizes.emplace_back(std::move(outSizes));
769     }
770     return success();
771   };
772 
773   // 5. Generate the tiled implementation using the destination tensors.
774   SmallVector<LoopLikeOpInterface> loops;
775   scf::SCFTilingOptions options;
776   options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
777   if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
778                               initTensors, innerYieldTiledValuesFn, loops)))
779     return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
780 
781   SmallVector<Value> replacements = llvm::map_to_vector(
782       loops.front()->getResults(), [](OpResult r) -> Value { return r; });
783 
784   // 5. Apply the merge reduction to combine all the partial values.
785   b.setInsertionPointAfter(*loops.begin());
786   Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
787   b.replaceOp(op, mergeOp->getResults());
788 
789   SCFReductionTilingResult results;
790   results.initialValues = initTensors;
791   results.loops = loops;
792   results.parallelTiledOp = parallelOp;
793   results.mergeOp = mergeOp;
794   return results;
795 }
796 
797 //===----------------------------------------------------------------------===//
798 // tileConsumerAndFuseProducersUsingSCF implementation.
799 //===----------------------------------------------------------------------===//
800 
801 /// Return the untiled producer whose slice is used in a tiled consumer. The
802 /// method traverses the tile loop nest (`loops`) if needed, and returns the
803 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
804 /// indicates that this is a destination operand of the consumer. If there was
805 /// no loop traversal needed, the second value of the returned tuple is empty.
806 static std::tuple<OpResult, std::optional<OpOperand *>>
807 getUntiledProducerFromSliceSource(OpOperand *source,
808                                   ArrayRef<LoopLikeOpInterface> loops) {
809   std::optional<OpOperand *> destinationIterArg;
810   auto loopIt = loops.rbegin();
811   while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
812     auto loop = *loopIt;
813     if (iterArg.getOwner()->getParentOp() != loop)
814       break;
815     source = loop.getTiedLoopInit(iterArg);
816     loopIt++;
817   }
818   if (loopIt == loops.rend())
819     destinationIterArg = source;
820   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
821 }
822 
823 /// Implementation of fusing producer of a single slice by computing the
824 /// slice of the producer in-place.
825 std::optional<scf::SCFFuseProducerOfSliceResult>
826 mlir::scf::tileAndFuseProducerOfSlice(
827     RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
828     MutableArrayRef<LoopLikeOpInterface> loops) {
829   // 1. Get the producer of the source (potentially walking through
830   // `iter_args` of nested `scf.for`)
831   auto [fusableProducer, destinationInitArg] =
832       getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
833                                         loops);
834   if (!fusableProducer)
835     return std::nullopt;
836   unsigned resultNumber = fusableProducer.getResultNumber();
837 
838   OpBuilder::InsertionGuard g(rewriter);
839   rewriter.setInsertionPoint(candidateSliceOp);
840 
841   // 2. Clone the fused producer
842   // 2a. Compute the destination operands to use for the cloned operation.
843   SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
844   Operation *fusableProducerOp = fusableProducer.getOwner();
845   if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
846       failed(tensor::getOrCreateDestinations(
847           rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
848           origDestinationTensors)))
849     return std::nullopt;
850 
851   clonedOpDestinationTensors = origDestinationTensors;
852   if (destinationInitArg &&
853       isa<DestinationStyleOpInterface>(fusableProducerOp)) {
854     // 2b. If the producer is also destination style, then to maintain the
855     // destination passing style, update the destination of the producer to be
856     // the source of the slice.
857     clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
858   }
859   // 2c. Clone the fused producer.
860   Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
861       rewriter, fusableProducerOp, clonedOpDestinationTensors);
862   // 2d. Update the source of the candidateSlice to be the cloned producer.
863   //     Easier to just clone the slice with different source since replacements
864   //     and DCE of cloned ops becomes easier
865   SmallVector<Value> candidateSliceOpOperands =
866       llvm::to_vector(candidateSliceOp->getOperands());
867   candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
868   tensor::ExtractSliceOp clonedCandidateSliceOp =
869       mlir::clone(rewriter, candidateSliceOp,
870                   candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
871 
872   // 3. Generate the tiled implementation of the producer of the source
873   FailureOr<TilingResult> tileAndFuseResult =
874       tensor::replaceExtractSliceWithTiledProducer(
875           rewriter, clonedCandidateSliceOp,
876           clonedProducerOp->getResult(resultNumber));
877   if (failed(tileAndFuseResult))
878     return std::nullopt;
879   // Note: Do not delete the candidateSliceOp, since its passed in from the
880   // caller.
881   rewriter.replaceAllUsesWith(candidateSliceOp,
882                               tileAndFuseResult->tiledValues[0]);
883   rewriter.eraseOp(clonedCandidateSliceOp);
884   rewriter.eraseOp(clonedProducerOp);
885 
886   // 3. If the slice is for a destination operand, for example,
887   //
888   // ```mlir
889   // %0 = linalg.init
890   // %1 = linalg.fill .. outs(%0 : )
891   // %2 = scf.for .. iter_args(%arg0 = %1) {
892   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
893   //     %4 = tensor.extract_slice %arg1 [..]
894   //     .. = linalg.matmul .. outs(%4 : )
895   //   }
896   // }
897   // ```
898   //
899   // the IR is currently
900   //
901   // ```
902   // %0 = linalg.init
903   // %1 = linalg.fill
904   // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
905   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
906   //     %4 = tensor.extract_slice %arg1[..]
907   //     %5 = linalg.fill .. outs(%4 : )
908   //     .. = linalg.matmul .. outs(%5 : )
909   //   }
910   // }
911   // ```
912   //
913   // The untiled `linalg.fill` is still used as the `init_value` since it
914   // was originally a destination operand of the untiled `linalg.matmul`.
915   // When fusing an operand that is a destination operand, the iter_arg of
916   // the outer most loop should be changed to use the destination of the
917   // fused operation. With this the IR will be.
918   //
919   // ```
920   // %0 = linalg.init
921   // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
922   //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
923   //     %3 = tensor.extract_slice %arg1[..]
924   //     %4 = linalg.fill .. outs(%3 : )
925   //     .. = linalg.matmul .. outs(%4 : )
926   //   }
927   // }
928   // ```
929   if (destinationInitArg &&
930       isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
931     loops.front()
932         ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
933         .set(origDestinationTensors[resultNumber]);
934   }
935   return scf::SCFFuseProducerOfSliceResult{fusableProducer,
936                                            tileAndFuseResult->tiledValues[0],
937                                            tileAndFuseResult->tiledOps};
938 }
939 
940 /// Reconstruct the fused producer from within the tiled-and-fused code.
941 LogicalResult mlir::scf::yieldReplacementForFusedProducer(
942     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
943     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
944     MutableArrayRef<LoopLikeOpInterface> loops) {
945   if (loops.empty())
946     return success();
947 
948   OpResult fusableProducer = fusedProducerInfo.origProducer;
949   Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
950   FailureOr<Value> initValue = tensor::getOrCreateDestination(
951       rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
952   if (succeeded(initValue)) {
953 
954     YieldTiledValuesFn newYieldValuesFn =
955         [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
956             ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
957             SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
958             SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
959         -> LogicalResult {
960       OpBuilder::InsertionGuard g(innerRewriter);
961       if (auto tiledDestStyleOp =
962               tiledAndFusedProducer
963                   .getDefiningOp<DestinationStyleOpInterface>()) {
964         rewriter.setInsertionPoint(tiledDestStyleOp);
965         Value newRegionArg = newRegionIterArgs.back();
966         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
967             sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
968             sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
969         unsigned resultNumber = fusableProducer.getResultNumber();
970         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
971           tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
972         });
973       }
974       Block *block = rewriter.getInsertionPoint()->getBlock();
975       rewriter.setInsertionPoint(block->getTerminator());
976       tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
977       tiledOffset.emplace_back(sliceOp.getMixedOffsets());
978       tiledSizes.emplace_back(sliceOp.getMixedSizes());
979       return success();
980     };
981 
982     return addInitOperandsToLoopNest(rewriter, loops,
983                                      SmallVector<Value>{initValue.value()},
984                                      newYieldValuesFn);
985   }
986   return success();
987 }
988 
989 /// Implementation of tile consumer and fuse producer greedily.
990 FailureOr<scf::SCFTileAndFuseResult>
991 mlir::scf::tileConsumerAndFuseProducersUsingSCF(
992     RewriterBase &rewriter, TilingInterface consumer,
993     const scf::SCFTileAndFuseOptions &options) {
994   // This transformation is only valid for ops that return values (i.e. not
995   // valid to use with operations that have memref operands).
996   if (!consumer->getNumResults()) {
997     return rewriter.notifyMatchFailure(
998         consumer, "invalid pattern for op with no results");
999   }
1000 
1001   // 1. First tile the consumer.
1002   SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1003   llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1004 
1005   FailureOr<scf::SCFTilingResult> tilingResult =
1006       tileUsingSCF(rewriter, consumer, options.tilingOptions);
1007 
1008   if (failed(tilingResult))
1009     return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1010   for (auto *tiledOp : tilingResult->tiledOps)
1011     tiledAndFusedOps.insert(tiledOp);
1012 
1013   // If there are no loops generated, fusion is immaterial.
1014   auto &loops = tilingResult->loops;
1015   if (loops.empty()) {
1016     DenseMap<Value, Value> replacements;
1017     for (auto [origVal, replacement] :
1018          llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1019       replacements[origVal] = replacement;
1020     }
1021     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1022                                      replacements};
1023   }
1024 
1025   // To keep track of replacements for now just record the map from the original
1026   // untiled value to the result number of the for loop. Since the loop gets
1027   // potentially replaced during fusion, keeping the value directly wont work.
1028   DenseMap<Value, size_t> origValToResultNumber;
1029   for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
1030     origValToResultNumber[result] = index;
1031   }
1032 
1033   // 2. Typically, the operands of the tiled operation are slices of the
1034   //    operands of the untiled operation. These are expressed in IR using
1035   //    `tensor.extract_slice` operations with source being the operands of the
1036   //    untiled operation. Create a worklist of these `tensor.extract_slice`
1037   //    operations. If the producers of the source of the `tensor.extract_slice`
1038   //    can be tiled such that the tiled value is generated in-place, that
1039   //    effectively tiles + fuses the operations.
1040   auto addCandidateSlices = [](Operation *fusedOp,
1041                                std::deque<tensor::ExtractSliceOp> &candidates) {
1042     for (Value operand : fusedOp->getOperands())
1043       if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
1044         candidates.push_back(sliceOp);
1045   };
1046 
1047   std::deque<tensor::ExtractSliceOp> candidates;
1048   addCandidateSlices(tiledAndFusedOps.back(), candidates);
1049   OpBuilder::InsertionGuard g(rewriter);
1050   while (!candidates.empty()) {
1051     // Traverse the slices in BFS fashion.
1052     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
1053     candidates.pop_front();
1054 
1055     // Find the original producer of the slice.
1056     auto [fusableProducer, destinationInitArg] =
1057         getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1058                                           loops);
1059     if (!fusableProducer)
1060       continue;
1061 
1062     auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
1063         candidateSliceOp, fusableProducer, destinationInitArg.has_value());
1064     if (!fuseSlice)
1065       continue;
1066 
1067     // The operands of the fused producer might themselved be slices of
1068     // values produced by operations that implement the `TilingInterface`.
1069     // Add these operations to the worklist.
1070     std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1071         tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
1072     if (!fusedResult)
1073       continue;
1074 
1075     if (yieldReplacement) {
1076       if (failed(yieldReplacementForFusedProducer(
1077               rewriter, candidateSliceOp, fusedResult.value(), loops))) {
1078         return rewriter.notifyMatchFailure(
1079             fusableProducer.getOwner(), "failed to replacement value for this "
1080                                         "oepration from within the tiled loop");
1081       }
1082       origValToResultNumber[fusableProducer] =
1083           loops.front()->getNumResults() - 1;
1084     }
1085 
1086     if (Operation *tiledAndFusedOp =
1087             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1088       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1089       tiledAndFusedOps.insert(tiledAndFusedOp);
1090       addCandidateSlices(tiledAndFusedOp, candidates);
1091     }
1092   }
1093 
1094   DenseMap<Value, Value> replacements;
1095   for (auto [origVal, resultNumber] : origValToResultNumber) {
1096     replacements[origVal] = loops.front()->getResult(resultNumber);
1097   }
1098 
1099   return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1100                                    replacements};
1101 }
1102 
1103 //===----------------------------------------------------------------------===//
1104 // tileAndFuseConsumerUsingSCF implementation.
1105 //===----------------------------------------------------------------------===//
1106 
1107 /// A utility function that checks whether the only use of the result of a
1108 /// tensor.insert_slice op is in a scf.yield op.
1109 static LogicalResult
1110 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1111   Value result = candidateSliceOp.getResult();
1112   Value::use_range uses = result.getUses();
1113   if (!llvm::hasSingleElement(uses)) {
1114     LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1115     return failure();
1116   }
1117   OpOperand &operandUse = (*uses.begin());
1118   Operation *userOp = operandUse.getOwner();
1119   if (!isa<scf::YieldOp>(userOp)) {
1120     LLVM_DEBUG(llvm::dbgs()
1121                << "Expected scf.yield to be the only user, but got -> "
1122                << (*userOp));
1123     return failure();
1124   }
1125   if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1126     LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1127                                "be in the same block\n");
1128     return failure();
1129   }
1130   return success();
1131 }
1132 
1133 /// Fetches the OpOperand of the only user (and use) of the value `val` which
1134 /// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
1135 /// failure otherwise.
1136 static FailureOr<OpOperand *> getConsumerFromUses(Value val,
1137                                                   Block *containingOpBlock) {
1138   // Step 1. Check that the value has exactly one use.
1139   if (!llvm::hasSingleElement(val.getUses()))
1140     return failure();
1141   // Step 2. Get uses.
1142   OpOperand &operand = (*val.getUses().begin());
1143   Operation *consumerOp = operand.getOwner();
1144   // TODO: We have to init result of consumer before scf.for, use
1145   //       DestinationStyleOpInterface to get result shape from init for now.
1146   //       Add support for other op such as op has InferTypeOpInterface.
1147   if (!isa<TilingInterface>(consumerOp) ||
1148       !isa<DestinationStyleOpInterface>(consumerOp))
1149     return failure();
1150   if (containingOpBlock != consumerOp->getBlock())
1151     return failure();
1152   return &operand;
1153 }
1154 
1155 /// Fetch the untiled consumer of a scf.for's result which is yielded by a
1156 /// tensor.insert_slice. This function makes the following assumptions :
1157 /// 1.  tensor.insert_slice has scf.yield as its only user.
1158 /// 2.  scf.for's corresponding result has only one use.
1159 static FailureOr<OpOperand *>
1160 getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
1161   if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
1162     return failure();
1163   Value sliceResult = candidateSliceOp.getResult();
1164   // Step 1. Fetch the corresponding output.
1165   OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1166   unsigned resultNumber = yieldOpOperand.getOperandNumber();
1167   // Step 2. Check containing op is scf.for.
1168   Operation *containingOp = candidateSliceOp->getParentOp();
1169   auto forOp = dyn_cast<scf::ForOp>(containingOp);
1170   if (!forOp)
1171     return failure();
1172   Value resultingValue = forOp->getResult(resultNumber);
1173 
1174   return getConsumerFromUses(resultingValue, containingOp->getBlock());
1175 }
1176 
1177 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
1178 /// by a tensor.parallel_insert_slice.
1179 static FailureOr<OpOperand *>
1180 getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
1181   // Step 1. Fetch the corresponding output
1182   Value sliceDest = candidateSliceOp.getDest();
1183   auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1184   if (!iterArg)
1185     return failure();
1186   Operation *containingOp = iterArg.getOwner()->getParentOp();
1187   if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
1188     return failure();
1189   // Step 2. Check that the containing op is scf.forall.
1190   auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1191   if (!forallOp)
1192     return failure();
1193   Value resultingValue =
1194       forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
1195 
1196   return getConsumerFromUses(resultingValue, containingOp->getBlock());
1197 }
1198 
1199 /// This utility currently checks whether the loop either :-
1200 /// 1. Yields exactly one result.
1201 /// 2. Has consumer op as its first user and other users to be in the same
1202 /// containing block as that of consumer op's. Currently we clone the loop op
1203 /// right before the consumer op in order to maintain a valid def-use chain.
1204 /// This utility thus helps ensuring that no invalid IR is formed due to the
1205 /// same.
1206 static LogicalResult checkAssumptionForLoop(Operation *loopOp,
1207                                             Operation *consumerOp) {
1208   // Check if the loop op yields one result.
1209   if (loopOp->getNumResults() == 1)
1210     return success();
1211   // Check if the consumerOp is the first user of the loopOp and if other users
1212   // are in the same containing block as that of consumer op's.
1213   Block *parentBlock = consumerOp->getBlock();
1214   for (Operation *userOp : loopOp->getUsers()) {
1215     if (userOp == consumerOp)
1216       continue;
1217     if (parentBlock != userOp->getBlock() ||
1218         !consumerOp->isBeforeInBlock(userOp))
1219       return failure();
1220   }
1221   return success();
1222 }
1223 
1224 /// A utility to fetch an untiled consumer of
1225 /// tensor.insert_slice/tensor.parallel_insert_slice.
1226 static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
1227   if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1228     return getUntiledConsumerFromSlice(insertSlice);
1229   } else if (auto parallelInsertSlice =
1230                  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1231     return getUntiledConsumerFromSlice(parallelInsertSlice);
1232   } else {
1233     return failure();
1234   }
1235 }
1236 
1237 /// After fusing consumer into scf.for we want to modify the scf.yield operation
1238 /// to reflect the same by returning the values yielded by the tiled consumer.
1239 static void
1240 fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
1241                       TilingResult &tilingResult,
1242                       ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
1243                       ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
1244                       ArrayRef<BlockArgument> bbArgs) {
1245   scf::YieldOp oldTerminatorOp =
1246       cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
1247   unsigned totalOldResults = oldTerminatorOp->getNumResults();
1248   unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
1249   SmallVector<Value> newYieldOperands;
1250   newYieldOperands.reserve(totalOldResults + totalTiledResults);
1251   for (auto oldResult : oldTerminatorOp.getResults()) {
1252     newYieldOperands.push_back(oldResult);
1253   }
1254   rewriter.setInsertionPointAfter(oldTerminatorOp);
1255   Location loc = newForOp.getLoc();
1256   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
1257        llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
1258                        resultOffsets, resultSizes)) {
1259     SmallVector<OpFoldResult> strides(resultOffset.size(),
1260                                       rewriter.getIndexAttr(1));
1261     Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
1262         loc, tiledResult, bbArg, resultOffset, resultSize, strides);
1263     newYieldOperands.push_back(newInsertSliceOp);
1264   }
1265   rewriter.create<scf::YieldOp>(loc, newYieldOperands);
1266   rewriter.eraseOp(oldTerminatorOp);
1267 }
1268 
1269 /// After fusing consumer into scf.forall we want to yield each of the resulting
1270 /// values by the tiled consumer within scf.forall.in_parallel region.
1271 static void
1272 fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
1273                            SmallVector<Value> tiledResults,
1274                            ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
1275                            ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
1276                            ArrayRef<BlockArgument> bbArgs) {
1277   scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
1278   rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
1279   Location firstYieldOpLoc =
1280       (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
1281   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
1282        llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
1283     SmallVector<OpFoldResult> strides(resultOffset.size(),
1284                                       rewriter.getIndexAttr(1));
1285     rewriter.create<tensor::ParallelInsertSliceOp>(
1286         firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
1287   }
1288 }
1289 
1290 /// Implementation of fusing consumer of a single slice by computing the
1291 /// slice of the consumer in-place for scf loop.
1292 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1293 mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1294                                       Operation *candidateSliceOp) {
1295   if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1296           candidateSliceOp))
1297     return failure();
1298 
1299   bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1300 
1301   // 1. Get the consumer of scf.for for the result yielded by
1302   // tensor.insert_slice/parallel_insert_slice.
1303   FailureOr<OpOperand *> maybeConsumerOpOperand =
1304       getUntiledConsumerFromSlice(candidateSliceOp);
1305   if (failed(maybeConsumerOpOperand)) {
1306     return rewriter.notifyMatchFailure(candidateSliceOp,
1307                                        "could not fetch consumer to fuse");
1308   }
1309   OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1310   Operation *consumerOp = consumerOpOperand->getOwner();
1311   unsigned operandNumber = consumerOpOperand->getOperandNumber();
1312   unsigned resultNumber = 0;
1313   if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
1314     resultNumber = producerResult.getResultNumber();
1315   } else {
1316     return rewriter.notifyMatchFailure(
1317         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
1318   }
1319 
1320   Operation *oldLoopOp = nullptr;
1321   SmallVector<Value> newOuts;
1322   Block *oldLoopBody = nullptr;
1323   unsigned initSize = 0;
1324   unsigned rank = 1;
1325   if (isInsertSliceOp) {
1326     auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
1327     oldLoopOp = forOp;
1328     llvm::append_range(newOuts, forOp.getInits());
1329     oldLoopBody = forOp.getBody();
1330     initSize = forOp.getInits().size();
1331   } else {
1332     auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
1333     oldLoopOp = forallOp;
1334     llvm::append_range(newOuts, forallOp.getOutputs());
1335     oldLoopBody = forallOp.getBody();
1336     initSize = forallOp.getOutputs().size();
1337     rank = forallOp.getRank();
1338   }
1339 
1340   if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
1341     return rewriter.notifyMatchFailure(
1342         oldLoopOp, "containing loop op should either yield just one value or "
1343                    "have the consumer op as its first user");
1344   }
1345 
1346   OpBuilder::InsertionGuard g(rewriter);
1347 
1348   // 2. Check consumer is not using scf loop's output as init.
1349   auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
1350   SmallVector<Value> dpsInits =
1351       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
1352   if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
1353     return rewriter.notifyMatchFailure(
1354         consumerOp,
1355         "consumer op taking the result of scf.for as init is not supported");
1356   }
1357   newOuts.append(dpsInits);
1358 
1359   Location loc = oldLoopOp->getLoc();
1360 
1361   // 3. Create new scf loop op.
1362   rewriter.setInsertionPoint(consumerOp);
1363   Operation *newLoopOp = nullptr;
1364   Block *newLoopBody = nullptr;
1365   if (isInsertSliceOp) {
1366     auto forOp = cast<scf::ForOp>(oldLoopOp);
1367     auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
1368                                                 forOp.getUpperBound(),
1369                                                 forOp.getStep(), newOuts);
1370     newLoopOp = newForOp;
1371     newLoopBody = newForOp.getBody();
1372   } else {
1373     auto forallOp = cast<scf::ForallOp>(oldLoopOp);
1374     auto newForallOp = rewriter.create<scf::ForallOp>(
1375         loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1376         forallOp.getMixedStep(), newOuts, forallOp.getMapping());
1377     newLoopOp = newForallOp;
1378     rewriter.eraseOp(newForallOp.getTerminator());
1379     newLoopBody = newForallOp.getBody();
1380   }
1381 
1382   // 4. Move the loop body to the new op.
1383   unsigned oldNumArguments = oldLoopBody->getNumArguments();
1384   rewriter.mergeBlocks(oldLoopBody, newLoopBody,
1385                        newLoopBody->getArguments().take_front(oldNumArguments));
1386 
1387   // 5. Set insertion point before terminator op of the loop and create a new
1388   // tensor.insert_slice. In the scf.for case this is a clone of the
1389   // candidateSliceOp whereas in the scf.forall case this is created from the
1390   // operands of tensor.parallel_insert_slice.
1391   tensor::InsertSliceOp clonedInsertSliceOp;
1392   if (auto sliceOp =
1393           dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1394     auto newForallOp = cast<scf::ForallOp>(newLoopOp);
1395     rewriter.setInsertionPoint(newForallOp.getTerminator());
1396     clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
1397         loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
1398         sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
1399   } else {
1400     rewriter.setInsertionPoint(candidateSliceOp);
1401     clonedInsertSliceOp =
1402         cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
1403   }
1404 
1405   // 6.a. Clone consumer op.
1406   auto newForOpBlockArgsForConsumerDest =
1407       newLoopBody->getArguments().drop_front(oldNumArguments);
1408   auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
1409       rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
1410 
1411   // 6.b. Replace all uses of the loop result with the result of the cloned
1412   // tensor.insert_slice.
1413   OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
1414   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
1415     operandToReplace.set(clonedInsertSliceOp.getResult());
1416   });
1417 
1418   // 7 - Perform tiling of the cloned consumer and replace the operand at
1419   // `operandNumber` with the source of the cloned tensor.insert_slice op.
1420   auto ossSliceOp =
1421       cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
1422   FailureOr<TilingResult> tileAndFuseResult =
1423       tensor::replaceInsertSliceWithTiledConsumer(
1424           rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
1425   if (failed(tileAndFuseResult)) {
1426     return failure();
1427   }
1428   rewriter.replaceAllUsesWith(
1429       tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
1430       clonedInsertSliceOp.getSource());
1431 
1432   // 8 - Extract offset/sizes/strides required to create the
1433   // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
1434   SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
1435   SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
1436   SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
1437 
1438   // 9. Check all insert stride is 1.
1439   if (llvm::any_of(strides, [](OpFoldResult stride) {
1440         return !isConstantIntValue(stride, 1);
1441       })) {
1442     return rewriter.notifyMatchFailure(
1443         candidateSliceOp, "containingOp's result yield with stride");
1444   }
1445 
1446   // 10. Try to get iter domain position from input position.
1447   SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
1448   if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
1449           rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
1450           iterDomainSizes))) {
1451     return rewriter.notifyMatchFailure(
1452         clonedConsumerOp, "can't get iter domain position from input position");
1453   }
1454 
1455   // 11. Try to fetch the offset and size for all results of the cloned
1456   // consumer. This would then be used to form the corresponding
1457   // tensor.insert_slice/parallel_insert_slice later.
1458   unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
1459   SmallVector<SmallVector<OpFoldResult>> resultOffsets(
1460       totalNumResultsOfConsumer);
1461   SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
1462   for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
1463     if (failed(clonedConsumerOp.getResultTilePosition(
1464             rewriter, idx, iterDomainOffsets, iterDomainSizes,
1465             resultOffsets[idx], resultSizes[idx]))) {
1466       return rewriter.notifyMatchFailure(
1467           clonedConsumerOp,
1468           "can't get result domain position from iter domain position");
1469     }
1470   }
1471 
1472   auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
1473   auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
1474   if (isInsertSliceOp) {
1475     auto newForOp = cast<scf::ForOp>(newLoopOp);
1476     fixTerminatorSCFYield(
1477         rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
1478         newForOp.getBody()->getArguments().drop_front(1 + initSize));
1479   } else {
1480     auto newForallOp = cast<scf::ForallOp>(newLoopOp);
1481     fixTerminatorSCFInParallel(
1482         rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
1483         arrayRefOffsets, arrayRefSizes,
1484         newForallOp.getBody()->getArguments().drop_front(rank + initSize));
1485   }
1486 
1487   // 12. Replace the result of scf loop and consumer op with new loop's results.
1488   for (auto &&[oldResult, newResult] :
1489        llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
1490     rewriter.replaceAllUsesWith(oldResult, newResult);
1491   }
1492 
1493   for (auto &&[oldResult, newResult] :
1494        llvm::zip(consumerOp->getResults(),
1495                  newLoopOp->getResults().drop_front(initSize))) {
1496     rewriter.replaceAllUsesWith(oldResult, newResult);
1497   }
1498 
1499   // 13. Need to erase the old scf loop and the cloned consumer op.
1500   rewriter.eraseOp(oldLoopOp);
1501   rewriter.eraseOp(clonedConsumerOp);
1502 
1503   return scf::SCFFuseConsumerOfSliceResult{
1504       consumerOpOperand,
1505       &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
1506       tileAndFuseResult->tiledOps};
1507 }
1508 
1509 //===----------------------------------------------------------------------===//
1510 // lowerToLoopsUsingSCFForOp implementation.
1511 //===----------------------------------------------------------------------===//
1512 
1513 FailureOr<SmallVector<scf::ForOp>>
1514 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
1515                                      TilingInterface op) {
1516   // TODO: Handle cases where the op has results if needed.
1517   if (op->getNumResults() > 0) {
1518     return rewriter.notifyMatchFailure(
1519         op, "unable to lower to loops operations with return values");
1520   }
1521 
1522   SmallVector<Range> domain = op.getIterationDomain(rewriter);
1523   SmallVector<Value> ivs;
1524   SmallVector<scf::ForOp> loops;
1525   Location loc = op.getLoc();
1526   for (auto loopRange : domain) {
1527     Value offsetVal =
1528         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
1529     Value sizeVal =
1530         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
1531     Value strideVal =
1532         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
1533     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
1534                                             strideVal, ValueRange{});
1535     loops.push_back(loop);
1536     ivs.push_back(loop.getInductionVar());
1537     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
1538   }
1539   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
1540     return failure();
1541   }
1542   return loops;
1543 }
1544