xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision 4435ced94998c00a6589c3500822015b6341c9e3)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
25 #include "mlir/Interfaces/TilingInterface.h"
26 #include "llvm/Support/Debug.h"
27 #include <optional>
28 
29 #define DEBUG_TYPE "tile-using-interface"
30 
31 using namespace mlir;
32 
33 scf::SCFTilingOptions &
34 scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
35   assert(!tileSizeComputationFunction && "tile sizes already set");
36   auto tileSizes = llvm::to_vector(ts);
37   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
38     return tileSizes;
39   };
40   return *this;
41 }
42 
43 /// Helper method to adjust the interchange vector to match the iteration
44 /// domain.
45 static SmallVector<int64_t>
46 fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
47                       size_t iterationDomainSize) {
48   SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
49   if (filledVector.size() < iterationDomainSize) {
50     auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
51     filledVector.append(range.begin(), range.end());
52   }
53   if (filledVector.size() > iterationDomainSize)
54     filledVector.resize(iterationDomainSize);
55   return filledVector;
56 }
57 
58 /// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
59 template <typename SrcOpTy>
60 static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
61   return llvm::to_vector(
62       llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
63 }
64 template <typename SrcOpTy>
65 static SmallVector<Operation *>
66 getAsOperations(const SmallVector<SrcOpTy> &ops) {
67   return getAsOperations(ArrayRef<SrcOpTy>(ops));
68 }
69 
70 /// Convert a list of `Operation *` to a list of `DstOpTy.
71 template <typename DstOpTy>
72 static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
73   return llvm::to_vector(
74       llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
75 }
76 template <typename DstOpTy>
77 static SmallVector<DstOpTy>
78 castToTypedOperations(const SmallVector<Operation *> &ops) {
79   return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // tileUsingSCFForOp implementation.
84 //===----------------------------------------------------------------------===//
85 
86 // Check if `stride` evenly divides the trip count `size - offset`.
87 static bool tileDividesIterationDomain(Range loopRange) {
88   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
89   if (!offsetAsInt)
90     return false;
91   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
92   if (!sizeAsInt)
93     return false;
94   std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
95   if (!strideAsInt)
96     return false;
97   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
98 }
99 
100 /// Returns the bounded tile size given the current `iv`, `loopRange` and
101 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
102 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
103                                        Range loopRange, Value iv,
104                                        OpFoldResult tileSize) {
105   std::optional<int64_t> ts = getConstantIntValue(tileSize);
106   if (ts && ts.value() == 1)
107     return tileSize;
108 
109   if (tileDividesIterationDomain(
110           Range{loopRange.offset, loopRange.size, tileSize}))
111     return tileSize;
112 
113   // The tile size to use (to avoid out of bounds access) is  minimum of
114   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
115   // loop.
116   AffineExpr s0, s1, d0;
117   bindDims(b.getContext(), d0);
118   bindSymbols(b.getContext(), s0, s1);
119   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
120   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
121   return affine::makeComposedFoldedAffineMin(
122       b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
123 }
124 
125 /// Clones the operation and updates the destination if the operation
126 /// implements the `DestinationStyleOpInterface`.
127 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
128                                                   Operation *op,
129                                                   ValueRange newDestArgs) {
130   Operation *clonedOp = rewriter.clone(*op);
131   if (newDestArgs.empty())
132     return clonedOp;
133   if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
134     destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
135   return clonedOp;
136 }
137 
138 /// Generate an empty loop nest that represents the tiled loop nest shell.
139 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
140 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
141 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
142 ///   the tile processed within the inner most loop.
143 /// Note that this methods adds `scf.yield` operation for all but the innermost
144 /// loop. These yield the value returned by the immediately inner loop. The
145 /// caller is expected to add the scf.yield operation for the innermost loop.
146 static SmallVector<scf::ForOp> generateTileLoopNest(
147     OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
148     ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
149     SmallVector<OpFoldResult> &sizes, ValueRange destinationTensors = {}) {
150   if (loopRanges.empty())
151     return {};
152   assert(loopRanges.size() == tileSizes.size() &&
153          "expected as many tile sizes as loop ranges");
154   OpBuilder::InsertionGuard guard(builder);
155   SmallVector<scf::ForOp> loops;
156   offsets.resize(loopRanges.size());
157   sizes.resize(loopRanges.size());
158 
159   for (auto loopRange : llvm::enumerate(loopRanges)) {
160     Value offset =
161         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
162     Value size =
163         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
164     Value tileSize = getValueOrCreateConstantIndexOp(
165         builder, loc, tileSizes[loopRange.index()]);
166     // No loops if tile size is zero. Set offset and size to the loop
167     // offset and size.
168     if (matchPattern(tileSize, m_Zero())) {
169       offsets[loopRange.index()] = offset;
170       sizes[loopRange.index()] = size;
171       continue;
172     }
173 
174     auto loop = builder.create<scf::ForOp>(
175         loc, offset, size, tileSize, destinationTensors,
176         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
177             ValueRange /*iterArgs*/) {
178           sizes[loopRange.index()] =
179               getBoundedTileSize(bodyBuilder, bodyLoc, loopRange.value(), iv,
180                                  getAsOpFoldResult(tileSize));
181         });
182     offsets[loopRange.index()] = loop.getInductionVar();
183     loops.push_back(loop);
184     builder.setInsertionPointToEnd(loop.getBody());
185     destinationTensors = loop.getRegionIterArgs();
186   }
187 
188   // Add the scf.yield operations for all the outer loops.
189   if (!loops.empty()) {
190     for (auto [outerLoop, innerLoop] :
191          llvm::zip_equal(MutableArrayRef(loops).drop_back(),
192                          MutableArrayRef(loops).drop_front())) {
193       builder.setInsertionPointToEnd(outerLoop.getBody());
194       builder.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop.getResults());
195     }
196   }
197   return loops;
198 }
199 
200 /// Method to add new init values to a loop nest. Updates `loops` in-place with
201 /// new loops that use the `newInitValues`.
202 /// The outer-loops are updated to yield the new result values of the inner
203 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get
204 /// the additional values to yield form the innermost loop.
205 static void addInitOperandsToLoopNest(
206     RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loops,
207     ValueRange newInitValues,
208     llvm::function_ref<SmallVector<Value>(RewriterBase &rewriter, Value iv,
209                                           ValueRange newRegionIterArgs)>
210         getNewYieldValsFn) {
211   SmallVector<scf::ForOp> newLoops;
212   if (loops.empty())
213     return;
214   OpBuilder::InsertionGuard g(rewriter);
215   rewriter.setInsertionPoint(loops.front());
216   for (auto &loop : loops) {
217     rewriter.setInsertionPoint(loop);
218 
219     // Create a new loop with the new init values for this loop.
220     SmallVector<Value> newInits = llvm::to_vector(loop.getInitArgs());
221     newInits.append(newInitValues.begin(), newInitValues.end());
222     auto newLoop = rewriter.create<scf::ForOp>(
223         loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
224         loop.getStep(), newInits,
225         [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
226 
227     // Merge the body of the new loop with the body of the old loops.
228     SmallVector<Value> sourceBlockArgs;
229     sourceBlockArgs.push_back(newLoop.getInductionVar());
230     auto newRegionIterArgs = newLoop.getRegionIterArgs();
231     sourceBlockArgs.append(
232         newRegionIterArgs.begin(),
233         std::next(newRegionIterArgs.begin(), loop.getNumResults()));
234     rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(), sourceBlockArgs);
235     rewriter.replaceOp(loop,
236                        newLoop.getResults().take_front(loop.getNumResults()));
237     loop = newLoop;
238     newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
239   }
240 
241   // Update the loop body of the innermost loop to get new yield values.
242   scf::ForOp innerMostLoop = loops.back();
243   auto innerMostYieldOp =
244       cast<scf::YieldOp>(innerMostLoop.getBody()->getTerminator());
245   rewriter.setInsertionPoint(innerMostYieldOp);
246   SmallVector<Value> newYieldVals =
247       getNewYieldValsFn(rewriter, innerMostLoop.getInductionVar(),
248                         innerMostLoop.getRegionIterArgs());
249   SmallVector<Value> newYieldOperands =
250       llvm::to_vector(innerMostYieldOp->getOperands());
251   newYieldOperands.append(newYieldVals);
252   rewriter.replaceOpWithNewOp<scf::YieldOp>(innerMostYieldOp, newYieldOperands);
253 
254   // Make all other loops except the innermost loops yield the values returned
255   // by the inner loop.
256   for (auto [outerLoop, innerLoop] :
257        llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
258     auto outerLoopYield =
259         cast<scf::YieldOp>(outerLoop.getBody()->getTerminator());
260     SmallVector<Value> newYields =
261         llvm::to_vector(outerLoopYield.getOperands());
262     ValueRange additionalYields =
263         innerLoop.getResults().take_back(newInitValues.size());
264     newYields.append(additionalYields.begin(), additionalYields.end());
265     rewriter.setInsertionPoint(outerLoopYield);
266     rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
267   }
268 }
269 
270 /// Implementation of tiling transformation of `op` that implements the
271 /// `TilingInterface` using `scf.for` to iterate over the tiles.
272 FailureOr<scf::SCFTilingResult>
273 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
274                              const scf::SCFTilingOptions &options) {
275   OpBuilder::InsertionGuard guard(rewriter);
276   rewriter.setInsertionPointAfter(op);
277 
278   if (!options.tileSizeComputationFunction) {
279     return rewriter.notifyMatchFailure(
280         op, "missing tile size computation function");
281   }
282 
283   // 1. Get the range of the loops that are represented by the operation.
284   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
285   size_t numLoops = iterationDomain.size();
286   if (numLoops == 0) {
287     return rewriter.notifyMatchFailure(
288         op, "unable to tile op with no iteration domain");
289   }
290   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
291   // skips tiling a particular dimension. This convention is significantly
292   // simpler to handle instead of adjusting affine maps to account for missing
293   // dimensions.
294   SmallVector<OpFoldResult> tileSizeVector =
295       options.tileSizeComputationFunction(rewriter, op);
296   if (tileSizeVector.size() < iterationDomain.size()) {
297     auto zero = rewriter.getIndexAttr(0);
298     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
299   }
300 
301   // 3. Find the destination tensors to use for the operation.
302   SmallVector<Value> destinationTensors;
303   if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
304                                              destinationTensors))) {
305     return rewriter.notifyMatchFailure(op,
306                                        "unable to create destination tensors");
307   }
308 
309   SmallVector<OpFoldResult> offsets, sizes;
310   SmallVector<scf::ForOp> forLoops;
311   {
312     // If there is an interchange specified, permute the iteration domain and
313     // the tile sizes.
314     SmallVector<int64_t> interchangeVector;
315     if (!options.interchangeVector.empty()) {
316       interchangeVector = fillInterchangeVector(options.interchangeVector,
317                                                 iterationDomain.size());
318     }
319     if (!interchangeVector.empty()) {
320       if (!isPermutationVector(interchangeVector)) {
321         return rewriter.notifyMatchFailure(
322             op, "invalid intechange vector, not a permutation of the entire "
323                 "iteration space");
324       }
325 
326       applyPermutationToVector(iterationDomain, interchangeVector);
327       applyPermutationToVector(tileSizeVector, interchangeVector);
328     }
329 
330     // 4. Materialize an empty loop nest that iterates over the tiles. These
331     // loops for now do not return any values even if the original operation has
332     // results.
333     forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
334                                     tileSizeVector, offsets, sizes,
335                                     destinationTensors);
336 
337     if (!interchangeVector.empty()) {
338       auto inversePermutation = invertPermutationVector(interchangeVector);
339       applyPermutationToVector(offsets, inversePermutation);
340       applyPermutationToVector(sizes, inversePermutation);
341     }
342   }
343 
344   LLVM_DEBUG({
345     if (!forLoops.empty()) {
346       llvm::dbgs() << "LoopNest shell :\n";
347       forLoops.front().dump();
348       llvm::dbgs() << "\n";
349     }
350   });
351 
352   // 5. Generate the tiled implementation within the inner most loop.
353   SmallVector<Value> clonedOpDestination = destinationTensors;
354   if (!forLoops.empty()) {
355     rewriter.setInsertionPointToEnd(forLoops.back().getBody());
356     clonedOpDestination =
357         llvm::map_to_vector(forLoops.back().getRegionIterArgs(),
358                             [](BlockArgument b) -> Value { return b; });
359   }
360 
361   // 5a. Clone the operation within the loop body.
362   auto clonedOp = cast<TilingInterface>(
363       cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination));
364 
365   // 5b. Early return cloned op if tiling is not happening. We can not return
366   // the original op because it could lead to
367   // `rewriter.replaceOp(op, op->getResults())` and user would get crash.
368   if (llvm::all_of(tileSizeVector, isZeroIndex)) {
369     return scf::SCFTilingResult{/*tiledOps=*/{clonedOp}, /*loops=*/{},
370                                 clonedOp->getResults()};
371   }
372 
373   // 5c. Tile the cloned operation.
374   FailureOr<TilingResult> tiledImplementation =
375       clonedOp.getTiledImplementation(rewriter, offsets, sizes);
376   if (failed(tiledImplementation)) {
377     return rewriter.notifyMatchFailure(op, "failed to tile operation");
378   }
379 
380   // 5d. Delete the cloned operation.
381   rewriter.eraseOp(clonedOp);
382 
383   // If loops are empty, the tiled op is used as the replacement for the untiled
384   // op.
385   if (forLoops.empty()) {
386     return scf::SCFTilingResult{tiledImplementation->tiledOps,
387                                 getAsOperations(forLoops),
388                                 tiledImplementation->tiledValues};
389   }
390 
391   if (op->getNumResults() == 0) {
392     // The innermost loop does not have a `scf.yield` yet. There is nothing to
393     // return, so generate an empty `scf.yield` operation.
394     rewriter.setInsertionPointToEnd(forLoops.back().getBody());
395     rewriter.create<scf::YieldOp>(op->getLoc());
396     return scf::SCFTilingResult{
397         tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
398   }
399 
400   // 6. Yield all the results of the tiled operation.
401   int64_t numResults = op->getNumResults();
402   SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
403       resultSizesList(numResults);
404   SmallVector<Value> yieldedValues;
405   for (auto [index, tiledValue] :
406        llvm::enumerate(tiledImplementation->tiledValues)) {
407     SmallVector<OpFoldResult> resultOffsets, resultSizes;
408     if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
409                                         resultOffsets, resultSizes))) {
410       return rewriter.notifyMatchFailure(
411           op, "failed to get slice of result produced");
412     }
413     SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
414                                             rewriter.getIndexAttr(1));
415     auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
416         op->getLoc(), tiledValue, clonedOpDestination[index], resultOffsets,
417         resultSizes, resultStrides);
418     yieldedValues.push_back(insertSlice);
419   }
420   rewriter.create<scf::YieldOp>(op->getLoc(), yieldedValues);
421 
422   SmallVector<Value> replacements = llvm::map_to_vector(
423       forLoops.front().getResults(), [](OpResult r) -> Value { return r; });
424   LLVM_DEBUG({
425     if (!forLoops.empty()) {
426       llvm::dbgs() << "After tiled implementation :\n";
427       forLoops.front().dump();
428       llvm::dbgs() << "\n";
429     }
430   });
431   return scf::SCFTilingResult{tiledImplementation->tiledOps,
432                               getAsOperations(forLoops), replacements};
433 }
434 
435 FailureOr<scf::SCFReductionTilingResult>
436 mlir::scf::tileReductionUsingScf(RewriterBase &b,
437                                  PartialReductionOpInterface op,
438                                  ArrayRef<OpFoldResult> tileSizes) {
439   Location loc = op.getLoc();
440   // Ops implementing PartialReductionOpInterface are expected to implement
441   // TilingInterface.
442   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
443   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
444   auto tileSizesVector = llvm::to_vector(tileSizes);
445   if (tileSizesVector.size() < iterationDomain.size()) {
446     auto zero = b.getIndexAttr(0);
447     tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
448                            zero);
449   }
450   if (op->getNumResults() != 1)
451     return b.notifyMatchFailure(
452         op, "don't support ops with multiple results for now");
453   SmallVector<utils::IteratorType> iterators =
454       tilingInterfaceOp.getLoopIteratorTypes();
455 
456   SmallVector<int> reductionDims;
457   for (auto [idx, iteratorType] :
458        llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
459     if (iteratorType == utils::IteratorType::reduction)
460       reductionDims.push_back(idx);
461   }
462 
463   // 2. create the inital tensor value.
464   FailureOr<Operation *> identityTensor =
465       op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
466                                                   reductionDims);
467   if (failed(identityTensor))
468     return b.notifyMatchFailure(op,
469                                 "cannot create a tensor of identity value.");
470   // 3. Create the nested loops.
471   SmallVector<OpFoldResult> offsets, sizes;
472   SmallVector<scf::ForOp> loops =
473       generateTileLoopNest(b, loc, iterationDomain, tileSizesVector, offsets,
474                            sizes, identityTensor.value()->getResults());
475 
476   // 4. Generate the tiled implementation within the inner most loop.
477   // 4a. Clone the operation within the loop body.
478   SmallVector<Value> clonedOpDestination =
479       llvm::map_to_vector(identityTensor.value()->getResults(),
480                           [](OpResult res) -> Value { return res; });
481   if (!loops.empty()) {
482     b.setInsertionPointToEnd(loops.back().getBody());
483     clonedOpDestination =
484         llvm::map_to_vector(loops.back().getRegionIterArgs(),
485                             [](BlockArgument b) -> Value { return b; });
486   }
487   auto clonedOp = cast<PartialReductionOpInterface>(
488       cloneOpAndUpdateDestinationArgs(b, op, clonedOpDestination));
489 
490   // 4b. Tile the cloned operation.
491   Operation *parallelOp = clonedOp.tileToPartialReduction(
492       b, loc, clonedOpDestination, offsets, sizes, reductionDims);
493   // 4c. Delete the cloned operation.
494   b.eraseOp(clonedOp);
495 
496   SmallVector<OpFoldResult> outSizes;
497   for (size_t i = 0; i < offsets.size(); i++) {
498     outSizes.push_back(
499         tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
500   }
501   SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
502   SmallVector<OpFoldResult> outStrides(outOffsets.size(), b.getIndexAttr(1));
503   SmallVector<Value> yieldedVals;
504   auto bbArgs = loops.back().getRegionIterArgs();
505   for (auto [result, bbArg] : llvm::zip(parallelOp->getResults(), bbArgs)) {
506     Value insert = b.create<tensor::InsertSliceOp>(
507         loc, result, bbArg, outOffsets, outSizes, outStrides);
508     yieldedVals.push_back(insert);
509   }
510   b.create<scf::YieldOp>(loc, yieldedVals);
511 
512   SmallVector<Value> replacements = llvm::map_to_vector(
513       loops.front().getResults(), [](OpResult r) -> Value { return r; });
514 
515   // 5. Apply the merge reduction to combine all the partial values.
516   b.setInsertionPointAfter(*loops.begin());
517   Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
518   b.replaceOp(op, mergeOp->getResults());
519 
520   SCFReductionTilingResult results;
521   results.initialOp = *identityTensor;
522   results.loops = std::move(loops);
523   results.parallelTiledOp = parallelOp;
524   results.mergeOp = mergeOp;
525   return results;
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
530 //===----------------------------------------------------------------------===//
531 
532 /// Return the untiled producer whose slice is used in a tiled consumer. The
533 /// method traverses the tile loop nest (`loops`) if needed, and returns the
534 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
535 /// indicates that this is a destination operand of the consumer. If there was
536 /// no loop traversal needed, the second value of the returned tuple is empty.
537 static std::tuple<OpResult, std::optional<OpOperand *>>
538 getUntiledProducerFromSliceSource(OpOperand *source,
539                                   ArrayRef<scf::ForOp> loops) {
540   std::optional<OpOperand *> destinationIterArg;
541   auto loopIt = loops.rbegin();
542   while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
543     scf::ForOp loop = *loopIt;
544     if (iterArg.getOwner()->getParentOp() != loop)
545       break;
546     source = loop.getTiedLoopInit(iterArg);
547     loopIt++;
548   }
549   if (loopIt == loops.rend())
550     destinationIterArg = source;
551   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
552 }
553 
554 /// Implementation of fusing producer of a single slice by computing the
555 /// slice of the producer in-place.
556 std::optional<scf::SCFFuseProducerOfSliceResult>
557 mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
558                                       tensor::ExtractSliceOp candidateSliceOp,
559                                       MutableArrayRef<scf::ForOp> loops) {
560   // 1. Get the producer of the source (potentially walking through
561   // `iter_args` of nested `scf.for`)
562   auto [fusableProducer, destinationInitArg] =
563       getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
564                                         loops);
565   if (!fusableProducer)
566     return std::nullopt;
567   unsigned resultNumber = fusableProducer.getResultNumber();
568 
569   OpBuilder::InsertionGuard g(rewriter);
570   rewriter.setInsertionPoint(candidateSliceOp);
571 
572   // 2. Clone the fused producer
573   // 2a. Compute the destination operands to use for the cloned operation.
574   SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
575   Operation *fusableProducerOp = fusableProducer.getOwner();
576   if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
577       failed(tensor::getOrCreateDestinations(
578           rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
579           origDestinationTensors)))
580     return std::nullopt;
581 
582   clonedOpDestinationTensors = origDestinationTensors;
583   if (destinationInitArg &&
584       isa<DestinationStyleOpInterface>(fusableProducerOp)) {
585     // 2b. If the producer is also destination style, then to maintain the
586     // destination passing style, update the destination of the producer to be
587     // the source of the slice.
588     clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
589   }
590   // 2c. Clone the fused producer.
591   Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
592       rewriter, fusableProducerOp, clonedOpDestinationTensors);
593   // 2d. Update the source of the candidateSlice to be the cloned producer.
594   //     Easier to just clone the slice with different source since replacements
595   //     and DCE of cloned ops becomes easier
596   SmallVector<Value> candidateSliceOpOperands =
597       llvm::to_vector(candidateSliceOp->getOperands());
598   candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
599   tensor::ExtractSliceOp clonedCandidateSliceOp =
600       mlir::clone(rewriter, candidateSliceOp,
601                   candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
602 
603   // 3. Generate the tiled implementation of the producer of the source
604   FailureOr<TilingResult> tileAndFuseResult =
605       tensor::replaceExtractSliceWithTiledProducer(
606           rewriter, clonedCandidateSliceOp,
607           clonedProducerOp->getResult(resultNumber));
608   if (failed(tileAndFuseResult))
609     return std::nullopt;
610   // Note: Do not delete the candidateSliceOp, since its passed in from the
611   // caller.
612   rewriter.replaceAllUsesWith(candidateSliceOp,
613                               tileAndFuseResult->tiledValues[0]);
614   rewriter.eraseOp(clonedCandidateSliceOp);
615   rewriter.eraseOp(clonedProducerOp);
616 
617   // 3. If the slice is for a destination operand, for example,
618   //
619   // ```mlir
620   // %0 = linalg.init
621   // %1 = linalg.fill .. outs(%0 : )
622   // %2 = scf.for .. iter_args(%arg0 = %1) {
623   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
624   //     %4 = tensor.extract_slice %arg1 [..]
625   //     .. = linalg.matmul .. outs(%4 : )
626   //   }
627   // }
628   // ```
629   //
630   // the IR is currently
631   //
632   // ```
633   // %0 = linalg.init
634   // %1 = linalg.fill
635   // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
636   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
637   //     %4 = tensor.extract_slice %arg1[..]
638   //     %5 = linalg.fill .. outs(%4 : )
639   //     .. = linalg.matmul .. outs(%5 : )
640   //   }
641   // }
642   // ```
643   //
644   // The untiled `linalg.fill` is still used as the `init_value` since it
645   // was originally a destination operand of the untiled `linalg.matmul`.
646   // When fusing an operand that is a destination operand, the iter_arg of
647   // the outer most loop should be changed to use the destination of the
648   // fused operation. With this the IR will be.
649   //
650   // ```
651   // %0 = linalg.init
652   // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
653   //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
654   //     %3 = tensor.extract_slice %arg1[..]
655   //     %4 = linalg.fill .. outs(%3 : )
656   //     .. = linalg.matmul .. outs(%4 : )
657   //   }
658   // }
659   // ```
660   if (destinationInitArg &&
661       isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
662     loops.front()
663         ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
664         .set(origDestinationTensors[resultNumber]);
665   }
666   return scf::SCFFuseProducerOfSliceResult{fusableProducer,
667                                            tileAndFuseResult->tiledValues[0],
668                                            tileAndFuseResult->tiledOps};
669 }
670 
671 /// Reconstruct the fused producer from within the tiled-and-fused code.
672 void mlir::scf::yieldReplacementForFusedProducer(
673     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
674     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
675     MutableArrayRef<scf::ForOp> loops) {
676   if (loops.empty())
677     return;
678 
679   OpResult fusableProducer = fusedProducerInfo.origProducer;
680   Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
681   FailureOr<Value> initValue = tensor::getOrCreateDestination(
682       rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
683   if (succeeded(initValue)) {
684 
685     auto newYieldValuesFn =
686         [&](RewriterBase &innerRewriter, Value iv,
687             ValueRange newRegionIterArgs) -> SmallVector<Value> {
688       OpBuilder::InsertionGuard g(innerRewriter);
689       if (auto tiledDestStyleOp =
690               tiledAndFusedProducer
691                   .getDefiningOp<DestinationStyleOpInterface>()) {
692         rewriter.setInsertionPoint(tiledDestStyleOp);
693         BlockArgument newRegionArg = loops.back().getRegionIterArgs().back();
694         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
695             sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
696             sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
697         unsigned resultNumber = fusableProducer.getResultNumber();
698         rewriter.updateRootInPlace(tiledDestStyleOp, [&]() {
699           tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
700         });
701       }
702       Block *block = rewriter.getInsertionPoint()->getBlock();
703       rewriter.setInsertionPoint(block->getTerminator());
704       Value replacement = rewriter.create<tensor::InsertSliceOp>(
705           fusedProducerInfo.origProducer.getLoc(),
706           fusedProducerInfo.tiledAndFusedProducer,
707           loops.back().getRegionIterArgs().back(), sliceOp.getMixedOffsets(),
708           sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
709       return {replacement};
710     };
711 
712     addInitOperandsToLoopNest(rewriter, loops,
713                               SmallVector<Value>{initValue.value()},
714                               newYieldValuesFn);
715   }
716 }
717 
718 /// Implementation of tile consumer and fuse producer greedily.
719 FailureOr<scf::SCFTileAndFuseResult>
720 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
721     RewriterBase &rewriter, TilingInterface consumer,
722     const scf::SCFTileAndFuseOptions &options) {
723   // This transformation is only valid for ops that return values (i.e. not
724   // valid to use with operations that have memref operands).
725   if (!consumer->getNumResults()) {
726     return rewriter.notifyMatchFailure(
727         consumer, "invalid pattern for op with no results");
728   }
729 
730   // 1. First tile the consumer.
731   SetVector<Operation *> fusedProducers, tiledAndFusedOps;
732   llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
733   FailureOr<scf::SCFTilingResult> tilingResult =
734       tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
735   if (failed(tilingResult))
736     return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
737   for (auto *tiledOp : tilingResult->tiledOps)
738     tiledAndFusedOps.insert(tiledOp);
739   SmallVector<scf::ForOp> forLoops =
740       castToTypedOperations<scf::ForOp>(tilingResult->loops);
741 
742   // If there are no loops generated, fusion is immaterial.
743   if (forLoops.empty()) {
744     DenseMap<Value, Value> replacements;
745     for (auto [origVal, replacement] :
746          llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
747       replacements[origVal] = replacement;
748     }
749     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
750                                      getAsOperations(forLoops), replacements};
751   }
752 
753   // To keep track of replacements for now just record the map from the original
754   // untiled value to the result number of the for loop. Since the loop gets
755   // potentially replaced during fusion, keeping the value directly wont work.
756   DenseMap<Value, size_t> origValToResultNumber;
757   for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
758     origValToResultNumber[result] = index;
759   }
760 
761   // 2. Typically, the operands of the tiled operation are slices of the
762   //    operands of the untiled operation. These are expressed in IR using
763   //    `tensor.extract_slice` operations with source being the operands of the
764   //    untiled operation. Create a worklist of these `tensor.extract_slice`
765   //    operations. If the producers of the source of the `tensor.extract_slice`
766   //    can be tiled such that the tiled value is generated in-place, that
767   //    effectively tiles + fuses the operations.
768   auto addCandidateSlices = [](Operation *fusedOp,
769                                std::deque<tensor::ExtractSliceOp> &candidates) {
770     for (Value operand : fusedOp->getOperands())
771       if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
772         candidates.push_back(sliceOp);
773   };
774 
775   std::deque<tensor::ExtractSliceOp> candidates;
776   addCandidateSlices(tiledAndFusedOps.back(), candidates);
777   OpBuilder::InsertionGuard g(rewriter);
778   while (!candidates.empty()) {
779     // Traverse the slices in BFS fashion.
780     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
781     candidates.pop_front();
782 
783     // Find the original producer of the slice.
784     auto [fusableProducer, destinationInitArg] =
785         getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
786                                           forLoops);
787     if (!fusableProducer)
788       continue;
789 
790     auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
791         candidateSliceOp, fusableProducer, destinationInitArg.has_value());
792     if (!fuseSlice)
793       continue;
794 
795     // The operands of the fused producer might themselved be slices of
796     // values produced by operations that implement the `TilingInterface`.
797     // Add these operations to the worklist.
798     std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
799         tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
800     if (!fusedResult)
801       continue;
802 
803     if (yieldReplacement) {
804       yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
805                                        fusedResult.value(), forLoops);
806       origValToResultNumber[fusableProducer] =
807           forLoops.front().getNumResults() - 1;
808     }
809 
810     if (Operation *tiledAndFusedOp =
811             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
812       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
813       tiledAndFusedOps.insert(tiledAndFusedOp);
814       addCandidateSlices(tiledAndFusedOp, candidates);
815     }
816   }
817 
818   DenseMap<Value, Value> replacements;
819   for (auto [origVal, resultNumber] : origValToResultNumber) {
820     replacements[origVal] = forLoops.front()->getResult(resultNumber);
821   }
822 
823   return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
824                                    getAsOperations(forLoops), replacements};
825 }
826 
827 //===----------------------------------------------------------------------===//
828 // tileUsingSCFForAllOp implementation.
829 //===----------------------------------------------------------------------===//
830 
831 FailureOr<scf::SCFTilingResult>
832 mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
833                                 const scf::SCFTilingOptions &options) {
834   Location loc = op->getLoc();
835   OpBuilder::InsertionGuard g(rewriter);
836 
837   // 1. Get the range of loops that are represented by the operation.
838   SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
839   if (loopRanges.empty())
840     return op->emitOpError("expected non-empty loop ranges");
841   auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
842   if (llvm::any_of(loopRanges, hasStrideOne))
843     return op->emitOpError("only stride-1 supported atm");
844 
845   // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
846   // To make it easier, pad the tile sizes to loopRanges.size with value 0.
847   SmallVector<OpFoldResult> tileSizeVector =
848       options.tileSizeComputationFunction(rewriter, op);
849   tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
850 
851   // 3. Build the offsets, sizes and steps for the tile and distributed loops.
852   SmallVector<OpFoldResult> lbs, ubs, steps;
853   for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
854     if (isConstantIntValue(tileSize, 0))
855       continue;
856     lbs.push_back(loopRange.offset);
857     ubs.push_back(loopRange.size);
858     steps.push_back(tileSize);
859   }
860 
861   // 4. Gather destination tensors.
862   SmallVector<Value> dest;
863   if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
864     return op->emitOpError("failed to get destination tensors");
865 
866   // 5. Build the device mapping attribute.
867   std::optional<ArrayAttr> mappingAttr;
868   if (!options.mappingVector.empty()) {
869     mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
870   }
871 
872   // 6. Create the ForallOp. We don't use the lambda body-builder
873   // version because we require the use of RewriterBase in the body, so we
874   // manually move the insertion point to the body below.
875   auto forallOp =
876       rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);
877 
878   // 7. Get the tile offset and sizes.
879   rewriter.setInsertionPoint(forallOp.getTerminator());
880   SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
881   ValueRange ivs = forallOp.getInductionVars();
882   {
883     int materializedLoopNum = 0;
884     for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
885       if (isConstantIntValue(tileSize, 0)) {
886         tiledOffsets.push_back(loopRange.offset);
887         tiledSizes.push_back(loopRange.size);
888         continue;
889       }
890       Value iv = ivs[materializedLoopNum++];
891       tiledOffsets.push_back(iv);
892       tiledSizes.push_back(
893           getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
894     }
895   }
896 
897   // 8. Tile the operation. Clone the operation to allow fix up of destination
898   // operands.
899   ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
900   Operation *clonedOp =
901       cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
902   FailureOr<TilingResult> tilingResult =
903       cast<TilingInterface>(clonedOp).getTiledImplementation(
904           rewriter, tiledOffsets, tiledSizes);
905   if (failed(tilingResult))
906     return clonedOp->emitError("failed to tile op: ");
907   rewriter.eraseOp(clonedOp);
908 
909   // 9. Parallel insert back into the result tensor.
910   for (auto [index, tiledValue, destBBArg] :
911        llvm::enumerate(tilingResult->tiledValues, destBbArgs)) {
912     // 9.a. Partial subset information is inserted just before the terminator.
913     rewriter.setInsertionPoint(forallOp.getTerminator());
914 
915     SmallVector<OpFoldResult> resultOffsets, resultSizes;
916     if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
917                                         tiledSizes, resultOffsets,
918                                         resultSizes))) {
919       return op->emitOpError("output offsets couldn't be calculated");
920     }
921 
922     SmallVector<OpFoldResult> strides(resultSizes.size(),
923                                       rewriter.getIndexAttr(1));
924     // 9.b. Parallel insertions are inserted at the end of the combining
925     // terminator.
926     rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
927     rewriter.create<tensor::ParallelInsertSliceOp>(
928         loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
929   }
930 
931   // 10. Return the tiling result.
932   return scf::SCFTilingResult{
933       tilingResult->tiledOps,
934       {forallOp.getOperation()},
935       llvm::map_to_vector(forallOp.getResults(),
936                           [](auto val) -> Value { return val; })};
937 }
938 
939 //===----------------------------------------------------------------------===//
940 // lowerToLoopsUsingSCFForOp implementation.
941 //===----------------------------------------------------------------------===//
942 
943 FailureOr<SmallVector<scf::ForOp>>
944 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
945                                      TilingInterface op) {
946   // TODO: Handle cases where the op has results if needed.
947   if (op->getNumResults() > 0) {
948     return rewriter.notifyMatchFailure(
949         op, "unable to lower to loops operations with return values");
950   }
951 
952   SmallVector<Range> domain = op.getIterationDomain(rewriter);
953   SmallVector<Value> ivs;
954   SmallVector<scf::ForOp> loops;
955   Location loc = op.getLoc();
956   for (auto loopRange : domain) {
957     Value offsetVal =
958         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
959     Value sizeVal =
960         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
961     Value strideVal =
962         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
963     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
964                                             strideVal, ValueRange{});
965     loops.push_back(loop);
966     ivs.push_back(loop.getInductionVar());
967     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
968   }
969   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
970     return failure();
971   }
972   return loops;
973 }
974