xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision 91bbebc7e118cceae1fc0e349de08094a3cd2fe7)
1cf6a7c19SMahesh Ravishankar //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2cf6a7c19SMahesh Ravishankar //
3cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information.
5cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6cf6a7c19SMahesh Ravishankar //
7cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
8cf6a7c19SMahesh Ravishankar //
9cf6a7c19SMahesh Ravishankar // This file implements the tiling using TilingInterface.
10cf6a7c19SMahesh Ravishankar //
11cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
12cf6a7c19SMahesh Ravishankar 
138b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14cf6a7c19SMahesh Ravishankar 
159bc3102bSYun-Fly #include "mlir/Analysis/SliceAnalysis.h"
169bc3102bSYun-Fly #include "mlir/Analysis/TopologicalSortUtils.h"
17cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
19abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h"
20cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Func/IR/FuncOps.h"
21cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/Utils/Utils.h"
22cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h"
23b1d3afc9SHanhan Wang #include "mlir/Dialect/Utils/IndexingUtils.h"
242b2ce50fSAbhishek Varma #include "mlir/IR/Dominance.h"
25cf6a7c19SMahesh Ravishankar #include "mlir/IR/Matchers.h"
26cf6a7c19SMahesh Ravishankar #include "mlir/IR/PatternMatch.h"
27b169643fSMatthias Springer #include "mlir/Interfaces/DestinationStyleOpInterface.h"
28cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h"
299144fed3SQuinn Dawkins #include "mlir/Rewrite/FrozenRewritePatternSet.h"
309144fed3SQuinn Dawkins #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
316e3631d0SKunwar Grover #include "llvm/ADT/ScopeExit.h"
3276ead96cSMaheshRavishankar #include "llvm/ADT/TypeSwitch.h"
33cf6a7c19SMahesh Ravishankar #include "llvm/Support/Debug.h"
345c9013e2SKazu Hirata #include <optional>
35cf6a7c19SMahesh Ravishankar 
36cf6a7c19SMahesh Ravishankar #define DEBUG_TYPE "tile-using-interface"
37cf6a7c19SMahesh Ravishankar 
38cf6a7c19SMahesh Ravishankar using namespace mlir;
39cf6a7c19SMahesh Ravishankar 
40cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions &
41170a25a7SMaheshRavishankar scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
42cf6a7c19SMahesh Ravishankar   assert(!tileSizeComputationFunction && "tile sizes already set");
43170a25a7SMaheshRavishankar   auto tileSizes = llvm::to_vector(ts);
44cf6a7c19SMahesh Ravishankar   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
45170a25a7SMaheshRavishankar     return tileSizes;
46cf6a7c19SMahesh Ravishankar   };
47cf6a7c19SMahesh Ravishankar   return *this;
48cf6a7c19SMahesh Ravishankar }
49cf6a7c19SMahesh Ravishankar 
506740d701SMaheshRavishankar scf::SCFTilingOptions &
516740d701SMaheshRavishankar scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
526740d701SMaheshRavishankar   assert(!numThreadsComputationFunction && "num tiles already set");
536740d701SMaheshRavishankar   auto numThreads = llvm::to_vector(nt);
546740d701SMaheshRavishankar   numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
556740d701SMaheshRavishankar     return numThreads;
566740d701SMaheshRavishankar   };
576740d701SMaheshRavishankar   return *this;
586740d701SMaheshRavishankar }
596740d701SMaheshRavishankar 
60b8a1f00dSMahesh Ravishankar /// Helper method to adjust the interchange vector to match the iteration
61b8a1f00dSMahesh Ravishankar /// domain.
6279150279SNicolas Vasilache static SmallVector<int64_t>
6379150279SNicolas Vasilache fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
64b8a1f00dSMahesh Ravishankar                       size_t iterationDomainSize) {
6579150279SNicolas Vasilache   SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
66b8a1f00dSMahesh Ravishankar   if (filledVector.size() < iterationDomainSize) {
6779150279SNicolas Vasilache     auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
68b8a1f00dSMahesh Ravishankar     filledVector.append(range.begin(), range.end());
69b8a1f00dSMahesh Ravishankar   }
70b8a1f00dSMahesh Ravishankar   if (filledVector.size() > iterationDomainSize)
71b8a1f00dSMahesh Ravishankar     filledVector.resize(iterationDomainSize);
72b8a1f00dSMahesh Ravishankar   return filledVector;
73b8a1f00dSMahesh Ravishankar }
74b8a1f00dSMahesh Ravishankar 
752f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===//
7676ead96cSMaheshRavishankar // tileUsingSCF implementation.
772f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===//
782f637fe7SMahesh Ravishankar 
796740d701SMaheshRavishankar /// Verify the tile size options are set in a consistent manner.
806740d701SMaheshRavishankar static LogicalResult
816740d701SMaheshRavishankar verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
826740d701SMaheshRavishankar                       const scf::SCFTilingOptions &options) {
836740d701SMaheshRavishankar   // Specifying number of threads is only supported on `scf.forall` op.
846740d701SMaheshRavishankar   if (options.numThreadsComputationFunction &&
856740d701SMaheshRavishankar       options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
866740d701SMaheshRavishankar     return rewriter.notifyMatchFailure(
876740d701SMaheshRavishankar         loc, "number of threads can only by specified when loop type is "
886740d701SMaheshRavishankar              "set to use `scf.forall`");
896740d701SMaheshRavishankar   }
906740d701SMaheshRavishankar 
916740d701SMaheshRavishankar   // If specified, check that the interchange vector is a permutation.
926740d701SMaheshRavishankar   if (!options.interchangeVector.empty()) {
936740d701SMaheshRavishankar     if (!isPermutationVector(options.interchangeVector)) {
946740d701SMaheshRavishankar       return rewriter.notifyMatchFailure(
956740d701SMaheshRavishankar           loc, "invalid interchange vector, not a permutation of the entire "
966740d701SMaheshRavishankar                "iteration space");
976740d701SMaheshRavishankar     }
986740d701SMaheshRavishankar   }
996740d701SMaheshRavishankar   return success();
1006740d701SMaheshRavishankar }
1016740d701SMaheshRavishankar 
1026740d701SMaheshRavishankar /// Method to instantiate the tile sizes and/or number of threads specified
1036740d701SMaheshRavishankar /// by the user.
1046740d701SMaheshRavishankar static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
1056740d701SMaheshRavishankar getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
1066740d701SMaheshRavishankar                               ArrayRef<Range> iterationDomain,
1076740d701SMaheshRavishankar                               const scf::SCFTilingOptions &options) {
1086740d701SMaheshRavishankar   OpFoldResult zero = rewriter.getIndexAttr(0);
1096740d701SMaheshRavishankar   SmallVector<OpFoldResult> tileSizes, numThreads;
1106740d701SMaheshRavishankar   size_t numLoops = iterationDomain.size();
1116740d701SMaheshRavishankar 
1126740d701SMaheshRavishankar   // Check whether the number of tiles to use is specified.
1136740d701SMaheshRavishankar   if (options.numThreadsComputationFunction) {
1146740d701SMaheshRavishankar     numThreads = options.numThreadsComputationFunction(rewriter, op);
1156740d701SMaheshRavishankar     numThreads.resize(numLoops, zero);
1166740d701SMaheshRavishankar 
1176740d701SMaheshRavishankar     // If the number of tiles is also specified, use that.
1186740d701SMaheshRavishankar     if (options.tileSizeComputationFunction) {
1196740d701SMaheshRavishankar       tileSizes = options.tileSizeComputationFunction(rewriter, op);
1206740d701SMaheshRavishankar       tileSizes.resize(numLoops, zero);
1216740d701SMaheshRavishankar       return {tileSizes, numThreads};
1226740d701SMaheshRavishankar     }
1236740d701SMaheshRavishankar 
1246740d701SMaheshRavishankar     // Compute the tile sizes from the iteration domain and number
1256740d701SMaheshRavishankar     // of tiles as follows
1266740d701SMaheshRavishankar     // - niters = ceilDiv(ub - lb, step)
1276740d701SMaheshRavishankar     // - tileSize = ceilDiv(niters, numThreads)
1286740d701SMaheshRavishankar     AffineExpr s0, s1, s2;
1296740d701SMaheshRavishankar     bindSymbols(rewriter.getContext(), s0, s1, s2);
1306740d701SMaheshRavishankar     // TODO: The step here is assumed to be 1.
1316740d701SMaheshRavishankar     AffineExpr numItersExpr = (s1 - s0);
1326740d701SMaheshRavishankar     AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
1336740d701SMaheshRavishankar     tileSizes.resize(numLoops, zero);
1346740d701SMaheshRavishankar     for (auto [index, range, nt] :
1356740d701SMaheshRavishankar          llvm::enumerate(iterationDomain, numThreads)) {
1366740d701SMaheshRavishankar       if (isConstantIntValue(nt, 0))
1376740d701SMaheshRavishankar         continue;
1386740d701SMaheshRavishankar 
1396740d701SMaheshRavishankar       tileSizes[index] = affine::makeComposedFoldedAffineApply(
1406740d701SMaheshRavishankar           rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
1416740d701SMaheshRavishankar     }
1426740d701SMaheshRavishankar     tileSizes.resize(numLoops, zero);
1436740d701SMaheshRavishankar     return {tileSizes, numThreads};
1446740d701SMaheshRavishankar   }
1456740d701SMaheshRavishankar 
1466740d701SMaheshRavishankar   // Enforce the convention that "tiling by zero"
1476740d701SMaheshRavishankar   // skips tiling a particular dimension. This convention is significantly
1486740d701SMaheshRavishankar   // simpler to handle instead of adjusting affine maps to account for missing
1496740d701SMaheshRavishankar   // dimensions.
1506740d701SMaheshRavishankar   assert(options.tileSizeComputationFunction &&
1516740d701SMaheshRavishankar          "expected tile sizes to be specified");
1526740d701SMaheshRavishankar   tileSizes = options.tileSizeComputationFunction(rewriter, op);
1536740d701SMaheshRavishankar   tileSizes.resize(numLoops, zero);
1546740d701SMaheshRavishankar 
1556740d701SMaheshRavishankar   return {tileSizes, numThreads};
1566740d701SMaheshRavishankar }
1576740d701SMaheshRavishankar 
1586740d701SMaheshRavishankar /// Checks if any of the tiled loops are not parallel.
1596740d701SMaheshRavishankar static void checkSafeToTileToForall(TilingInterface op,
1606740d701SMaheshRavishankar                                     ArrayRef<OpFoldResult> tileSizes,
1616740d701SMaheshRavishankar                                     ArrayRef<OpFoldResult> numThreads) {
1626740d701SMaheshRavishankar   auto iterators = op.getLoopIteratorTypes();
1636740d701SMaheshRavishankar   assert(iterators.size() == tileSizes.size() &&
1646740d701SMaheshRavishankar          "expected as many tile size values as number of loops");
1656740d701SMaheshRavishankar   assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
1666740d701SMaheshRavishankar          "when specified, expected number of threads to use for each loop");
1676740d701SMaheshRavishankar 
1686740d701SMaheshRavishankar   for (auto [index, iterator, tileSize] :
1696740d701SMaheshRavishankar        llvm::enumerate(iterators, tileSizes)) {
1706740d701SMaheshRavishankar     // If num threads is specified, check that it is greater than one only for
1716740d701SMaheshRavishankar     // parallel dimensions.
1726740d701SMaheshRavishankar     if (!numThreads.empty()) {
1736740d701SMaheshRavishankar       if (std::optional<int64_t> constNumThreads =
1746740d701SMaheshRavishankar               getConstantIntValue(numThreads[index])) {
1756740d701SMaheshRavishankar         if (constNumThreads.value() > 1 &&
1766740d701SMaheshRavishankar             iterator != utils::IteratorType::parallel) {
1776740d701SMaheshRavishankar           op.emitWarning() << "tiling is not thread safe at axis #" << index;
1786740d701SMaheshRavishankar         }
1796740d701SMaheshRavishankar       }
1806740d701SMaheshRavishankar       continue;
1816740d701SMaheshRavishankar     }
1826740d701SMaheshRavishankar 
1836740d701SMaheshRavishankar     if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
1846740d701SMaheshRavishankar       if (constTileSize.value() > 0 &&
1856740d701SMaheshRavishankar           iterator != utils::IteratorType::parallel) {
1866740d701SMaheshRavishankar         op.emitWarning() << "tiling is not thread safe at axis #" << index;
1876740d701SMaheshRavishankar       }
1886740d701SMaheshRavishankar     }
1896740d701SMaheshRavishankar   }
1906740d701SMaheshRavishankar }
1916740d701SMaheshRavishankar 
1926740d701SMaheshRavishankar /// Check if `stride` evenly divides the trip count `size - offset`.
193954de25aSlorenzo chelini static bool tileDividesIterationDomain(Range loopRange) {
19422426110SRamkumar Ramachandra   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
195954de25aSlorenzo chelini   if (!offsetAsInt)
196954de25aSlorenzo chelini     return false;
19722426110SRamkumar Ramachandra   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
198954de25aSlorenzo chelini   if (!sizeAsInt)
199954de25aSlorenzo chelini     return false;
20022426110SRamkumar Ramachandra   std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
201954de25aSlorenzo chelini   if (!strideAsInt)
202954de25aSlorenzo chelini     return false;
203954de25aSlorenzo chelini   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
204954de25aSlorenzo chelini }
205954de25aSlorenzo chelini 
2066740d701SMaheshRavishankar /// Returns the bounded tile size given the current `offset`, `loopRange` and
2076740d701SMaheshRavishankar /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
20871cf48a6SHanhan Wang static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
2096740d701SMaheshRavishankar                                        Range loopRange, OpFoldResult offset,
210d871daeaSMaheshRavishankar                                        OpFoldResult tileSize) {
211d2b7a8e8SAdrian Kuegel   std::optional<int64_t> ts = getConstantIntValue(tileSize);
212d2b7a8e8SAdrian Kuegel   if (ts && ts.value() == 1)
213d871daeaSMaheshRavishankar     return tileSize;
21471cf48a6SHanhan Wang 
21571cf48a6SHanhan Wang   if (tileDividesIterationDomain(
21671cf48a6SHanhan Wang           Range{loopRange.offset, loopRange.size, tileSize}))
21771cf48a6SHanhan Wang     return tileSize;
21871cf48a6SHanhan Wang 
21971cf48a6SHanhan Wang   // The tile size to use (to avoid out of bounds access) is  minimum of
22071cf48a6SHanhan Wang   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
22171cf48a6SHanhan Wang   // loop.
22271cf48a6SHanhan Wang   AffineExpr s0, s1, d0;
22371cf48a6SHanhan Wang   bindDims(b.getContext(), d0);
22471cf48a6SHanhan Wang   bindSymbols(b.getContext(), s0, s1);
2256740d701SMaheshRavishankar   AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
22671cf48a6SHanhan Wang   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
2274c48f016SMatthias Springer   return affine::makeComposedFoldedAffineMin(
2286740d701SMaheshRavishankar       b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
2296740d701SMaheshRavishankar }
2306740d701SMaheshRavishankar 
2316740d701SMaheshRavishankar /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
2326740d701SMaheshRavishankar /// than `iterationSize`.
2336740d701SMaheshRavishankar static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
2346740d701SMaheshRavishankar                                            OpFoldResult numThreads,
2356740d701SMaheshRavishankar                                            OpFoldResult iterationSize) {
2366740d701SMaheshRavishankar   std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
2376740d701SMaheshRavishankar   std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
2386740d701SMaheshRavishankar   std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
2396740d701SMaheshRavishankar   if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
2406740d701SMaheshRavishankar     return false;
2416740d701SMaheshRavishankar   return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
2426740d701SMaheshRavishankar }
2436740d701SMaheshRavishankar 
2446740d701SMaheshRavishankar /// Compute the `OpFoldResult`s that represents the multi-dimensional
2456740d701SMaheshRavishankar /// `offset`s and `size`s of the tile of the iteration space that the
2466740d701SMaheshRavishankar /// innermost loop body of the generated tiled loops corresponds to.
2476740d701SMaheshRavishankar static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
2486740d701SMaheshRavishankar getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
2496740d701SMaheshRavishankar                       ArrayRef<Range> iterationDomain,
2506740d701SMaheshRavishankar                       ArrayRef<OpFoldResult> tileSizes,
2516740d701SMaheshRavishankar                       ArrayRef<OpFoldResult> numThreads) {
2526740d701SMaheshRavishankar   SmallVector<OpFoldResult> offsets, sizes;
2536740d701SMaheshRavishankar   int materializedLoopNum = 0;
2546740d701SMaheshRavishankar 
2556740d701SMaheshRavishankar   if (!numThreads.empty()) {
2566740d701SMaheshRavishankar     AffineExpr d0, d1, s0, s1;
2576740d701SMaheshRavishankar     AffineExpr offsetExpr, residualTileSizeExpr;
2586740d701SMaheshRavishankar     bindDims(rewriter.getContext(), d0, d1);
2596740d701SMaheshRavishankar     bindSymbols(rewriter.getContext(), s0, s1);
2606740d701SMaheshRavishankar     offsetExpr = d0 + d1 * s0;
2616740d701SMaheshRavishankar     residualTileSizeExpr = s1 - (d0 + d1 * s0);
2626740d701SMaheshRavishankar 
2636740d701SMaheshRavishankar     for (auto [nt, tileSize, loopRange] :
2646740d701SMaheshRavishankar          llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
2656740d701SMaheshRavishankar 
2666740d701SMaheshRavishankar       // Non-tiled cases, set the offset and size to the
2676740d701SMaheshRavishankar       // `loopRange.offset/size`.
2686740d701SMaheshRavishankar       if (isConstantIntValue(nt, 0)) {
2696740d701SMaheshRavishankar         offsets.push_back(loopRange.offset);
2706740d701SMaheshRavishankar         sizes.push_back(loopRange.size);
2716740d701SMaheshRavishankar         continue;
2726740d701SMaheshRavishankar       }
2736740d701SMaheshRavishankar 
2746740d701SMaheshRavishankar       Value iv = ivs[materializedLoopNum++];
2756740d701SMaheshRavishankar       OpFoldResult offset = affine::makeComposedFoldedAffineApply(
2766740d701SMaheshRavishankar           rewriter, loc, offsetExpr,
2776740d701SMaheshRavishankar           ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
2786740d701SMaheshRavishankar       OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
2796740d701SMaheshRavishankar           rewriter, loc, residualTileSizeExpr,
2806740d701SMaheshRavishankar           {loopRange.offset, nt, tileSize, loopRange.size});
2816740d701SMaheshRavishankar 
2826740d701SMaheshRavishankar       OpFoldResult size = tileSize;
2836740d701SMaheshRavishankar       if (!isConstantIntValue(residualTileSize, 0)) {
2846740d701SMaheshRavishankar         OpFoldResult sizeMinusOffsetPerThread =
2856740d701SMaheshRavishankar             affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
2866740d701SMaheshRavishankar                                                   {offset, loopRange.size});
2876740d701SMaheshRavishankar         size = affine::makeComposedFoldedAffineMin(
2886740d701SMaheshRavishankar             rewriter, loc,
2896740d701SMaheshRavishankar             AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
2906740d701SMaheshRavishankar             {sizeMinusOffsetPerThread, tileSize});
2916740d701SMaheshRavishankar       }
2926740d701SMaheshRavishankar 
2936740d701SMaheshRavishankar       // Consider the case where the original loop was `[0, 100)`.
2946740d701SMaheshRavishankar       // If number of threads are `7`, the tile size would be computed as
2956740d701SMaheshRavishankar       // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
2966740d701SMaheshRavishankar       // - `offset = 0 + 6 * 15 = 105`
2976740d701SMaheshRavishankar       // - `tileSize = min(15, 100 - 105) = -5`
2986740d701SMaheshRavishankar       // To avoid negative tile sizes, we need to do a further
2996740d701SMaheshRavishankar       // `nonNegativeTileSize = affine.max(0, tileSize)`.
3006740d701SMaheshRavishankar       // This `max` can be avoided if
3016740d701SMaheshRavishankar       //  `offset + tileSize * (numThreads - 1) < (ub - lb)`
3026740d701SMaheshRavishankar       if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
3036740d701SMaheshRavishankar         AffineMap maxMap =
3046740d701SMaheshRavishankar             AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
3056740d701SMaheshRavishankar         size = affine::makeComposedFoldedAffineMax(
3066740d701SMaheshRavishankar             rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
3076740d701SMaheshRavishankar       }
3086740d701SMaheshRavishankar 
3096740d701SMaheshRavishankar       offsets.push_back(offset);
3106740d701SMaheshRavishankar       sizes.push_back(size);
3116740d701SMaheshRavishankar     }
3126740d701SMaheshRavishankar     return {offsets, sizes};
3136740d701SMaheshRavishankar   } else {
3146740d701SMaheshRavishankar     for (auto [tileSize, loopRange] :
3156740d701SMaheshRavishankar          llvm::zip_equal(tileSizes, iterationDomain)) {
3166740d701SMaheshRavishankar 
3176740d701SMaheshRavishankar       // Non-tiled cases, set the offset and size to the
3186740d701SMaheshRavishankar       // `loopRange.offset/size`.
3196740d701SMaheshRavishankar       if (isConstantIntValue(tileSize, 0)) {
3206740d701SMaheshRavishankar         offsets.push_back(loopRange.offset);
3216740d701SMaheshRavishankar         sizes.push_back(loopRange.size);
3226740d701SMaheshRavishankar         continue;
3236740d701SMaheshRavishankar       }
3246740d701SMaheshRavishankar 
3256740d701SMaheshRavishankar       Value iv = ivs[materializedLoopNum++];
3266740d701SMaheshRavishankar       OpFoldResult offset = getAsOpFoldResult(iv);
3276740d701SMaheshRavishankar       offsets.push_back(offset);
3286740d701SMaheshRavishankar       OpFoldResult size =
3296740d701SMaheshRavishankar           getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
3306740d701SMaheshRavishankar       sizes.push_back(size);
3316740d701SMaheshRavishankar     }
3326740d701SMaheshRavishankar     return {offsets, sizes};
3336740d701SMaheshRavishankar   }
3346740d701SMaheshRavishankar }
3356740d701SMaheshRavishankar 
3366740d701SMaheshRavishankar /// Function to return the bounds of the loops to be generated.
3376740d701SMaheshRavishankar static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
3386740d701SMaheshRavishankar                   SmallVector<OpFoldResult>>
3396740d701SMaheshRavishankar getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
3406740d701SMaheshRavishankar               ArrayRef<OpFoldResult> tileSizes) {
3416740d701SMaheshRavishankar   SmallVector<OpFoldResult> lbs, ubs, steps;
3426740d701SMaheshRavishankar   for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
3436740d701SMaheshRavishankar     // No loop if the tile size is 0.
3446740d701SMaheshRavishankar     if (isConstantIntValue(tileSize, 0))
3456740d701SMaheshRavishankar       continue;
3466740d701SMaheshRavishankar     lbs.push_back(loopRange.offset);
3476740d701SMaheshRavishankar     ubs.push_back(loopRange.size);
3486740d701SMaheshRavishankar     steps.push_back(tileSize);
3496740d701SMaheshRavishankar   }
3506740d701SMaheshRavishankar   return {lbs, ubs, steps};
35171cf48a6SHanhan Wang }
35271cf48a6SHanhan Wang 
35376ead96cSMaheshRavishankar /// A function that allows returning additional yielded values during
35476ead96cSMaheshRavishankar /// `yieldTiledValuesAndReplace`.
35576ead96cSMaheshRavishankar /// - `ivs` induction variable for the loop.
35676ead96cSMaheshRavishankar /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
35776ead96cSMaheshRavishankar /// - `tiledValues` the tiled values to return. Must be of same size as
35876ead96cSMaheshRavishankar ///   `newbbArgs`, each element of this array is inserted into the corresponding
35976ead96cSMaheshRavishankar ///   element in `newbbArgs`.
36076ead96cSMaheshRavishankar /// - `resultOffsets` is of the same size as `tiledValues` and represents
36176ead96cSMaheshRavishankar ///   the offsets to use when inserting corresponding element from `tiledValues`
36276ead96cSMaheshRavishankar ///   into the element from `newBbArgs`.
36376ead96cSMaheshRavishankar /// - `resultSizes` is of the same size as `tiledValues` and represents
36476ead96cSMaheshRavishankar ///   the size of the corresponding element from `tiledValues` inserted into
36576ead96cSMaheshRavishankar ///   the element from `newBbArgs`.
36676ead96cSMaheshRavishankar /// In case the method needs to return `failure()` the method is expected
36776ead96cSMaheshRavishankar /// to clean up any inserted operations.
36876ead96cSMaheshRavishankar using YieldTiledValuesFn = std::function<LogicalResult(
36976ead96cSMaheshRavishankar     RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
37076ead96cSMaheshRavishankar     SmallVector<Value> &tiledValues,
37176ead96cSMaheshRavishankar     SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
37276ead96cSMaheshRavishankar     SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
37376ead96cSMaheshRavishankar 
374d871daeaSMaheshRavishankar /// Clones the operation and updates the destination if the operation
375d871daeaSMaheshRavishankar /// implements the `DestinationStyleOpInterface`.
376d871daeaSMaheshRavishankar static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
377d871daeaSMaheshRavishankar                                                   Operation *op,
378d871daeaSMaheshRavishankar                                                   ValueRange newDestArgs) {
379d871daeaSMaheshRavishankar   Operation *clonedOp = rewriter.clone(*op);
3804a020018SMaheshRavishankar   if (newDestArgs.empty())
3814a020018SMaheshRavishankar     return clonedOp;
3824a020018SMaheshRavishankar   if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
383d871daeaSMaheshRavishankar     destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
384d871daeaSMaheshRavishankar   return clonedOp;
385d871daeaSMaheshRavishankar }
386d871daeaSMaheshRavishankar 
38776ead96cSMaheshRavishankar /// Generate the tile-loop nest using `scf.for` operation.
388cf6a7c19SMahesh Ravishankar /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
389170a25a7SMaheshRavishankar /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
39076ead96cSMaheshRavishankar /// - `destinationTensors` are the init values to use for the outer most loop.
39176ead96cSMaheshRavishankar /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
39276ead96cSMaheshRavishankar /// most
39376ead96cSMaheshRavishankar ///    loop.
39476ead96cSMaheshRavishankar /// - `loops` is an in-out parameter into which the generated loops are
39576ead96cSMaheshRavishankar ///    populated.
39676ead96cSMaheshRavishankar static LogicalResult generateLoopNestUsingForOp(
39776ead96cSMaheshRavishankar     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
39876ead96cSMaheshRavishankar     ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
39976ead96cSMaheshRavishankar     YieldTiledValuesFn yieldTiledValuesFn,
40076ead96cSMaheshRavishankar     SmallVector<LoopLikeOpInterface> &loops) {
40176ead96cSMaheshRavishankar   assert(!loopRanges.empty() && "unexpected empty loop ranges");
402170a25a7SMaheshRavishankar   assert(loopRanges.size() == tileSizes.size() &&
403cf6a7c19SMahesh Ravishankar          "expected as many tile sizes as loop ranges");
40476ead96cSMaheshRavishankar   OpBuilder::InsertionGuard guard(rewriter);
4056740d701SMaheshRavishankar 
4066740d701SMaheshRavishankar   SmallVector<OpFoldResult> lbs, ubs, steps;
4076740d701SMaheshRavishankar   std::tie(lbs, ubs, steps) =
4086740d701SMaheshRavishankar       getLoopBounds(rewriter, loc, loopRanges, tileSizes);
4096740d701SMaheshRavishankar   SmallVector<Value> lbVals =
4106740d701SMaheshRavishankar       getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
4116740d701SMaheshRavishankar   SmallVector<Value> ubVals =
4126740d701SMaheshRavishankar       getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
4136740d701SMaheshRavishankar   SmallVector<Value> stepVals =
4146740d701SMaheshRavishankar       getValueOrCreateConstantIndexOp(rewriter, loc, steps);
4156740d701SMaheshRavishankar 
41676ead96cSMaheshRavishankar   SmallVector<Value> ivs;
4176740d701SMaheshRavishankar   for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
41876ead96cSMaheshRavishankar     auto loop =
41976ead96cSMaheshRavishankar         rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
42076ead96cSMaheshRavishankar                                     [](OpBuilder &bodyBuilder, Location bodyLoc,
42176ead96cSMaheshRavishankar                                        Value iv, ValueRange /*iterArgs*/) {});
422cf6a7c19SMahesh Ravishankar     loops.push_back(loop);
42376ead96cSMaheshRavishankar     ivs.push_back(loop.getInductionVar());
42476ead96cSMaheshRavishankar     rewriter.setInsertionPointToEnd(loop.getBody());
4254a020018SMaheshRavishankar     destinationTensors = loop.getRegionIterArgs();
4264a020018SMaheshRavishankar   }
4274a020018SMaheshRavishankar 
42876ead96cSMaheshRavishankar   SmallVector<Value> tiledResults;
42976ead96cSMaheshRavishankar   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
43076ead96cSMaheshRavishankar   if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
43176ead96cSMaheshRavishankar                                 tiledResults, resultOffsets, resultSizes))) {
43276ead96cSMaheshRavishankar     return rewriter.notifyMatchFailure(
43376ead96cSMaheshRavishankar         loc, "failed to generate inner tile loop body");
43476ead96cSMaheshRavishankar   }
43576ead96cSMaheshRavishankar   if (loops.empty())
43676ead96cSMaheshRavishankar     return success();
43776ead96cSMaheshRavishankar 
4389329b20dSKunwar Grover   assert(tiledResults.size() == destinationTensors.size() &&
4399329b20dSKunwar Grover          "Number of results of body should be equal to number of iter args");
4409329b20dSKunwar Grover 
44176ead96cSMaheshRavishankar   // 6. Yield all the results of the tiled operation.
44276ead96cSMaheshRavishankar   SmallVector<Value> yieldedValues;
44376ead96cSMaheshRavishankar   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
44476ead96cSMaheshRavishankar        llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
44576ead96cSMaheshRavishankar                        resultSizes)) {
44676ead96cSMaheshRavishankar     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
44776ead96cSMaheshRavishankar                                            rewriter.getIndexAttr(1));
44876ead96cSMaheshRavishankar     auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
44976ead96cSMaheshRavishankar         loc, tiledValue, destinationTensor, resultOffset, resultSize,
45076ead96cSMaheshRavishankar         resultStride);
45176ead96cSMaheshRavishankar     yieldedValues.push_back(insertSlice);
45276ead96cSMaheshRavishankar   }
45376ead96cSMaheshRavishankar   rewriter.create<scf::YieldOp>(loc, yieldedValues);
45476ead96cSMaheshRavishankar 
4554a020018SMaheshRavishankar   // Add the scf.yield operations for all the outer loops.
4564a020018SMaheshRavishankar   for (auto [outerLoop, innerLoop] :
4574a020018SMaheshRavishankar        llvm::zip_equal(MutableArrayRef(loops).drop_back(),
4584a020018SMaheshRavishankar                        MutableArrayRef(loops).drop_front())) {
45976ead96cSMaheshRavishankar     rewriter.setInsertionPointToEnd(
46076ead96cSMaheshRavishankar         cast<scf::ForOp>(outerLoop.getOperation()).getBody());
46176ead96cSMaheshRavishankar     rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
4624a020018SMaheshRavishankar   }
46376ead96cSMaheshRavishankar   return success();
464cf6a7c19SMahesh Ravishankar }
46576ead96cSMaheshRavishankar 
46676ead96cSMaheshRavishankar /// Generate the tile-loop nest using `scf.forall` operation.
46776ead96cSMaheshRavishankar /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
46876ead96cSMaheshRavishankar /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
46976ead96cSMaheshRavishankar /// - `destinationTensors` are the init values to use for the outer most loop.
47076ead96cSMaheshRavishankar /// - `mappingVector` is the mapping attributes to use for loop construction.
47176ead96cSMaheshRavishankar ///   Can be empty.
47276ead96cSMaheshRavishankar /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
47376ead96cSMaheshRavishankar /// most
47476ead96cSMaheshRavishankar ///    loop.
47576ead96cSMaheshRavishankar /// - `loops` is an in-out parameter into which the generated loops are
47676ead96cSMaheshRavishankar ///    populated.
47776ead96cSMaheshRavishankar static LogicalResult generateLoopNestUsingForallOp(
47876ead96cSMaheshRavishankar     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
4796740d701SMaheshRavishankar     ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
4806740d701SMaheshRavishankar     ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
4816740d701SMaheshRavishankar     YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
48276ead96cSMaheshRavishankar   assert(!loopRanges.empty() && "unexpected empty loop ranges");
48376ead96cSMaheshRavishankar   assert(loopRanges.size() == tileSizes.size() &&
48476ead96cSMaheshRavishankar          "expected as many tile sizes as loop ranges");
48576ead96cSMaheshRavishankar   OpBuilder::InsertionGuard guard(rewriter);
48676ead96cSMaheshRavishankar   SmallVector<OpFoldResult> offsets(loopRanges.size()),
48776ead96cSMaheshRavishankar       sizes(loopRanges.size());
48876ead96cSMaheshRavishankar 
48976ead96cSMaheshRavishankar   std::optional<ArrayAttr> mappingAttr;
49076ead96cSMaheshRavishankar   if (!mappingVector.empty())
49176ead96cSMaheshRavishankar     mappingAttr = rewriter.getArrayAttr(mappingVector);
49276ead96cSMaheshRavishankar 
4936740d701SMaheshRavishankar   scf::ForallOp forallOp;
4946740d701SMaheshRavishankar   bool useNumThreads = !numThreads.empty();
4956740d701SMaheshRavishankar 
4966740d701SMaheshRavishankar   if (useNumThreads) {
4976740d701SMaheshRavishankar     // Prune the zero numthreads.
4986740d701SMaheshRavishankar     SmallVector<OpFoldResult> nonZeroNumThreads;
4996740d701SMaheshRavishankar     for (auto nt : numThreads) {
5006740d701SMaheshRavishankar       if (isConstantIntValue(nt, 0))
5016740d701SMaheshRavishankar         continue;
5026740d701SMaheshRavishankar       nonZeroNumThreads.push_back(nt);
5036740d701SMaheshRavishankar     }
5046740d701SMaheshRavishankar     forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
5056740d701SMaheshRavishankar                                               destinationTensors, mappingAttr);
5066740d701SMaheshRavishankar   } else {
5076740d701SMaheshRavishankar     SmallVector<OpFoldResult> lbs, ubs, steps;
5086740d701SMaheshRavishankar     std::tie(lbs, ubs, steps) =
5096740d701SMaheshRavishankar         getLoopBounds(rewriter, loc, loopRanges, tileSizes);
5106740d701SMaheshRavishankar     forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
5116740d701SMaheshRavishankar                                               destinationTensors, mappingAttr);
5126740d701SMaheshRavishankar   }
51376ead96cSMaheshRavishankar   loops.push_back(forallOp);
51476ead96cSMaheshRavishankar 
51576ead96cSMaheshRavishankar   rewriter.setInsertionPoint(forallOp.getTerminator());
51676ead96cSMaheshRavishankar   destinationTensors = forallOp.getRegionOutArgs();
51776ead96cSMaheshRavishankar 
51876ead96cSMaheshRavishankar   SmallVector<Value> tiledResults;
51976ead96cSMaheshRavishankar   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
52076ead96cSMaheshRavishankar   if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
52176ead96cSMaheshRavishankar                          destinationTensors, tiledResults, resultOffsets,
52276ead96cSMaheshRavishankar                          resultSizes)))
52376ead96cSMaheshRavishankar     return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
52476ead96cSMaheshRavishankar 
52576ead96cSMaheshRavishankar   rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
52676ead96cSMaheshRavishankar   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
52776ead96cSMaheshRavishankar        llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
52876ead96cSMaheshRavishankar                        resultSizes)) {
52976ead96cSMaheshRavishankar     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
53076ead96cSMaheshRavishankar                                            rewriter.getIndexAttr(1));
53176ead96cSMaheshRavishankar 
53276ead96cSMaheshRavishankar     rewriter.create<tensor::ParallelInsertSliceOp>(
53376ead96cSMaheshRavishankar         loc, tiledValue, destinationTensor, resultOffset, resultSize,
53476ead96cSMaheshRavishankar         resultStride);
53576ead96cSMaheshRavishankar   }
53676ead96cSMaheshRavishankar   return success();
53776ead96cSMaheshRavishankar }
53876ead96cSMaheshRavishankar 
53976ead96cSMaheshRavishankar /// Generate the tile-loop nest using the loop construct specifed in `options`.
54076ead96cSMaheshRavishankar /// - `options`: Tiling options specified.
54176ead96cSMaheshRavishankar /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
54276ead96cSMaheshRavishankar /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
54376ead96cSMaheshRavishankar /// - `destinationTensors` are the init values to use for the outer most loop.
54476ead96cSMaheshRavishankar /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
54576ead96cSMaheshRavishankar /// most
54676ead96cSMaheshRavishankar ///    loop.
54776ead96cSMaheshRavishankar /// - `loops` is an in-out parameter into which the generated loops are
54876ead96cSMaheshRavishankar ///    populated.
5496740d701SMaheshRavishankar static LogicalResult generateLoopNest(
5506740d701SMaheshRavishankar     RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
5516740d701SMaheshRavishankar     ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
5526740d701SMaheshRavishankar     ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
5536740d701SMaheshRavishankar     YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
55476ead96cSMaheshRavishankar   // If the tile sizes are all zero, no loops are generated. Just call the
55576ead96cSMaheshRavishankar   // callback function to handle untiled case.
55676ead96cSMaheshRavishankar   if (llvm::all_of(tileSizes, isZeroIndex)) {
55776ead96cSMaheshRavishankar     SmallVector<Value> tiledResults;
55876ead96cSMaheshRavishankar     SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
55976ead96cSMaheshRavishankar     return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
56076ead96cSMaheshRavishankar                        tiledResults, resultOffsets, resultSizes);
56176ead96cSMaheshRavishankar   }
56276ead96cSMaheshRavishankar   if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
56376ead96cSMaheshRavishankar     return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
56476ead96cSMaheshRavishankar                                       destinationTensors, tiledBodyFn, loops);
56576ead96cSMaheshRavishankar   }
56676ead96cSMaheshRavishankar   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
56776ead96cSMaheshRavishankar     return generateLoopNestUsingForallOp(
5686740d701SMaheshRavishankar         rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
56976ead96cSMaheshRavishankar         destinationTensors, tiledBodyFn, loops);
57076ead96cSMaheshRavishankar   }
57176ead96cSMaheshRavishankar   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
57276ead96cSMaheshRavishankar }
57376ead96cSMaheshRavishankar 
5744b563458SKunwar Grover static FailureOr<SmallVector<Value>>
5754b563458SKunwar Grover createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
5764b563458SKunwar Grover                               ArrayRef<OpFoldResult> tileSizes,
5774b563458SKunwar Grover                               const scf::SCFTilingOptions &options) {
5784b563458SKunwar Grover   SmallVector<Value> initTensors;
5794b563458SKunwar Grover   Location loc = op->getLoc();
5804b563458SKunwar Grover   switch (options.reductionStrategy) {
5814b563458SKunwar Grover   case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
5824b563458SKunwar Grover     if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
5834b563458SKunwar Grover       return failure();
5844b563458SKunwar Grover     return initTensors;
5854b563458SKunwar Grover   case scf::SCFTilingOptions::ReductionTilingStrategy::
5864b563458SKunwar Grover       PartialReductionOuterReduction: {
5874b563458SKunwar Grover     auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
5884b563458SKunwar Grover     if (!redOp) {
5894b563458SKunwar Grover       return rewriter.notifyMatchFailure(
5904b563458SKunwar Grover           op, "PartialReductionOuterReduction tiling strategy is only supported"
5914b563458SKunwar Grover               "for operations implementing PartialReductionOpInterface");
5924b563458SKunwar Grover     }
5934b563458SKunwar Grover     // Get reduction dimensions.
5944b563458SKunwar Grover     // TODO: PartialReductionOpInterface should really query TilingInterface
5954b563458SKunwar Grover     // itself and find reduction dimensions.
5964b563458SKunwar Grover     SmallVector<int> reductionDims;
5974b563458SKunwar Grover     for (auto [idx, iteratorType] :
5984b563458SKunwar Grover          llvm::enumerate(op.getLoopIteratorTypes())) {
5994b563458SKunwar Grover       if (iteratorType == utils::IteratorType::reduction)
6004b563458SKunwar Grover         reductionDims.push_back(idx);
6014b563458SKunwar Grover     }
6024b563458SKunwar Grover     return redOp.generateInitialTensorForPartialReduction(
6034b563458SKunwar Grover         rewriter, loc, tileSizes, reductionDims);
6044b563458SKunwar Grover   }
6054b563458SKunwar Grover   default:
6064b563458SKunwar Grover     return rewriter.notifyMatchFailure(op,
6074b563458SKunwar Grover                                        "unhandled reduction tiling strategy");
6084b563458SKunwar Grover   }
6094b563458SKunwar Grover }
6104b563458SKunwar Grover 
6114b563458SKunwar Grover static FailureOr<TilingResult>
6124b563458SKunwar Grover getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
6134b563458SKunwar Grover                        ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
6144b563458SKunwar Grover                        ArrayRef<OpFoldResult> sizes,
6154b563458SKunwar Grover                        const scf::SCFTilingOptions &options) {
6164b563458SKunwar Grover   switch (options.reductionStrategy) {
6174b563458SKunwar Grover   case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
6184b563458SKunwar Grover     return op.getTiledImplementation(rewriter, offsets, sizes);
6194b563458SKunwar Grover   case scf::SCFTilingOptions::ReductionTilingStrategy::
6204b563458SKunwar Grover       PartialReductionOuterReduction: {
6214b563458SKunwar Grover     auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
6224b563458SKunwar Grover     if (!redOp) {
6234b563458SKunwar Grover       return rewriter.notifyMatchFailure(
6244b563458SKunwar Grover           op, "PartialReductionOuterReduction tiling strategy is only "
6254b563458SKunwar Grover               "supported for operations "
6264b563458SKunwar Grover               "implementing PartialReductionOpInterface");
6274b563458SKunwar Grover     }
6284b563458SKunwar Grover     // Get reduction dimensions.
6294b563458SKunwar Grover     // TODO: PartialReductionOpInterface should really query TilingInterface
6304b563458SKunwar Grover     // itself and find reduction dimensions.
6314b563458SKunwar Grover     SmallVector<int> reductionDims;
6324b563458SKunwar Grover     for (auto [idx, iteratorType] :
6334b563458SKunwar Grover          llvm::enumerate(op.getLoopIteratorTypes())) {
6344b563458SKunwar Grover       if (iteratorType == utils::IteratorType::reduction)
6354b563458SKunwar Grover         reductionDims.push_back(idx);
6364b563458SKunwar Grover     }
6374b563458SKunwar Grover     return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
6384b563458SKunwar Grover                                         offsets, sizes, reductionDims);
6394b563458SKunwar Grover   }
6404b563458SKunwar Grover   default:
6414b563458SKunwar Grover     return rewriter.notifyMatchFailure(op,
6424b563458SKunwar Grover                                        "unhandled reduction tiling strategy");
6434b563458SKunwar Grover   }
6444b563458SKunwar Grover }
6454b563458SKunwar Grover 
6464b563458SKunwar Grover static LogicalResult
6474b563458SKunwar Grover getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
6484b563458SKunwar Grover                       TilingInterface op, ArrayRef<OpFoldResult> offsets,
6494b563458SKunwar Grover                       ArrayRef<OpFoldResult> sizes,
6504b563458SKunwar Grover                       SmallVector<OpFoldResult> &resultOffset,
6514b563458SKunwar Grover                       SmallVector<OpFoldResult> &resultSize,
6524b563458SKunwar Grover                       const scf::SCFTilingOptions &options) {
6534b563458SKunwar Grover 
6544b563458SKunwar Grover   switch (options.reductionStrategy) {
6554b563458SKunwar Grover   case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
6564b563458SKunwar Grover     return op.getResultTilePosition(rewriter, index, offsets, sizes,
6574b563458SKunwar Grover                                     resultOffset, resultSize);
6584b563458SKunwar Grover   case scf::SCFTilingOptions::ReductionTilingStrategy::
6594b563458SKunwar Grover       PartialReductionOuterReduction: {
660*91bbebc7SKunwar Grover     auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
661*91bbebc7SKunwar Grover     if (!redOp) {
662*91bbebc7SKunwar Grover       return rewriter.notifyMatchFailure(
663*91bbebc7SKunwar Grover           op, "PartialReductionOuterReduction tiling strategy is only supported"
664*91bbebc7SKunwar Grover               "for operations implementing PartialReductionOpInterface");
6654b563458SKunwar Grover     }
666*91bbebc7SKunwar Grover     // Get reduction dimensions.
667*91bbebc7SKunwar Grover     // TODO: PartialReductionOpInterface should really query TilingInterface
668*91bbebc7SKunwar Grover     // itself and find reduction dimensions.
669*91bbebc7SKunwar Grover     SmallVector<int> reductionDims;
670*91bbebc7SKunwar Grover     for (auto [idx, iteratorType] :
671*91bbebc7SKunwar Grover          llvm::enumerate(op.getLoopIteratorTypes())) {
672*91bbebc7SKunwar Grover       if (iteratorType == utils::IteratorType::reduction)
673*91bbebc7SKunwar Grover         reductionDims.push_back(idx);
674*91bbebc7SKunwar Grover     }
675*91bbebc7SKunwar Grover     return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
676*91bbebc7SKunwar Grover                                               resultOffset, resultSize,
677*91bbebc7SKunwar Grover                                               reductionDims);
678*91bbebc7SKunwar Grover   }
6794b563458SKunwar Grover   default:
6804b563458SKunwar Grover     return rewriter.notifyMatchFailure(op,
6814b563458SKunwar Grover                                        "unhandled reduction tiling strategy");
6824b563458SKunwar Grover   }
6834b563458SKunwar Grover }
6844b563458SKunwar Grover 
6854b563458SKunwar Grover static FailureOr<MergeResult>
6864b563458SKunwar Grover mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
6874b563458SKunwar Grover                    ValueRange partialResults,
6884b563458SKunwar Grover                    const scf::SCFTilingOptions &options) {
6894b563458SKunwar Grover   switch (options.reductionStrategy) {
6904b563458SKunwar Grover   case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
6914b563458SKunwar Grover     // No need to merge results for reduction tiling strategy.
6924b563458SKunwar Grover     return MergeResult{{}, partialResults};
6934b563458SKunwar Grover   case scf::SCFTilingOptions::ReductionTilingStrategy::
6944b563458SKunwar Grover       PartialReductionOuterReduction: {
6954b563458SKunwar Grover     auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
6964b563458SKunwar Grover     if (!redOp) {
6974b563458SKunwar Grover       return rewriter.notifyMatchFailure(
6984b563458SKunwar Grover           op, "PartialReductionOuterReduction tiling strategy is only "
6994b563458SKunwar Grover               "supported for operations "
7004b563458SKunwar Grover               "implementing PartialReductionOpInterface");
7014b563458SKunwar Grover     }
7024b563458SKunwar Grover     // Get reduction dimensions.
7034b563458SKunwar Grover     // TODO: PartialReductionOpInterface should really query TilingInterface
7044b563458SKunwar Grover     // itself and find reduction dimensions.
7054b563458SKunwar Grover     SmallVector<int> reductionDims;
7064b563458SKunwar Grover     for (auto [idx, iteratorType] :
7074b563458SKunwar Grover          llvm::enumerate(op.getLoopIteratorTypes())) {
7084b563458SKunwar Grover       if (iteratorType == utils::IteratorType::reduction)
7094b563458SKunwar Grover         reductionDims.push_back(idx);
7104b563458SKunwar Grover     }
7114b563458SKunwar Grover     return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
7124b563458SKunwar Grover                                  reductionDims);
7134b563458SKunwar Grover   }
7144b563458SKunwar Grover   default:
7154b563458SKunwar Grover     return rewriter.notifyMatchFailure(op,
7164b563458SKunwar Grover                                        "unhandled reduction tiling strategy");
7174b563458SKunwar Grover   }
7184b563458SKunwar Grover }
7194b563458SKunwar Grover 
72076ead96cSMaheshRavishankar /// Append the specified additional `newInitOperands` operands to the
72176ead96cSMaheshRavishankar /// loops existing `init` operands (or similar), and replace `loopOp` with
72276ead96cSMaheshRavishankar /// the new loop that has the additional init operands. The loop body of
72376ead96cSMaheshRavishankar /// this loop is moved over to the new loop. `yieldTiledValuesFn`
72476ead96cSMaheshRavishankar /// is called to get the new tiled values returned, and the offset
72576ead96cSMaheshRavishankar /// and sizes at which the tiled value is inserted into the
72676ead96cSMaheshRavishankar /// new region iter_args that correspond to the newly added init operands.
72776ead96cSMaheshRavishankar template <typename LoopType>
72876ead96cSMaheshRavishankar FailureOr<LoopLikeOpInterface>
72976ead96cSMaheshRavishankar yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
73076ead96cSMaheshRavishankar                                ValueRange newInitOperands,
73176ead96cSMaheshRavishankar                                YieldTiledValuesFn yieldTiledValuesFn) {
73276ead96cSMaheshRavishankar   return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
73376ead96cSMaheshRavishankar }
73476ead96cSMaheshRavishankar 
73576ead96cSMaheshRavishankar /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
73676ead96cSMaheshRavishankar template <>
73776ead96cSMaheshRavishankar FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
73876ead96cSMaheshRavishankar     scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
73976ead96cSMaheshRavishankar     YieldTiledValuesFn yieldTiledValuesFn) {
74076ead96cSMaheshRavishankar   OpBuilder::InsertionGuard g(rewriter);
74176ead96cSMaheshRavishankar   Location loc = loopOp.getLoc();
74276ead96cSMaheshRavishankar   rewriter.setInsertionPoint(loopOp);
74376ead96cSMaheshRavishankar 
74476ead96cSMaheshRavishankar   auto inits = llvm::to_vector(loopOp.getInitArgs());
74576ead96cSMaheshRavishankar   inits.append(newInitOperands.begin(), newInitOperands.end());
74676ead96cSMaheshRavishankar   auto newLoop = rewriter.create<scf::ForOp>(
74776ead96cSMaheshRavishankar       loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
74876ead96cSMaheshRavishankar       inits, [](OpBuilder &, Location, Value, ValueRange) {});
74976ead96cSMaheshRavishankar 
75076ead96cSMaheshRavishankar   // Move the loop body to the new op.
75176ead96cSMaheshRavishankar   Block *loopBody = loopOp.getBody();
75276ead96cSMaheshRavishankar   Block *newLoopBody = newLoop.getBody();
75376ead96cSMaheshRavishankar   rewriter.mergeBlocks(
75476ead96cSMaheshRavishankar       loopBody, newLoopBody,
75576ead96cSMaheshRavishankar       newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
75676ead96cSMaheshRavishankar 
75776ead96cSMaheshRavishankar   auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
75876ead96cSMaheshRavishankar   rewriter.setInsertionPoint(yieldOp);
75976ead96cSMaheshRavishankar 
76076ead96cSMaheshRavishankar   SmallVector<Value> tiledValues;
76176ead96cSMaheshRavishankar   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
76276ead96cSMaheshRavishankar   ValueRange newRegionIterArgs =
76376ead96cSMaheshRavishankar       newLoop.getRegionIterArgs().take_back(newInitOperands.size());
76476ead96cSMaheshRavishankar   if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
76576ead96cSMaheshRavishankar                                 newRegionIterArgs, tiledValues, resultOffsets,
76676ead96cSMaheshRavishankar                                 resultSizes))) {
76776ead96cSMaheshRavishankar     rewriter.eraseOp(newLoop);
76876ead96cSMaheshRavishankar     return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
76976ead96cSMaheshRavishankar   }
77076ead96cSMaheshRavishankar 
77176ead96cSMaheshRavishankar   SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
77276ead96cSMaheshRavishankar   for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
77376ead96cSMaheshRavishankar        llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
77476ead96cSMaheshRavishankar                        resultSizes)) {
77576ead96cSMaheshRavishankar     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
77676ead96cSMaheshRavishankar                                            rewriter.getIndexAttr(1));
77776ead96cSMaheshRavishankar     Value insert = rewriter.create<tensor::InsertSliceOp>(
77876ead96cSMaheshRavishankar         yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
77976ead96cSMaheshRavishankar         resultStride);
78076ead96cSMaheshRavishankar     newYieldValues.push_back(insert);
78176ead96cSMaheshRavishankar   }
78276ead96cSMaheshRavishankar 
78376ead96cSMaheshRavishankar   rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
78476ead96cSMaheshRavishankar   rewriter.replaceOp(loopOp,
78576ead96cSMaheshRavishankar                      newLoop->getResults().take_front(loopOp.getNumResults()));
78676ead96cSMaheshRavishankar   return cast<LoopLikeOpInterface>(newLoop.getOperation());
78776ead96cSMaheshRavishankar }
78876ead96cSMaheshRavishankar 
78976ead96cSMaheshRavishankar /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
79076ead96cSMaheshRavishankar template <>
79176ead96cSMaheshRavishankar FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
79276ead96cSMaheshRavishankar     scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
79376ead96cSMaheshRavishankar     YieldTiledValuesFn yieldTiledValuesFn) {
79476ead96cSMaheshRavishankar   OpBuilder::InsertionGuard g(rewriter);
79576ead96cSMaheshRavishankar   Location loc = loopOp.getLoc();
79676ead96cSMaheshRavishankar   rewriter.setInsertionPoint(loopOp);
79776ead96cSMaheshRavishankar   auto inits = llvm::to_vector(loopOp.getOutputs());
79876ead96cSMaheshRavishankar   inits.append(newInitOperands.begin(), newInitOperands.end());
79976ead96cSMaheshRavishankar   auto newLoop = rewriter.create<scf::ForallOp>(
80076ead96cSMaheshRavishankar       loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
80176ead96cSMaheshRavishankar       loopOp.getMixedStep(), inits, loopOp.getMapping(),
80276ead96cSMaheshRavishankar       [](OpBuilder &, Location, ValueRange) {});
80376ead96cSMaheshRavishankar 
80476ead96cSMaheshRavishankar   // Move the region of the current block to the newly created op.
80576ead96cSMaheshRavishankar   Block *loopBody = loopOp.getBody();
80676ead96cSMaheshRavishankar   Block *newLoopBody = newLoop.getBody();
80776ead96cSMaheshRavishankar   rewriter.mergeBlocks(
80876ead96cSMaheshRavishankar       loopBody, newLoopBody,
80976ead96cSMaheshRavishankar       newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
81076ead96cSMaheshRavishankar 
81176ead96cSMaheshRavishankar   auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
81276ead96cSMaheshRavishankar   rewriter.setInsertionPoint(terminator);
81376ead96cSMaheshRavishankar   SmallVector<Value> tiledValues;
81476ead96cSMaheshRavishankar   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
81576ead96cSMaheshRavishankar   ValueRange regionIterArgs =
81676ead96cSMaheshRavishankar       newLoop.getRegionIterArgs().take_back(newInitOperands.size());
81776ead96cSMaheshRavishankar   if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
81876ead96cSMaheshRavishankar                                 regionIterArgs, tiledValues, resultOffsets,
81976ead96cSMaheshRavishankar                                 resultSizes))) {
82076ead96cSMaheshRavishankar     rewriter.eraseOp(newLoop);
82176ead96cSMaheshRavishankar     return rewriter.notifyMatchFailure(loopOp,
82276ead96cSMaheshRavishankar                                        "failed to get yielded tiled values");
82376ead96cSMaheshRavishankar   }
82476ead96cSMaheshRavishankar 
82576ead96cSMaheshRavishankar   // Update the terminator.
82676ead96cSMaheshRavishankar   rewriter.setInsertionPointToEnd(terminator.getBody());
82776ead96cSMaheshRavishankar 
82876ead96cSMaheshRavishankar   for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
82976ead96cSMaheshRavishankar            tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
83076ead96cSMaheshRavishankar     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
83176ead96cSMaheshRavishankar                                            rewriter.getIndexAttr(1));
83276ead96cSMaheshRavishankar     rewriter.create<tensor::ParallelInsertSliceOp>(
83376ead96cSMaheshRavishankar         terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
83476ead96cSMaheshRavishankar         resultStride);
83576ead96cSMaheshRavishankar   }
83676ead96cSMaheshRavishankar 
83776ead96cSMaheshRavishankar   rewriter.replaceOp(loopOp,
83876ead96cSMaheshRavishankar                      newLoop->getResults().take_front(loopOp.getNumResults()));
83976ead96cSMaheshRavishankar   return cast<LoopLikeOpInterface>(newLoop.getOperation());
84076ead96cSMaheshRavishankar }
84176ead96cSMaheshRavishankar 
84276ead96cSMaheshRavishankar /// Implementation of `yieldTiledValuesAndReplaceLoop` for
84376ead96cSMaheshRavishankar /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
84476ead96cSMaheshRavishankar /// supported loop type.
84576ead96cSMaheshRavishankar FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
84676ead96cSMaheshRavishankar     LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
84776ead96cSMaheshRavishankar     ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
84876ead96cSMaheshRavishankar   return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>(
84976ead96cSMaheshRavishankar              loopLikeOp.getOperation())
85076ead96cSMaheshRavishankar       .Case<scf::ForOp, scf::ForallOp>(
85176ead96cSMaheshRavishankar           [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
85276ead96cSMaheshRavishankar             return yieldTiledValuesAndReplaceLoop(
85376ead96cSMaheshRavishankar                 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
85476ead96cSMaheshRavishankar           })
85576ead96cSMaheshRavishankar       .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
85676ead96cSMaheshRavishankar         return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
85776ead96cSMaheshRavishankar       });
858cf6a7c19SMahesh Ravishankar }
859cf6a7c19SMahesh Ravishankar 
8604b563458SKunwar Grover /// Method to add new init values to a loop nest. Updates `loops` in-place
8614b563458SKunwar Grover /// with new loops that use the `newInitValues`. The outer-loops are updated
8624b563458SKunwar Grover /// to yield the new result values of the inner loop. For the innermost loop,
8634b563458SKunwar Grover /// the call back `getNewYields` is invoked to get the additional values to
8644b563458SKunwar Grover /// yield form the innermost loop.
86576ead96cSMaheshRavishankar static LogicalResult addInitOperandsToLoopNest(
86676ead96cSMaheshRavishankar     RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
86776ead96cSMaheshRavishankar     ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
8684a020018SMaheshRavishankar   SmallVector<scf::ForOp> newLoops;
8694a020018SMaheshRavishankar   if (loops.empty())
87076ead96cSMaheshRavishankar     return success();
8714a020018SMaheshRavishankar   OpBuilder::InsertionGuard g(rewriter);
8724a020018SMaheshRavishankar   rewriter.setInsertionPoint(loops.front());
87376ead96cSMaheshRavishankar 
87476ead96cSMaheshRavishankar   SmallVector<Value> ivs;
87576ead96cSMaheshRavishankar   for (auto &loop : loops.drop_back()) {
8764a020018SMaheshRavishankar     rewriter.setInsertionPoint(loop);
87797f91982SMahesh Ravishankar 
87876ead96cSMaheshRavishankar     // if loops.size() > 1 we assume that scf.for is used for the loops.
87976ead96cSMaheshRavishankar     auto forLoop = cast<scf::ForOp>(loop.getOperation());
88076ead96cSMaheshRavishankar 
8814a020018SMaheshRavishankar     // Create a new loop with the new init values for this loop.
88276ead96cSMaheshRavishankar     SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
8834a020018SMaheshRavishankar     newInits.append(newInitValues.begin(), newInitValues.end());
8844a020018SMaheshRavishankar     auto newLoop = rewriter.create<scf::ForOp>(
88576ead96cSMaheshRavishankar         forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
88676ead96cSMaheshRavishankar         forLoop.getStep(), newInits,
8874a020018SMaheshRavishankar         [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
8884a020018SMaheshRavishankar 
8894a020018SMaheshRavishankar     // Merge the body of the new loop with the body of the old loops.
8904a020018SMaheshRavishankar     SmallVector<Value> sourceBlockArgs;
8914a020018SMaheshRavishankar     sourceBlockArgs.push_back(newLoop.getInductionVar());
8924a020018SMaheshRavishankar     auto newRegionIterArgs = newLoop.getRegionIterArgs();
8934a020018SMaheshRavishankar     sourceBlockArgs.append(
8944a020018SMaheshRavishankar         newRegionIterArgs.begin(),
89576ead96cSMaheshRavishankar         std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
89676ead96cSMaheshRavishankar     rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
89776ead96cSMaheshRavishankar     rewriter.replaceOp(
89876ead96cSMaheshRavishankar         forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
8994a020018SMaheshRavishankar     loop = newLoop;
90076ead96cSMaheshRavishankar     ivs.push_back(newLoop.getInductionVar());
9014a020018SMaheshRavishankar     newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
90297f91982SMahesh Ravishankar   }
90397f91982SMahesh Ravishankar 
9044a020018SMaheshRavishankar   // Update the loop body of the innermost loop to get new yield values.
90576ead96cSMaheshRavishankar   LoopLikeOpInterface innerMostLoop = loops.back();
90676ead96cSMaheshRavishankar   FailureOr<LoopLikeOpInterface> newInnerMostLoop =
90776ead96cSMaheshRavishankar       yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
90876ead96cSMaheshRavishankar                                      getNewTiledYieldsFn);
90976ead96cSMaheshRavishankar 
91076ead96cSMaheshRavishankar   if (failed(newInnerMostLoop))
91176ead96cSMaheshRavishankar     return innerMostLoop.emitOpError("failed to return additional yields");
91276ead96cSMaheshRavishankar   loops.back() = newInnerMostLoop.value();
9134a020018SMaheshRavishankar 
9144a020018SMaheshRavishankar   // Make all other loops except the innermost loops yield the values returned
9154a020018SMaheshRavishankar   // by the inner loop.
9164a020018SMaheshRavishankar   for (auto [outerLoop, innerLoop] :
9174a020018SMaheshRavishankar        llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
91876ead96cSMaheshRavishankar     // Again assume that all the outer loops are scf.for operations.
91976ead96cSMaheshRavishankar     auto outerForLoop = cast<scf::ForOp>(outerLoop);
9204a020018SMaheshRavishankar     auto outerLoopYield =
92176ead96cSMaheshRavishankar         cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
9224a020018SMaheshRavishankar     SmallVector<Value> newYields =
9234a020018SMaheshRavishankar         llvm::to_vector(outerLoopYield.getOperands());
9244a020018SMaheshRavishankar     ValueRange additionalYields =
92576ead96cSMaheshRavishankar         innerLoop->getResults().take_back(newInitValues.size());
9264a020018SMaheshRavishankar     newYields.append(additionalYields.begin(), additionalYields.end());
9274a020018SMaheshRavishankar     rewriter.setInsertionPoint(outerLoopYield);
9284a020018SMaheshRavishankar     rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
92994f2a6ddSMahesh Ravishankar   }
93076ead96cSMaheshRavishankar   return success();
931809e3d8cSMahesh Ravishankar }
93294f2a6ddSMahesh Ravishankar 
93397f91982SMahesh Ravishankar /// Implementation of tiling transformation of `op` that implements the
93497f91982SMahesh Ravishankar /// `TilingInterface` using `scf.for` to iterate over the tiles.
935cf6a7c19SMahesh Ravishankar FailureOr<scf::SCFTilingResult>
93676ead96cSMaheshRavishankar mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
9376d4baa74SMehdi Amini                         const scf::SCFTilingOptions &options) {
9386740d701SMaheshRavishankar   if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
9396740d701SMaheshRavishankar     return failure();
9406740d701SMaheshRavishankar   }
9416740d701SMaheshRavishankar 
942cf6a7c19SMahesh Ravishankar   OpBuilder::InsertionGuard guard(rewriter);
943cf6a7c19SMahesh Ravishankar   rewriter.setInsertionPointAfter(op);
944cf6a7c19SMahesh Ravishankar 
945cf6a7c19SMahesh Ravishankar   // 1. Get the range of the loops that are represented by the operation.
946cf6a7c19SMahesh Ravishankar   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
947aa2a96a2SMaheshRavishankar 
9486740d701SMaheshRavishankar   // 2. Materialize the tile sizes and/or number of threads;
9496740d701SMaheshRavishankar   SmallVector<OpFoldResult> tileSizes, numThreads;
9506740d701SMaheshRavishankar   std::tie(tileSizes, numThreads) =
9516740d701SMaheshRavishankar       getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
9526740d701SMaheshRavishankar 
9536740d701SMaheshRavishankar   // Check if it is safe to tile. This is hold over from previous iterations
9546740d701SMaheshRavishankar   // of tile to for-all. Consider dropping it.
9556740d701SMaheshRavishankar   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
9566740d701SMaheshRavishankar     checkSafeToTileToForall(op, tileSizes, numThreads);
957cf6a7c19SMahesh Ravishankar   }
958cf6a7c19SMahesh Ravishankar 
95976ead96cSMaheshRavishankar   // 3. If there is an interchange specified, permute the iteration domain and
960b8a1f00dSMahesh Ravishankar   // the tile sizes.
96179150279SNicolas Vasilache   SmallVector<int64_t> interchangeVector;
962b8a1f00dSMahesh Ravishankar   if (!options.interchangeVector.empty()) {
963b8a1f00dSMahesh Ravishankar     interchangeVector = fillInterchangeVector(options.interchangeVector,
964b8a1f00dSMahesh Ravishankar                                               iterationDomain.size());
9656740d701SMaheshRavishankar     assert(isPermutationVector(interchangeVector) &&
9666740d701SMaheshRavishankar            "expected interchange vector to be a permutation");
967b8a1f00dSMahesh Ravishankar 
968b8a1f00dSMahesh Ravishankar     applyPermutationToVector(iterationDomain, interchangeVector);
96976ead96cSMaheshRavishankar     applyPermutationToVector(tileSizes, interchangeVector);
9706740d701SMaheshRavishankar     if (!numThreads.empty())
9716740d701SMaheshRavishankar       applyPermutationToVector(numThreads, interchangeVector);
972b8a1f00dSMahesh Ravishankar   }
973b8a1f00dSMahesh Ravishankar 
97476ead96cSMaheshRavishankar   FailureOr<TilingResult> tilingResult;
97576ead96cSMaheshRavishankar   // 4. Define the lambda function used later to generate the body of the
97676ead96cSMaheshRavishankar   // innermost tiled loop.
97776ead96cSMaheshRavishankar   YieldTiledValuesFn innerYieldTiledValuesFn =
97876ead96cSMaheshRavishankar       [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
97976ead96cSMaheshRavishankar           ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
98076ead96cSMaheshRavishankar           SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
98176ead96cSMaheshRavishankar           SmallVector<SmallVector<OpFoldResult>> &resultSizes)
98276ead96cSMaheshRavishankar       -> LogicalResult {
98376ead96cSMaheshRavishankar     // 4a. Compute the `offsets` and `sizes` to use for tiling.
98476ead96cSMaheshRavishankar     SmallVector<OpFoldResult> offsets, sizes;
9856740d701SMaheshRavishankar     std::tie(offsets, sizes) = getTileOffsetAndSizes(
9866740d701SMaheshRavishankar         rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
987cf6a7c19SMahesh Ravishankar 
98876ead96cSMaheshRavishankar     // 4b. If interchange was provided, apply inverse of the interchange
98976ead96cSMaheshRavishankar     //     to get back the offsets/sizes in the order to be specified.
990b8a1f00dSMahesh Ravishankar     if (!interchangeVector.empty()) {
991b8a1f00dSMahesh Ravishankar       auto inversePermutation = invertPermutationVector(interchangeVector);
992b1d3afc9SHanhan Wang       applyPermutationToVector(offsets, inversePermutation);
993b1d3afc9SHanhan Wang       applyPermutationToVector(sizes, inversePermutation);
994b8a1f00dSMahesh Ravishankar     }
995cf6a7c19SMahesh Ravishankar 
9964a020018SMaheshRavishankar     // 5. Generate the tiled implementation within the inner most loop.
99797f91982SMahesh Ravishankar 
9984a020018SMaheshRavishankar     // 5a. Clone the operation within the loop body.
9994a020018SMaheshRavishankar     auto clonedOp = cast<TilingInterface>(
100076ead96cSMaheshRavishankar         cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
10014a020018SMaheshRavishankar 
10024b563458SKunwar Grover     // 5b. Early return cloned op if tiling is not happening. We can not
10034b563458SKunwar Grover     // return the original op because it could lead to `rewriter.replaceOp(op,
10044b563458SKunwar Grover     // op->getResults())` and users would get crash.
100576ead96cSMaheshRavishankar     if (llvm::all_of(tileSizes, isZeroIndex)) {
100676ead96cSMaheshRavishankar       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
100776ead96cSMaheshRavishankar       tilingResult =
1008d5f0969cSMaheshRavishankar           TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
1009d5f0969cSMaheshRavishankar                        /*generatedSlices=*/{}};
101076ead96cSMaheshRavishankar       return success();
1011899c2bedSHan-Chung Wang     }
1012899c2bedSHan-Chung Wang 
1013899c2bedSHan-Chung Wang     // 5c. Tile the cloned operation.
10144b563458SKunwar Grover     tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs,
10154b563458SKunwar Grover                                           offsets, sizes, options);
101676ead96cSMaheshRavishankar     if (failed(tilingResult)) {
101776ead96cSMaheshRavishankar       rewriter.eraseOp(clonedOp);
101876ead96cSMaheshRavishankar       return op.emitOpError("faild to tile operation");
10194a020018SMaheshRavishankar     }
10204a020018SMaheshRavishankar 
1021899c2bedSHan-Chung Wang     // 5d. Delete the cloned operation.
10224a020018SMaheshRavishankar     rewriter.eraseOp(clonedOp);
10234a020018SMaheshRavishankar 
102476ead96cSMaheshRavishankar     // 5e. Compute the offsets at which the result values are to be inserted
102576ead96cSMaheshRavishankar     //     back into its destinations.
10264a020018SMaheshRavishankar     for (auto [index, tiledValue] :
102776ead96cSMaheshRavishankar          llvm::enumerate(tilingResult->tiledValues)) {
102876ead96cSMaheshRavishankar       tiledResults.push_back(tiledValue);
102976ead96cSMaheshRavishankar       SmallVector<OpFoldResult> resultOffset, resultSize;
10304b563458SKunwar Grover       if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets,
10314b563458SKunwar Grover                                        sizes, resultOffset, resultSize,
10324b563458SKunwar Grover                                        options))) {
103376ead96cSMaheshRavishankar         for (auto op : tilingResult->tiledOps) {
103476ead96cSMaheshRavishankar           rewriter.eraseOp(op);
103576ead96cSMaheshRavishankar         }
103697f91982SMahesh Ravishankar         return rewriter.notifyMatchFailure(
103797f91982SMahesh Ravishankar             op, "failed to get slice of result produced");
103897f91982SMahesh Ravishankar       }
103976ead96cSMaheshRavishankar       resultOffsets.emplace_back(std::move(resultOffset));
104076ead96cSMaheshRavishankar       resultSizes.emplace_back(std::move(resultSize));
104197f91982SMahesh Ravishankar     }
104276ead96cSMaheshRavishankar 
104376ead96cSMaheshRavishankar     return success();
104476ead96cSMaheshRavishankar   };
104576ead96cSMaheshRavishankar 
104676ead96cSMaheshRavishankar   // 6. Find the destination tensors to use for the operation.
10474b563458SKunwar Grover   FailureOr<SmallVector<Value>> maybeInits =
10484b563458SKunwar Grover       createInitialTensorsForTiling(rewriter, op, tileSizes, options);
10494b563458SKunwar Grover   if (failed(maybeInits)) {
10504b563458SKunwar Grover     return rewriter.notifyMatchFailure(
10514b563458SKunwar Grover         op, "unable to create initial tensors for tiling");
105276ead96cSMaheshRavishankar   }
10534b563458SKunwar Grover   SmallVector<Value> &initTensors = maybeInits.value();
105476ead96cSMaheshRavishankar 
105576ead96cSMaheshRavishankar   // 7. Generate the tiled loops nest using the callback defined above.
105676ead96cSMaheshRavishankar   SmallVector<LoopLikeOpInterface> loops;
105776ead96cSMaheshRavishankar   if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
10584b563458SKunwar Grover                               tileSizes, numThreads, initTensors,
105976ead96cSMaheshRavishankar                               innerYieldTiledValuesFn, loops)))
106076ead96cSMaheshRavishankar     return op.emitOpError("failed to generate tiling loops");
106176ead96cSMaheshRavishankar   assert(succeeded(tilingResult) &&
106276ead96cSMaheshRavishankar          "expected tiling result to be computed after loop generation");
106376ead96cSMaheshRavishankar 
10644b563458SKunwar Grover   SmallVector<Value> partialResults;
106576ead96cSMaheshRavishankar   if (loops.empty()) {
10664b563458SKunwar Grover     // If loops are empty, the tiled op is used as the replacement for the
10674b563458SKunwar Grover     // untiled op.
10684b563458SKunwar Grover     partialResults = tilingResult->tiledValues;
10694b563458SKunwar Grover   } else {
10704b563458SKunwar Grover     partialResults = llvm::map_to_vector(loops.front()->getResults(),
10714b563458SKunwar Grover                                          [](OpResult r) -> Value { return r; });
10724b563458SKunwar Grover   }
10734b563458SKunwar Grover 
10744b563458SKunwar Grover   FailureOr<MergeResult> mergeResult =
10754b563458SKunwar Grover       mergeTilingResults(rewriter, op, partialResults, options);
10764b563458SKunwar Grover   if (failed(mergeResult)) {
10774b563458SKunwar Grover     return rewriter.notifyMatchFailure(
10784b563458SKunwar Grover         op, "Failed to merge partial results from tiling");
10794b563458SKunwar Grover   }
10804b563458SKunwar Grover 
10814b563458SKunwar Grover   return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
10824b563458SKunwar Grover                               mergeResult.value(),
1083d5f0969cSMaheshRavishankar                               tilingResult->generatedSlices};
108476ead96cSMaheshRavishankar }
108597f91982SMahesh Ravishankar 
10864b563458SKunwar Grover FailureOr<scf::SCFTilingResult>
10871cff4cbdSNicolas Vasilache mlir::scf::tileReductionUsingScf(RewriterBase &b,
10883310fe55SThomas Raoux                                  PartialReductionOpInterface op,
1089170a25a7SMaheshRavishankar                                  ArrayRef<OpFoldResult> tileSizes) {
10904b563458SKunwar Grover   SCFTilingOptions options;
10914b563458SKunwar Grover   options.setLoopType(SCFTilingOptions::LoopType::ForOp);
10924b563458SKunwar Grover   options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy::
10934b563458SKunwar Grover                                          PartialReductionOuterReduction);
10944b563458SKunwar Grover   options.setTileSizes(tileSizes);
10952cc5f5d4SGroverkss 
10964b563458SKunwar Grover   TilingInterface tilingInterfaceOp =
10974b563458SKunwar Grover       dyn_cast<TilingInterface>(op.getOperation());
10984b563458SKunwar Grover   if (!tilingInterfaceOp) {
10994b563458SKunwar Grover     return b.notifyMatchFailure(
11004b563458SKunwar Grover         op,
11014b563458SKunwar Grover         "Operation implementing PartialReductionOpInterface should implement "
11024b563458SKunwar Grover         "TilingInterface");
11033310fe55SThomas Raoux   }
1104faac8989SAlex Zinenko 
11054b563458SKunwar Grover   return tileUsingSCF(b, tilingInterfaceOp, options);
11063310fe55SThomas Raoux }
110793c42299SMaheshRavishankar 
11082f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===//
110976ead96cSMaheshRavishankar // tileConsumerAndFuseProducersUsingSCF implementation.
11102f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===//
11112f637fe7SMahesh Ravishankar 
11127ee34550SMahesh Ravishankar /// Return the untiled producer whose slice is used in a tiled consumer. The
11137ee34550SMahesh Ravishankar /// method traverses the tile loop nest (`loops`) if needed, and returns the
11144b563458SKunwar Grover /// `iter_args` of the outer most that is encountered. Traversing the
11154b563458SKunwar Grover /// iter_args indicates that this is a destination operand of the consumer. If
11164b563458SKunwar Grover /// there was no loop traversal needed, the second value of the returned tuple
11174b563458SKunwar Grover /// is empty.
111822426110SRamkumar Ramachandra static std::tuple<OpResult, std::optional<OpOperand *>>
11197ee34550SMahesh Ravishankar getUntiledProducerFromSliceSource(OpOperand *source,
112076ead96cSMaheshRavishankar                                   ArrayRef<LoopLikeOpInterface> loops) {
112122426110SRamkumar Ramachandra   std::optional<OpOperand *> destinationIterArg;
11227ee34550SMahesh Ravishankar   auto loopIt = loops.rbegin();
11235550c821STres Popp   while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
112476ead96cSMaheshRavishankar     auto loop = *loopIt;
11257ee34550SMahesh Ravishankar     if (iterArg.getOwner()->getParentOp() != loop)
11267ee34550SMahesh Ravishankar       break;
11273cd2a0bcSMatthias Springer     source = loop.getTiedLoopInit(iterArg);
11287ee34550SMahesh Ravishankar     loopIt++;
11292f637fe7SMahesh Ravishankar   }
11307ee34550SMahesh Ravishankar   if (loopIt == loops.rend())
11317ee34550SMahesh Ravishankar     destinationIterArg = source;
11325550c821STres Popp   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
11332ed7c3fdSlorenzo chelini }
11342ed7c3fdSlorenzo chelini 
11359db7d4edSMahesh Ravishankar /// Implementation of fusing producer of a single slice by computing the
11369db7d4edSMahesh Ravishankar /// slice of the producer in-place.
11379db7d4edSMahesh Ravishankar std::optional<scf::SCFFuseProducerOfSliceResult>
113876ead96cSMaheshRavishankar mlir::scf::tileAndFuseProducerOfSlice(
113976ead96cSMaheshRavishankar     RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
114076ead96cSMaheshRavishankar     MutableArrayRef<LoopLikeOpInterface> loops) {
1141ce349ff1SMahesh Ravishankar   // 1. Get the producer of the source (potentially walking through
1142ce349ff1SMahesh Ravishankar   // `iter_args` of nested `scf.for`)
11436923a315SMatthias Springer   auto [fusableProducer, destinationInitArg] =
11448823e961SMatthias Springer       getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1145ce349ff1SMahesh Ravishankar                                         loops);
1146ce349ff1SMahesh Ravishankar   if (!fusableProducer)
1147ce349ff1SMahesh Ravishankar     return std::nullopt;
11484a020018SMaheshRavishankar   unsigned resultNumber = fusableProducer.getResultNumber();
1149ce349ff1SMahesh Ravishankar 
1150ce349ff1SMahesh Ravishankar   OpBuilder::InsertionGuard g(rewriter);
1151ce349ff1SMahesh Ravishankar   rewriter.setInsertionPoint(candidateSliceOp);
11524a020018SMaheshRavishankar 
11534a020018SMaheshRavishankar   // 2. Clone the fused producer
11544a020018SMaheshRavishankar   // 2a. Compute the destination operands to use for the cloned operation.
11554a020018SMaheshRavishankar   SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
11564a020018SMaheshRavishankar   Operation *fusableProducerOp = fusableProducer.getOwner();
11574a020018SMaheshRavishankar   if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
11584a020018SMaheshRavishankar       failed(tensor::getOrCreateDestinations(
11594a020018SMaheshRavishankar           rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
11604a020018SMaheshRavishankar           origDestinationTensors)))
11614a020018SMaheshRavishankar     return std::nullopt;
11624a020018SMaheshRavishankar 
11634a020018SMaheshRavishankar   clonedOpDestinationTensors = origDestinationTensors;
11644a020018SMaheshRavishankar   if (destinationInitArg &&
11654a020018SMaheshRavishankar       isa<DestinationStyleOpInterface>(fusableProducerOp)) {
11664a020018SMaheshRavishankar     // 2b. If the producer is also destination style, then to maintain the
11674a020018SMaheshRavishankar     // destination passing style, update the destination of the producer to be
11684a020018SMaheshRavishankar     // the source of the slice.
11694a020018SMaheshRavishankar     clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
11704a020018SMaheshRavishankar   }
11714a020018SMaheshRavishankar   // 2c. Clone the fused producer.
11724a020018SMaheshRavishankar   Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
11734a020018SMaheshRavishankar       rewriter, fusableProducerOp, clonedOpDestinationTensors);
11744a020018SMaheshRavishankar   // 2d. Update the source of the candidateSlice to be the cloned producer.
11754b563458SKunwar Grover   //     Easier to just clone the slice with different source since
11764b563458SKunwar Grover   //     replacements and DCE of cloned ops becomes easier
11774a020018SMaheshRavishankar   SmallVector<Value> candidateSliceOpOperands =
11784a020018SMaheshRavishankar       llvm::to_vector(candidateSliceOp->getOperands());
11794a020018SMaheshRavishankar   candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
11804a020018SMaheshRavishankar   tensor::ExtractSliceOp clonedCandidateSliceOp =
11814a020018SMaheshRavishankar       mlir::clone(rewriter, candidateSliceOp,
11824a020018SMaheshRavishankar                   candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
11834a020018SMaheshRavishankar 
11844a020018SMaheshRavishankar   // 3. Generate the tiled implementation of the producer of the source
1185809e3d8cSMahesh Ravishankar   FailureOr<TilingResult> tileAndFuseResult =
11864a020018SMaheshRavishankar       tensor::replaceExtractSliceWithTiledProducer(
11874a020018SMaheshRavishankar           rewriter, clonedCandidateSliceOp,
11884a020018SMaheshRavishankar           clonedProducerOp->getResult(resultNumber));
1189809e3d8cSMahesh Ravishankar   if (failed(tileAndFuseResult))
1190ce349ff1SMahesh Ravishankar     return std::nullopt;
11914a020018SMaheshRavishankar   // Note: Do not delete the candidateSliceOp, since its passed in from the
11924a020018SMaheshRavishankar   // caller.
1193809e3d8cSMahesh Ravishankar   rewriter.replaceAllUsesWith(candidateSliceOp,
1194809e3d8cSMahesh Ravishankar                               tileAndFuseResult->tiledValues[0]);
11954a020018SMaheshRavishankar   rewriter.eraseOp(clonedCandidateSliceOp);
11964a020018SMaheshRavishankar   rewriter.eraseOp(clonedProducerOp);
1197ce349ff1SMahesh Ravishankar 
1198ce349ff1SMahesh Ravishankar   // 3. If the slice is for a destination operand, for example,
1199ce349ff1SMahesh Ravishankar   //
1200ce349ff1SMahesh Ravishankar   // ```mlir
1201ce349ff1SMahesh Ravishankar   // %0 = linalg.init
1202ce349ff1SMahesh Ravishankar   // %1 = linalg.fill .. outs(%0 : )
1203ce349ff1SMahesh Ravishankar   // %2 = scf.for .. iter_args(%arg0 = %1) {
1204ce349ff1SMahesh Ravishankar   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
1205ce349ff1SMahesh Ravishankar   //     %4 = tensor.extract_slice %arg1 [..]
1206ce349ff1SMahesh Ravishankar   //     .. = linalg.matmul .. outs(%4 : )
1207ce349ff1SMahesh Ravishankar   //   }
1208ce349ff1SMahesh Ravishankar   // }
1209ce349ff1SMahesh Ravishankar   // ```
1210ce349ff1SMahesh Ravishankar   //
1211ce349ff1SMahesh Ravishankar   // the IR is currently
1212ce349ff1SMahesh Ravishankar   //
1213ce349ff1SMahesh Ravishankar   // ```
1214ce349ff1SMahesh Ravishankar   // %0 = linalg.init
1215ce349ff1SMahesh Ravishankar   // %1 = linalg.fill
1216ce349ff1SMahesh Ravishankar   // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1217ce349ff1SMahesh Ravishankar   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
12184a020018SMaheshRavishankar   //     %4 = tensor.extract_slice %arg1[..]
1219ce349ff1SMahesh Ravishankar   //     %5 = linalg.fill .. outs(%4 : )
1220ce349ff1SMahesh Ravishankar   //     .. = linalg.matmul .. outs(%5 : )
1221ce349ff1SMahesh Ravishankar   //   }
1222ce349ff1SMahesh Ravishankar   // }
1223ce349ff1SMahesh Ravishankar   // ```
1224ce349ff1SMahesh Ravishankar   //
1225ce349ff1SMahesh Ravishankar   // The untiled `linalg.fill` is still used as the `init_value` since it
1226ce349ff1SMahesh Ravishankar   // was originally a destination operand of the untiled `linalg.matmul`.
12274a020018SMaheshRavishankar   // When fusing an operand that is a destination operand, the iter_arg of
12284a020018SMaheshRavishankar   // the outer most loop should be changed to use the destination of the
12294a020018SMaheshRavishankar   // fused operation. With this the IR will be.
1230ce349ff1SMahesh Ravishankar   //
1231ce349ff1SMahesh Ravishankar   // ```
1232ce349ff1SMahesh Ravishankar   // %0 = linalg.init
1233ce349ff1SMahesh Ravishankar   // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1234ce349ff1SMahesh Ravishankar   //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
12354a020018SMaheshRavishankar   //     %3 = tensor.extract_slice %arg1[..]
1236ce349ff1SMahesh Ravishankar   //     %4 = linalg.fill .. outs(%3 : )
1237ce349ff1SMahesh Ravishankar   //     .. = linalg.matmul .. outs(%4 : )
1238ce349ff1SMahesh Ravishankar   //   }
1239ce349ff1SMahesh Ravishankar   // }
1240ce349ff1SMahesh Ravishankar   // ```
12416923a315SMatthias Springer   if (destinationInitArg &&
12424a020018SMaheshRavishankar       isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
12434a020018SMaheshRavishankar     loops.front()
12444a020018SMaheshRavishankar         ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
12454a020018SMaheshRavishankar         .set(origDestinationTensors[resultNumber]);
1246ce349ff1SMahesh Ravishankar   }
1247d5f0969cSMaheshRavishankar   return scf::SCFFuseProducerOfSliceResult{
1248d5f0969cSMaheshRavishankar       fusableProducer, tileAndFuseResult->tiledValues[0],
1249d5f0969cSMaheshRavishankar       tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
12509db7d4edSMahesh Ravishankar }
12519db7d4edSMahesh Ravishankar 
12529db7d4edSMahesh Ravishankar /// Reconstruct the fused producer from within the tiled-and-fused code.
1253d5f0969cSMaheshRavishankar FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
12549db7d4edSMahesh Ravishankar     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
12559db7d4edSMahesh Ravishankar     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
12567ef08eacSYun-Fly     MutableArrayRef<LoopLikeOpInterface> loops,
12577ef08eacSYun-Fly     ArrayRef<unsigned> yieldResultNumber) {
12584a020018SMaheshRavishankar   if (loops.empty())
125976ead96cSMaheshRavishankar     return success();
12604a020018SMaheshRavishankar 
12617ef08eacSYun-Fly   Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
12627ef08eacSYun-Fly             *tiledOwner = fusedProducerInfo.tiledOps[0];
12637ef08eacSYun-Fly 
12647ef08eacSYun-Fly   Location loc = originalOwner->getLoc();
12657ef08eacSYun-Fly   // a. collect all init Value to be appended
12667ef08eacSYun-Fly   SmallVector<unsigned> initNumberList =
12677ef08eacSYun-Fly       yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
12687ef08eacSYun-Fly                                       0, originalOwner->getNumResults()))
12697ef08eacSYun-Fly                                 : llvm::to_vector(yieldResultNumber);
12707ef08eacSYun-Fly   SmallVector<Value> initValueList;
12717ef08eacSYun-Fly   for (const auto &resultNumber : initNumberList) {
12729db7d4edSMahesh Ravishankar     FailureOr<Value> initValue = tensor::getOrCreateDestination(
12737ef08eacSYun-Fly         rewriter, loc, originalOwner->getResult(resultNumber));
12749db7d4edSMahesh Ravishankar     if (succeeded(initValue)) {
12757ef08eacSYun-Fly       initValueList.push_back(initValue.value());
12767ef08eacSYun-Fly     } else {
12777ef08eacSYun-Fly       return failure();
12787ef08eacSYun-Fly     }
12797ef08eacSYun-Fly   }
12804a020018SMaheshRavishankar 
1281d5f0969cSMaheshRavishankar   SmallVector<Operation *> generatedSlices;
128276ead96cSMaheshRavishankar   YieldTiledValuesFn newYieldValuesFn =
128376ead96cSMaheshRavishankar       [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
128476ead96cSMaheshRavishankar           ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
128576ead96cSMaheshRavishankar           SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
12867ef08eacSYun-Fly           SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
12874a020018SMaheshRavishankar     OpBuilder::InsertionGuard g(innerRewriter);
12887ef08eacSYun-Fly 
12897ef08eacSYun-Fly     // get sliceOp tile information
12907ef08eacSYun-Fly     SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
12917ef08eacSYun-Fly                               sliceSizes = sliceOp.getMixedSizes();
12927ef08eacSYun-Fly 
12937ef08eacSYun-Fly     // expect all strides of sliceOp being 1
12947ef08eacSYun-Fly     if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
12957ef08eacSYun-Fly           return !isConstantIntValue(ofr, 1);
12967ef08eacSYun-Fly         }))
12977ef08eacSYun-Fly       return failure();
12987ef08eacSYun-Fly 
12997ef08eacSYun-Fly     unsigned sliceResultNumber =
13007ef08eacSYun-Fly         fusedProducerInfo.origProducer.getResultNumber();
13017ef08eacSYun-Fly 
13027ef08eacSYun-Fly     auto tilableOp = cast<TilingInterface>(originalOwner);
13037ef08eacSYun-Fly     // b. get iterDomain Offset and Sizes based on sliceOp tile
13047ef08eacSYun-Fly     SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
13057ef08eacSYun-Fly     // skip tensor.pack/unpack/pad, which expects single opResult
13067ef08eacSYun-Fly     if (tilableOp->getNumResults() > 1 &&
13077ef08eacSYun-Fly         failed(tilableOp.getIterationDomainTileFromResultTile(
13087ef08eacSYun-Fly             rewriter, sliceResultNumber, sliceOffset, sliceSizes,
13097ef08eacSYun-Fly             iterDomainOffset, iterDomainSizes))) {
13104b563458SKunwar Grover       // In theory, it is unnecessary to raise an error here. Actually
13114b563458SKunwar Grover       // although it fails to reconstruct the result tensor, it should not
13124b563458SKunwar Grover       // broke current fusion anyway. The reason why we must return failure
13134b563458SKunwar Grover       // currently is that the callback function `newYieldValuesFn` will be
13144b563458SKunwar Grover       // called after new init operand(s) has already been appended. It will
13154b563458SKunwar Grover       // take more refactoring to make sure the init operands are added
13164b563458SKunwar Grover       // consistently in the future. For more details, please refer to:
13177ef08eacSYun-Fly       // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
13187ef08eacSYun-Fly       return failure();
13197ef08eacSYun-Fly     }
13207ef08eacSYun-Fly 
13217ef08eacSYun-Fly     // c. calculate offsets and sizes info of all OpResults respectively based
13227ef08eacSYun-Fly     // on iteration Domain Tile
13237ef08eacSYun-Fly     SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
13247ef08eacSYun-Fly     for (const auto &resultNumber : initNumberList) {
13257ef08eacSYun-Fly       if (resultNumber == sliceResultNumber) {
13267ef08eacSYun-Fly         offsetList.push_back(sliceOffset);
13277ef08eacSYun-Fly         sizesList.push_back(sliceSizes);
13287ef08eacSYun-Fly       } else {
13297ef08eacSYun-Fly         assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
13307ef08eacSYun-Fly         // infer result tile according to the iteration domain tile
13317ef08eacSYun-Fly         SmallVector<OpFoldResult> offset, sizes;
13327ef08eacSYun-Fly         if (failed(tilableOp.getResultTilePosition(
13337ef08eacSYun-Fly                 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
13347ef08eacSYun-Fly                 offset, sizes))) {
13357ef08eacSYun-Fly           return failure();
13367ef08eacSYun-Fly         }
13377ef08eacSYun-Fly         offsetList.push_back(offset);
13387ef08eacSYun-Fly         sizesList.push_back(sizes);
13397ef08eacSYun-Fly       }
13407ef08eacSYun-Fly     }
13417ef08eacSYun-Fly 
13424b563458SKunwar Grover     // d. create `extract_slice` for `iter_args` for DPS operation if
13434b563458SKunwar Grover     // necessary
13444a020018SMaheshRavishankar     if (auto tiledDestStyleOp =
13457ef08eacSYun-Fly             dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
13464a020018SMaheshRavishankar       rewriter.setInsertionPoint(tiledDestStyleOp);
13477ef08eacSYun-Fly       for (const auto &&[index, newRegionArg] :
13487ef08eacSYun-Fly            llvm::enumerate(newRegionIterArgs)) {
13494a020018SMaheshRavishankar         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
13507ef08eacSYun-Fly             loc, newRegionArg, offsetList[index], sizesList[index],
13517ef08eacSYun-Fly             SmallVector<OpFoldResult>(offsetList[index].size(),
13527ef08eacSYun-Fly                                       rewriter.getIndexAttr(1)));
1353d5f0969cSMaheshRavishankar         generatedSlices.push_back(destSlice);
13547ef08eacSYun-Fly         unsigned resultNumber = initNumberList[index];
13555fcf907bSMatthias Springer         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
13564a020018SMaheshRavishankar           tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
13574a020018SMaheshRavishankar         });
1358ec1086f2SMaheshRavishankar       }
13597ef08eacSYun-Fly     }
13607ef08eacSYun-Fly 
13617ef08eacSYun-Fly     // e. prepare tiled offset and sizes for later `insert_slice` creation by
13627ef08eacSYun-Fly     // caller
13634a020018SMaheshRavishankar     Block *block = rewriter.getInsertionPoint()->getBlock();
13644a020018SMaheshRavishankar     rewriter.setInsertionPoint(block->getTerminator());
13657ef08eacSYun-Fly     for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
13667ef08eacSYun-Fly       tiledResult.push_back(tiledOwner->getResult(resultNumber));
13677ef08eacSYun-Fly       tiledOffset.emplace_back(offsetList[index]);
13687ef08eacSYun-Fly       tiledSizes.emplace_back(sizesList[index]);
13697ef08eacSYun-Fly     }
137076ead96cSMaheshRavishankar     return success();
13714a020018SMaheshRavishankar   };
13724a020018SMaheshRavishankar 
1373d5f0969cSMaheshRavishankar   if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
1374d5f0969cSMaheshRavishankar                                        newYieldValuesFn))) {
1375d5f0969cSMaheshRavishankar     return failure();
1376d5f0969cSMaheshRavishankar   }
1377d5f0969cSMaheshRavishankar   return generatedSlices;
13789db7d4edSMahesh Ravishankar }
1379ce349ff1SMahesh Ravishankar 
13809144fed3SQuinn Dawkins namespace {
13819144fed3SQuinn Dawkins 
13829144fed3SQuinn Dawkins //===----------------------------------------------------------------------===//
13839144fed3SQuinn Dawkins // SliceTrackingListener
13849144fed3SQuinn Dawkins //===----------------------------------------------------------------------===//
13859144fed3SQuinn Dawkins 
13869144fed3SQuinn Dawkins /// This class is a listener for tracking the insertion and removal of
13879144fed3SQuinn Dawkins /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
13889144fed3SQuinn Dawkins /// fusion algorithm to apply cleanup patterns in between fusion steps.
13899144fed3SQuinn Dawkins class SliceTrackingListener : public RewriterBase::Listener {
13909144fed3SQuinn Dawkins public:
13919144fed3SQuinn Dawkins   explicit SliceTrackingListener(
13929144fed3SQuinn Dawkins       std::optional<FrozenRewritePatternSet> patterns);
13939144fed3SQuinn Dawkins   SliceTrackingListener() = default;
13949144fed3SQuinn Dawkins 
13954b563458SKunwar Grover   /// Adds the given list of operations to the worklist, and if present,
13964b563458SKunwar Grover   /// applies the list of `patterns` to the newly added operations. This only
13974b563458SKunwar Grover   /// processes the given operations and any newly inserted ones by the
13984b563458SKunwar Grover   /// pattern set.
13999144fed3SQuinn Dawkins   LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
14009144fed3SQuinn Dawkins 
14019144fed3SQuinn Dawkins   /// Add to the new operation worklist if it is an extract_slice.
14029144fed3SQuinn Dawkins   void notifyOperationInserted(Operation *op,
14039144fed3SQuinn Dawkins                                OpBuilder::InsertPoint previous) override;
14049144fed3SQuinn Dawkins 
14059144fed3SQuinn Dawkins   /// Shared helper for operation removal from the worklist.
14069144fed3SQuinn Dawkins   void removeOp(Operation *op);
14079144fed3SQuinn Dawkins 
14089144fed3SQuinn Dawkins   /// Remove the operation from the worklist.
14099144fed3SQuinn Dawkins   void notifyOperationErased(Operation *op) override;
14109144fed3SQuinn Dawkins 
14119144fed3SQuinn Dawkins   /// Remove the operation from the worklist.
14129144fed3SQuinn Dawkins   void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
14139144fed3SQuinn Dawkins 
14149144fed3SQuinn Dawkins   /// The worklist for this transformation keeps track of the slices to visit
14159144fed3SQuinn Dawkins   /// next for fusion.
14169144fed3SQuinn Dawkins   std::deque<tensor::ExtractSliceOp> worklist;
14179144fed3SQuinn Dawkins 
14189144fed3SQuinn Dawkins private:
14194b563458SKunwar Grover   /// Optional pattern set to apply when adding new operations to the
14204b563458SKunwar Grover   /// worklist.
14219144fed3SQuinn Dawkins   std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
14229144fed3SQuinn Dawkins };
14239144fed3SQuinn Dawkins 
14249144fed3SQuinn Dawkins SliceTrackingListener::SliceTrackingListener(
14259144fed3SQuinn Dawkins     std::optional<FrozenRewritePatternSet> p) {
14269144fed3SQuinn Dawkins   patterns = std::move(p);
14279144fed3SQuinn Dawkins }
14289144fed3SQuinn Dawkins 
14299144fed3SQuinn Dawkins LogicalResult
14309144fed3SQuinn Dawkins SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
14319144fed3SQuinn Dawkins   for (Operation *op : ops) {
14329144fed3SQuinn Dawkins     if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
14339144fed3SQuinn Dawkins       worklist.push_back(slice);
14349144fed3SQuinn Dawkins   }
14359144fed3SQuinn Dawkins 
14369144fed3SQuinn Dawkins   if (!patterns)
14379144fed3SQuinn Dawkins     return success();
14389144fed3SQuinn Dawkins 
14399144fed3SQuinn Dawkins   GreedyRewriteConfig config;
14409144fed3SQuinn Dawkins   config.listener = this;
14419144fed3SQuinn Dawkins   config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
144209dfc571SJacques Pienaar   return applyOpPatternsGreedily(ops, patterns.value(), config);
14439144fed3SQuinn Dawkins }
14449144fed3SQuinn Dawkins 
14459144fed3SQuinn Dawkins void SliceTrackingListener::notifyOperationInserted(
14469144fed3SQuinn Dawkins     Operation *op, OpBuilder::InsertPoint previous) {
14479144fed3SQuinn Dawkins   auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
14489144fed3SQuinn Dawkins   if (!slice)
14499144fed3SQuinn Dawkins     return;
14509144fed3SQuinn Dawkins   worklist.push_back(slice);
14519144fed3SQuinn Dawkins }
14529144fed3SQuinn Dawkins 
14534b563458SKunwar Grover // Scan the worklist for the given op and remove it if present. The
14544b563458SKunwar Grover // expectation is for the worklist to be small and for removal to be
14554b563458SKunwar Grover // relatively rare.
14569144fed3SQuinn Dawkins void SliceTrackingListener::removeOp(Operation *op) {
14579144fed3SQuinn Dawkins   if (!isa<tensor::ExtractSliceOp>(op))
14589144fed3SQuinn Dawkins     return;
14599144fed3SQuinn Dawkins   auto iter = worklist.begin();
14609144fed3SQuinn Dawkins   while (iter != worklist.end()) {
14619144fed3SQuinn Dawkins     if (*iter == op)
14629144fed3SQuinn Dawkins       break;
14639144fed3SQuinn Dawkins     iter++;
14649144fed3SQuinn Dawkins   }
14659144fed3SQuinn Dawkins   if (iter == worklist.end())
14669144fed3SQuinn Dawkins     return;
14679144fed3SQuinn Dawkins 
14689144fed3SQuinn Dawkins   worklist.erase(iter);
14699144fed3SQuinn Dawkins }
14709144fed3SQuinn Dawkins 
14719144fed3SQuinn Dawkins void SliceTrackingListener::notifyOperationErased(Operation *op) {
14729144fed3SQuinn Dawkins   removeOp(op);
14739144fed3SQuinn Dawkins }
14749144fed3SQuinn Dawkins 
14759144fed3SQuinn Dawkins void SliceTrackingListener::notifyOperationReplaced(Operation *op,
14769144fed3SQuinn Dawkins                                                     ValueRange replacement) {
14779144fed3SQuinn Dawkins   removeOp(op);
14789144fed3SQuinn Dawkins }
14796e3631d0SKunwar Grover 
14806e3631d0SKunwar Grover //===----------------------------------------------------------------------===//
14816e3631d0SKunwar Grover // ReplacementListener
14826e3631d0SKunwar Grover //===----------------------------------------------------------------------===//
14836e3631d0SKunwar Grover 
14846e3631d0SKunwar Grover /// Listener that tracks updates replacements for values which can be mutated.
14856e3631d0SKunwar Grover /// This listener runs on top of the existing listener for the rewriter,
14866e3631d0SKunwar Grover /// to make sure external users can still run listeners.
14876e3631d0SKunwar Grover class ReplacementListener : public RewriterBase::ForwardingListener {
14886e3631d0SKunwar Grover public:
14896e3631d0SKunwar Grover   ReplacementListener(DenseMap<Value, Value> &replacements,
14906e3631d0SKunwar Grover                       OpBuilder::Listener *listener)
14916e3631d0SKunwar Grover       : ForwardingListener(listener), replacements(replacements) {}
14926e3631d0SKunwar Grover 
14936e3631d0SKunwar Grover   void updateReplacementValues(ValueRange origValues,
14946e3631d0SKunwar Grover                                ValueRange replaceValues) {
14956e3631d0SKunwar Grover     // This can probably be written better, but just iterates over the map
14966e3631d0SKunwar Grover     // and the new replacements for now.
14976e3631d0SKunwar Grover     for (auto &[key, val] : replacements) {
14986e3631d0SKunwar Grover       for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
14996e3631d0SKunwar Grover         if (val == orig) {
15006e3631d0SKunwar Grover           val = replace;
15016e3631d0SKunwar Grover         }
15026e3631d0SKunwar Grover       }
15036e3631d0SKunwar Grover     }
15046e3631d0SKunwar Grover   }
15056e3631d0SKunwar Grover 
15066e3631d0SKunwar Grover   void notifyOperationReplaced(Operation *op, Operation *newOp) override {
15076e3631d0SKunwar Grover     ForwardingListener::notifyOperationReplaced(op, newOp);
15086e3631d0SKunwar Grover     updateReplacementValues(op->getResults(), newOp->getResults());
15096e3631d0SKunwar Grover   }
15106e3631d0SKunwar Grover 
15116e3631d0SKunwar Grover   void notifyOperationReplaced(Operation *op, ValueRange values) override {
15126e3631d0SKunwar Grover     ForwardingListener::notifyOperationReplaced(op, values);
15136e3631d0SKunwar Grover     updateReplacementValues(op->getResults(), values);
15146e3631d0SKunwar Grover   }
15156e3631d0SKunwar Grover 
15166e3631d0SKunwar Grover private:
15176e3631d0SKunwar Grover   DenseMap<Value, Value> &replacements;
15186e3631d0SKunwar Grover };
15196e3631d0SKunwar Grover 
15209144fed3SQuinn Dawkins } // namespace
15219144fed3SQuinn Dawkins 
152297f91982SMahesh Ravishankar /// Implementation of tile consumer and fuse producer greedily.
15232f637fe7SMahesh Ravishankar FailureOr<scf::SCFTileAndFuseResult>
152476ead96cSMaheshRavishankar mlir::scf::tileConsumerAndFuseProducersUsingSCF(
152597f91982SMahesh Ravishankar     RewriterBase &rewriter, TilingInterface consumer,
15266d4baa74SMehdi Amini     const scf::SCFTileAndFuseOptions &options) {
15272f637fe7SMahesh Ravishankar   // This transformation is only valid for ops that return values (i.e. not
15282f637fe7SMahesh Ravishankar   // valid to use with operations that have memref operands).
152997f91982SMahesh Ravishankar   if (!consumer->getNumResults()) {
15302f637fe7SMahesh Ravishankar     return rewriter.notifyMatchFailure(
153197f91982SMahesh Ravishankar         consumer, "invalid pattern for op with no results");
15322f637fe7SMahesh Ravishankar   }
15332f637fe7SMahesh Ravishankar 
15342f637fe7SMahesh Ravishankar   // 1. First tile the consumer.
153593c42299SMaheshRavishankar   SetVector<Operation *> fusedProducers, tiledAndFusedOps;
15364435ced9SMaheshRavishankar   llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
153776ead96cSMaheshRavishankar 
153897f91982SMahesh Ravishankar   FailureOr<scf::SCFTilingResult> tilingResult =
153976ead96cSMaheshRavishankar       tileUsingSCF(rewriter, consumer, options.tilingOptions);
154076ead96cSMaheshRavishankar 
154197f91982SMahesh Ravishankar   if (failed(tilingResult))
154297f91982SMahesh Ravishankar     return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1543fbfca43eSMehdi Amini   for (auto *tiledOp : tilingResult->tiledOps)
154493c42299SMaheshRavishankar     tiledAndFusedOps.insert(tiledOp);
154597f91982SMahesh Ravishankar 
15464435ced9SMaheshRavishankar   DenseMap<Value, Value> replacements;
15474b563458SKunwar Grover   for (auto [origVal, replacement] : llvm::zip_equal(
15484b563458SKunwar Grover            consumer->getResults(), tilingResult->mergeResult.replacements)) {
15494435ced9SMaheshRavishankar     replacements[origVal] = replacement;
15504435ced9SMaheshRavishankar   }
15516e3631d0SKunwar Grover 
15526e3631d0SKunwar Grover   // If there are no loops generated, fusion is immaterial.
15536e3631d0SKunwar Grover   auto &loops = tilingResult->loops;
15546e3631d0SKunwar Grover   if (loops.empty()) {
155576ead96cSMaheshRavishankar     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
155676ead96cSMaheshRavishankar                                      replacements};
155793c42299SMaheshRavishankar   }
15582f637fe7SMahesh Ravishankar 
15596e3631d0SKunwar Grover   // Since the loop gets potentially replaced during fusion, we need to track
15606e3631d0SKunwar Grover   // the mutation of replacement values. To do this, we attach a listener to
15616e3631d0SKunwar Grover   // update the replacements as they happen.
15626e3631d0SKunwar Grover   OpBuilder::Listener *previousListener = rewriter.getListener();
15636e3631d0SKunwar Grover   auto resetListener =
15646e3631d0SKunwar Grover       llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
15656e3631d0SKunwar Grover   ReplacementListener replaceListener(replacements, previousListener);
15666e3631d0SKunwar Grover   rewriter.setListener(&replaceListener);
15674435ced9SMaheshRavishankar 
15682f637fe7SMahesh Ravishankar   // 2. Typically, the operands of the tiled operation are slices of the
15692f637fe7SMahesh Ravishankar   //    operands of the untiled operation. These are expressed in IR using
15704b563458SKunwar Grover   //    `tensor.extract_slice` operations with source being the operands of
15714b563458SKunwar Grover   //    the untiled operation. Create a worklist of these
15724b563458SKunwar Grover   //    `tensor.extract_slice` operations. If the producers of the source of
15734b563458SKunwar Grover   //    the `tensor.extract_slice` can be tiled such that the tiled value is
15744b563458SKunwar Grover   //    generated in-place, that effectively tiles + fuses the operations.
1575d5f0969cSMaheshRavishankar   struct WorklistItem {
1576d5f0969cSMaheshRavishankar     tensor::ExtractSliceOp candidateSlice;
1577d5f0969cSMaheshRavishankar     SCFTileAndFuseOptions::ControlFnResult controlFnResult;
15782f637fe7SMahesh Ravishankar   };
15799144fed3SQuinn Dawkins 
15809144fed3SQuinn Dawkins   SliceTrackingListener sliceTracker =
15819144fed3SQuinn Dawkins       SliceTrackingListener(options.cleanupPatterns);
15829144fed3SQuinn Dawkins 
15839144fed3SQuinn Dawkins   if (failed(
15849144fed3SQuinn Dawkins           sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
15859144fed3SQuinn Dawkins     return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
15869144fed3SQuinn Dawkins   }
15879144fed3SQuinn Dawkins   OpBuilder::InsertionGuard g(rewriter);
15889144fed3SQuinn Dawkins   while (!sliceTracker.worklist.empty()) {
15899144fed3SQuinn Dawkins     auto candidateSlice = sliceTracker.worklist.front();
15909144fed3SQuinn Dawkins     sliceTracker.worklist.pop_front();
15912f637fe7SMahesh Ravishankar 
15924435ced9SMaheshRavishankar     auto [fusableProducer, destinationInitArg] =
15939144fed3SQuinn Dawkins         getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
15949144fed3SQuinn Dawkins                                           loops);
15954435ced9SMaheshRavishankar     if (!fusableProducer)
15964435ced9SMaheshRavishankar       continue;
15979144fed3SQuinn Dawkins 
1598d5f0969cSMaheshRavishankar     std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
15999144fed3SQuinn Dawkins         options.fusionControlFn(candidateSlice, fusableProducer,
1600d5f0969cSMaheshRavishankar                                 destinationInitArg.has_value());
1601d5f0969cSMaheshRavishankar     if (!controlFnResult)
16024435ced9SMaheshRavishankar       continue;
1603d5f0969cSMaheshRavishankar 
16049144fed3SQuinn Dawkins     WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
16054435ced9SMaheshRavishankar 
1606ce349ff1SMahesh Ravishankar     // The operands of the fused producer might themselved be slices of
16072f637fe7SMahesh Ravishankar     // values produced by operations that implement the `TilingInterface`.
16082f637fe7SMahesh Ravishankar     // Add these operations to the worklist.
160993c42299SMaheshRavishankar     std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1610d5f0969cSMaheshRavishankar         tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1611d5f0969cSMaheshRavishankar                                    loops);
161293c42299SMaheshRavishankar     if (!fusedResult)
1613ce349ff1SMahesh Ravishankar       continue;
16142f637fe7SMahesh Ravishankar 
16159144fed3SQuinn Dawkins     SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
16169144fed3SQuinn Dawkins 
1617d5f0969cSMaheshRavishankar     if (worklistItem.controlFnResult.yieldProducerReplacement) {
16184b563458SKunwar Grover       // Reconstruct and yield all opResult of fusableProducerOp by default.
16194b563458SKunwar Grover       // The caller can specific which one to yield by designating optional
16204b563458SKunwar Grover       // argument named `yieldResultNumber` of
16214b563458SKunwar Grover       // `yieldReplacementForFusedProducer`.
1622d5f0969cSMaheshRavishankar       Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1623d5f0969cSMaheshRavishankar       FailureOr<SmallVector<Operation *>> newSlices =
1624d5f0969cSMaheshRavishankar           yieldReplacementForFusedProducer(rewriter,
1625d5f0969cSMaheshRavishankar                                            worklistItem.candidateSlice,
1626d5f0969cSMaheshRavishankar                                            fusedResult.value(), loops);
1627d5f0969cSMaheshRavishankar       if (failed(newSlices)) {
162876ead96cSMaheshRavishankar         return rewriter.notifyMatchFailure(
16297ef08eacSYun-Fly             fusableProducerOp, "failed to replacement value for this "
16307ef08eacSYun-Fly                                "operation from within the tiled loop");
163176ead96cSMaheshRavishankar       }
16329144fed3SQuinn Dawkins       worklistCandidates.append(newSlices.value());
16337ef08eacSYun-Fly       for (auto [index, result] :
16347ef08eacSYun-Fly            llvm::enumerate(fusableProducerOp->getResults())) {
16356e3631d0SKunwar Grover         replacements[result] = loops.front()->getResult(
16366e3631d0SKunwar Grover             loops.front()->getNumResults() -
16376e3631d0SKunwar Grover             fusableProducerOp->getNumResults() + index);
16387ef08eacSYun-Fly       }
16394435ced9SMaheshRavishankar     }
16409db7d4edSMahesh Ravishankar     if (Operation *tiledAndFusedOp =
164193c42299SMaheshRavishankar             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
164293c42299SMaheshRavishankar       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
164393c42299SMaheshRavishankar       tiledAndFusedOps.insert(tiledAndFusedOp);
16449db7d4edSMahesh Ravishankar     }
16459144fed3SQuinn Dawkins 
16469144fed3SQuinn Dawkins     if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
16479144fed3SQuinn Dawkins       return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
16489144fed3SQuinn Dawkins     }
16492f637fe7SMahesh Ravishankar   }
16504435ced9SMaheshRavishankar 
165176ead96cSMaheshRavishankar   return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
165276ead96cSMaheshRavishankar                                    replacements};
1653d871daeaSMaheshRavishankar }
1654d871daeaSMaheshRavishankar 
1655d871daeaSMaheshRavishankar //===----------------------------------------------------------------------===//
16562b2ce50fSAbhishek Varma // tileAndFuseConsumerUsingSCF implementation.
16572b2ce50fSAbhishek Varma //===----------------------------------------------------------------------===//
16582b2ce50fSAbhishek Varma 
16592b2ce50fSAbhishek Varma /// A utility function that checks whether the only use of the result of a
16602b2ce50fSAbhishek Varma /// tensor.insert_slice op is in a scf.yield op.
16612b2ce50fSAbhishek Varma static LogicalResult
16622b2ce50fSAbhishek Varma checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
16632b2ce50fSAbhishek Varma   Value result = candidateSliceOp.getResult();
16642b2ce50fSAbhishek Varma   Value::use_range uses = result.getUses();
16652b2ce50fSAbhishek Varma   if (!llvm::hasSingleElement(uses)) {
16662b2ce50fSAbhishek Varma     LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
16672b2ce50fSAbhishek Varma     return failure();
16682b2ce50fSAbhishek Varma   }
16692b2ce50fSAbhishek Varma   OpOperand &operandUse = (*uses.begin());
16702b2ce50fSAbhishek Varma   Operation *userOp = operandUse.getOwner();
16712b2ce50fSAbhishek Varma   if (!isa<scf::YieldOp>(userOp)) {
16722b2ce50fSAbhishek Varma     LLVM_DEBUG(llvm::dbgs()
16732b2ce50fSAbhishek Varma                << "Expected scf.yield to be the only user, but got -> "
16742b2ce50fSAbhishek Varma                << (*userOp));
16752b2ce50fSAbhishek Varma     return failure();
16762b2ce50fSAbhishek Varma   }
16772b2ce50fSAbhishek Varma   if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
16782b2ce50fSAbhishek Varma     LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
16792b2ce50fSAbhishek Varma                                "be in the same block\n");
16802b2ce50fSAbhishek Varma     return failure();
16812b2ce50fSAbhishek Varma   }
16822b2ce50fSAbhishek Varma   return success();
16832b2ce50fSAbhishek Varma }
16842b2ce50fSAbhishek Varma 
16854b563458SKunwar Grover /// An utility to get the first user of the given loopOp. If any of user stay
16864b563458SKunwar Grover /// in different block of loopOp, return failure.
16879bc3102bSYun-Fly static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
16889bc3102bSYun-Fly   if (!isa<LoopLikeOpInterface>(loopOp))
16892b2ce50fSAbhishek Varma     return failure();
16909bc3102bSYun-Fly   Operation *firstUserOfLoop = nullptr;
16919bc3102bSYun-Fly   for (Operation *userOp : loopOp->getUsers()) {
16929bc3102bSYun-Fly     // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
16939bc3102bSYun-Fly     // block with any other types of operation. Thus, just redirecting to its
16949bc3102bSYun-Fly     // parent `InParallelOp`. E.g.
16959bc3102bSYun-Fly     //
16969bc3102bSYun-Fly     // ```
16979bc3102bSYun-Fly     // %1 = scf.for {
16989bc3102bSYun-Fly     //   ...
16999bc3102bSYun-Fly     // }
17009bc3102bSYun-Fly     // %2 = consumerOp ins(%1, ...)
17019bc3102bSYun-Fly     // scf.forall.in_parallel {
17029bc3102bSYun-Fly     //    tensor.parallel_insert_slice %1
17039bc3102bSYun-Fly     // }
17049bc3102bSYun-Fly     // ```
17059bc3102bSYun-Fly     // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
17069bc3102bSYun-Fly     // same block with `consumerOp`.
17079bc3102bSYun-Fly     if (isa<tensor::ParallelInsertSliceOp>(userOp))
17089bc3102bSYun-Fly       userOp = userOp->getParentOfType<scf::InParallelOp>();
17099bc3102bSYun-Fly 
17109bc3102bSYun-Fly     if (loopOp->getBlock() != userOp->getBlock())
17112b2ce50fSAbhishek Varma       return failure();
17129bc3102bSYun-Fly 
17139bc3102bSYun-Fly     if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
17149bc3102bSYun-Fly       firstUserOfLoop = userOp;
17159bc3102bSYun-Fly   }
17169bc3102bSYun-Fly   return firstUserOfLoop;
1717b8c974f0SAbhishek Varma }
1718b8c974f0SAbhishek Varma 
17194b563458SKunwar Grover /// This utility currently checks whether the first userOp of loop is NOT
17204b563458SKunwar Grover /// before the last defineOp of consumer operand. Because that we need to move
17214b563458SKunwar Grover /// the whole loop structure right before the `firstUserOfLoop`. This utility
17224b563458SKunwar Grover /// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
17234b563458SKunwar Grover /// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
17249bc3102bSYun-Fly ///
17259bc3102bSYun-Fly /// ```
17269bc3102bSYun-Fly /// %0 = scf.for() {
17279bc3102bSYun-Fly ///   ...
17289bc3102bSYun-Fly /// }
17299bc3102bSYun-Fly /// ...
17309bc3102bSYun-Fly /// %1 = firstUserOfLoop(%0)
17319bc3102bSYun-Fly /// ...
17329bc3102bSYun-Fly /// %2 = lastDefOfConsumerOperand
17339bc3102bSYun-Fly /// ...
17349bc3102bSYun-Fly /// %3 = consumerOp(%2)
17359bc3102bSYun-Fly /// ```
17369bc3102bSYun-Fly ///
17374b563458SKunwar Grover /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
17384b563458SKunwar Grover /// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
17394b563458SKunwar Grover /// a.k.a. use-def chain violation:
17409bc3102bSYun-Fly ///
17419bc3102bSYun-Fly /// ```
17429bc3102bSYun-Fly /// %0:2 = scf.for() {
17439bc3102bSYun-Fly ///    // use before define error
17449bc3102bSYun-Fly ///    %3 = tiledConsumerOp(%2)
17459bc3102bSYun-Fly /// }
17469bc3102bSYun-Fly /// %1 = firstUserOfLoop(%0)
17479bc3102bSYun-Fly /// ...
17489bc3102bSYun-Fly /// %2 = lastDefOfConsumerOperand
17499bc3102bSYun-Fly /// ```
17509bc3102bSYun-Fly ///
17519bc3102bSYun-Fly /// @param loopOp: loop operation
17529bc3102bSYun-Fly /// @param consumerOp: consumer operation
17534b563458SKunwar Grover /// @param reorderOperations: the flag controls whether to reorder the
17544b563458SKunwar Grover /// backward slice w.r.t. the defineOp of `consumerOp` operands.
17554b563458SKunwar Grover /// @return: computed backward slice of consumerOp, but excluding those
17564b563458SKunwar Grover /// already dominates `firstUserOfLoop`.
17579bc3102bSYun-Fly static FailureOr<llvm::SetVector<Operation *>>
17589bc3102bSYun-Fly checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
17599bc3102bSYun-Fly                        bool reorderOperations) {
17609bc3102bSYun-Fly   FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
17619bc3102bSYun-Fly   if (failed(firstUserOfLoop))
17629bc3102bSYun-Fly     return failure();
17639bc3102bSYun-Fly 
17649bc3102bSYun-Fly   BackwardSliceOptions options;
17659bc3102bSYun-Fly   DominanceInfo dominanceInfo;
17669bc3102bSYun-Fly   options.inclusive = true;
17679bc3102bSYun-Fly   options.omitBlockArguments = true;
17689bc3102bSYun-Fly   bool includeLoopOp = false;
17699bc3102bSYun-Fly   options.filter = [&](Operation *op) {
17709bc3102bSYun-Fly     if (op == loopOp) {
17719bc3102bSYun-Fly       includeLoopOp = true;
17729bc3102bSYun-Fly       return false;
17739bc3102bSYun-Fly     }
17749bc3102bSYun-Fly     // Cut off the slice to not include any operation that already dominates
17759bc3102bSYun-Fly     // firstUserOfLoop.
17769bc3102bSYun-Fly     return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
17779bc3102bSYun-Fly   };
17789bc3102bSYun-Fly   llvm::SetVector<Operation *> slice;
17799bc3102bSYun-Fly   for (auto operand : consumerOp->getOperands()) {
17809bc3102bSYun-Fly     getBackwardSlice(operand, &slice, options);
17819bc3102bSYun-Fly   }
17829bc3102bSYun-Fly 
17839bc3102bSYun-Fly   if (!slice.empty()) {
17849bc3102bSYun-Fly     // If consumerOp has one producer, which is also the user of loopOp.
17859bc3102bSYun-Fly     // E.g.
17869bc3102bSYun-Fly     // ```
17879bc3102bSYun-Fly     //  %0 = %loopOp
17889bc3102bSYun-Fly     //  %1 = consumerOp1 ins(%0)
17899bc3102bSYun-Fly     //  %2 = consumerOp2 ins(%0, %1)
17909bc3102bSYun-Fly     // ```
17919bc3102bSYun-Fly     // We can not fuse consumerOp2 into loopOp due to UD chain, unless
17929bc3102bSYun-Fly     // consumerOp1 has already been fused into loopOp before.
17939bc3102bSYun-Fly     if (includeLoopOp || !reorderOperations)
17949bc3102bSYun-Fly       return failure();
17959bc3102bSYun-Fly   }
17969bc3102bSYun-Fly 
17979bc3102bSYun-Fly   return slice;
17989bc3102bSYun-Fly }
17999bc3102bSYun-Fly 
18009bc3102bSYun-Fly /// Fetches the OpOperand of the first valid user (and use) of the value `val`
18019bc3102bSYun-Fly /// which implements `TilingInterface` and `DestinationStyleOpInterface`.
18029bc3102bSYun-Fly /// Returns failure otherwise.
18039bc3102bSYun-Fly static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
18049bc3102bSYun-Fly                                                       Operation *loopOp,
18059bc3102bSYun-Fly                                                       unsigned resultNumber) {
18069bc3102bSYun-Fly   if (!isa<LoopLikeOpInterface>(loopOp))
18079bc3102bSYun-Fly     return failure();
18089bc3102bSYun-Fly   Value val = loopOp->getResult(resultNumber);
18099bc3102bSYun-Fly   Block *loopBlock = loopOp->getBlock();
18109bc3102bSYun-Fly   for (OpOperand &opOperand : val.getUses()) {
18119bc3102bSYun-Fly     Operation *consumerOp = opOperand.getOwner();
18129bc3102bSYun-Fly     // Step 1. Check if the user is tilable.
181354ae9e7bSQuinn Dawkins     if (!isa<TilingInterface>(consumerOp) ||
181454ae9e7bSQuinn Dawkins         !isa<DestinationStyleOpInterface>(consumerOp)) {
18159bc3102bSYun-Fly       // TODO: We have to init result of consumer before scf.for, use
18164b563458SKunwar Grover       // DestinationStyleOpInterface to get result shape from init for now.
18174b563458SKunwar Grover       // Add support for other op such as op has InferTypeOpInterface.
18189bc3102bSYun-Fly       continue;
18199bc3102bSYun-Fly     }
18209bc3102bSYun-Fly     // Step 2. Check if user stay in the same block.
18219bc3102bSYun-Fly     if (loopBlock != consumerOp->getBlock())
18229bc3102bSYun-Fly       continue;
18239bc3102bSYun-Fly     // Step 3. Check if user has succeeding user. Otherwise, it usually
18249bc3102bSYun-Fly     // represents already tiled.
18259bc3102bSYun-Fly     if (consumerOp->use_empty())
18269bc3102bSYun-Fly       continue;
18279bc3102bSYun-Fly     // Step 4. Check assumption for loop with `reorderOperations` enabled.
18289bc3102bSYun-Fly     FailureOr<llvm::SetVector<Operation *>> slice =
18299bc3102bSYun-Fly         checkAssumptionForLoop(loopOp, consumerOp, true);
18309bc3102bSYun-Fly     if (failed(slice))
18319bc3102bSYun-Fly       continue;
18324b563458SKunwar Grover     // Step 5. If backward sice is not empty, move them before
18334b563458SKunwar Grover     // firstUserOfLoop.
18349bc3102bSYun-Fly     if (!slice->empty()) {
18359bc3102bSYun-Fly       mlir::topologicalSort(*slice);
18369bc3102bSYun-Fly       FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
18379bc3102bSYun-Fly       assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
18389bc3102bSYun-Fly       for (auto op : *slice) {
18399bc3102bSYun-Fly         rewriter.moveOpBefore(op, *firstUserOfLoop);
18409bc3102bSYun-Fly       }
18419bc3102bSYun-Fly     }
18429bc3102bSYun-Fly     return &opOperand;
18439bc3102bSYun-Fly   }
1844b8c974f0SAbhishek Varma   return failure();
18452b2ce50fSAbhishek Varma }
18462b2ce50fSAbhishek Varma 
18474b563458SKunwar Grover /// Find the perfectly nested loops outside of given loop(included) sorted
18484b563458SKunwar Grover /// from outer to inner.
1849a9ba1b6dSYun-Fly ///
1850a9ba1b6dSYun-Fly /// E.g.
1851a9ba1b6dSYun-Fly ///
1852a9ba1b6dSYun-Fly /// ```
1853a9ba1b6dSYun-Fly ///  %0 = scf.for()
1854a9ba1b6dSYun-Fly ///    %1 = scf.for()
1855a9ba1b6dSYun-Fly ///      %2 = scf.for()
1856a9ba1b6dSYun-Fly ///         %3 = ...
1857a9ba1b6dSYun-Fly ///         yield %3
1858a9ba1b6dSYun-Fly ///      yield %2
1859a9ba1b6dSYun-Fly ///    yield %1
1860a9ba1b6dSYun-Fly /// ```
1861a9ba1b6dSYun-Fly ///
1862a9ba1b6dSYun-Fly /// This function will return three perfectly nested loops: %0 + %1 + %2, when
1863a9ba1b6dSYun-Fly /// target inner loop is %2.
1864a9ba1b6dSYun-Fly static SmallVector<scf::ForOp>
1865a9ba1b6dSYun-Fly getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
1866a9ba1b6dSYun-Fly   SmallVector<scf::ForOp> nestLoops = {loop};
1867a9ba1b6dSYun-Fly   auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1868a9ba1b6dSYun-Fly 
1869a9ba1b6dSYun-Fly   // Check if it is the ForOp that yield the result of inner loop.
1870a9ba1b6dSYun-Fly   auto isForOpYieldResultOfInnerLoop =
1871a9ba1b6dSYun-Fly       [](scf::ForOp outerLoop) -> LogicalResult {
1872a9ba1b6dSYun-Fly     Block *body = outerLoop.getBody();
1873a9ba1b6dSYun-Fly     if (!llvm::hasSingleElement(body->without_terminator()))
1874a9ba1b6dSYun-Fly       return failure();
1875a9ba1b6dSYun-Fly     auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
1876a9ba1b6dSYun-Fly     auto innerForOp = dyn_cast<scf::ForOp>(body->front());
1877a9ba1b6dSYun-Fly     if (!innerForOp)
1878a9ba1b6dSYun-Fly       return failure();
1879a9ba1b6dSYun-Fly     // All of innerForOp results should be yielded.
1880a9ba1b6dSYun-Fly     return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1881a9ba1b6dSYun-Fly   };
1882a9ba1b6dSYun-Fly 
1883a9ba1b6dSYun-Fly   while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1884a9ba1b6dSYun-Fly     nestLoops.push_back(outerLoop);
1885a9ba1b6dSYun-Fly     outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1886a9ba1b6dSYun-Fly   }
1887a9ba1b6dSYun-Fly   // sorted from outer to inner
1888a9ba1b6dSYun-Fly   return {nestLoops.rbegin(), nestLoops.rend()};
1889a9ba1b6dSYun-Fly }
1890a9ba1b6dSYun-Fly 
18912b2ce50fSAbhishek Varma /// Fetch the untiled consumer of a scf.for's result which is yielded by a
18922b2ce50fSAbhishek Varma /// tensor.insert_slice. This function makes the following assumptions :
18932b2ce50fSAbhishek Varma /// 1.  tensor.insert_slice has scf.yield as its only user.
18942b2ce50fSAbhishek Varma /// 2.  scf.for's corresponding result has only one use.
18952b2ce50fSAbhishek Varma static FailureOr<OpOperand *>
18969bc3102bSYun-Fly getUntiledConsumerFromSlice(RewriterBase &rewriter,
18979bc3102bSYun-Fly                             tensor::InsertSliceOp candidateSliceOp) {
18982b2ce50fSAbhishek Varma   if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
18992b2ce50fSAbhishek Varma     return failure();
19002b2ce50fSAbhishek Varma   Value sliceResult = candidateSliceOp.getResult();
19012b2ce50fSAbhishek Varma   // Step 1. Fetch the corresponding output.
19022b2ce50fSAbhishek Varma   OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
19032b2ce50fSAbhishek Varma   unsigned resultNumber = yieldOpOperand.getOperandNumber();
19042b2ce50fSAbhishek Varma   // Step 2. Check containing op is scf.for.
19052b2ce50fSAbhishek Varma   Operation *containingOp = candidateSliceOp->getParentOp();
19062b2ce50fSAbhishek Varma   auto forOp = dyn_cast<scf::ForOp>(containingOp);
19072b2ce50fSAbhishek Varma   if (!forOp)
19082b2ce50fSAbhishek Varma     return failure();
1909a9ba1b6dSYun-Fly   scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
19102b2ce50fSAbhishek Varma 
19119bc3102bSYun-Fly   return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
19122b2ce50fSAbhishek Varma }
19132b2ce50fSAbhishek Varma 
19142b2ce50fSAbhishek Varma /// Fetch the first untiled consumer of a scf.forall's result which is yielded
19152b2ce50fSAbhishek Varma /// by a tensor.parallel_insert_slice.
19162b2ce50fSAbhishek Varma static FailureOr<OpOperand *>
19179bc3102bSYun-Fly getUntiledConsumerFromSlice(RewriterBase &rewriter,
19189bc3102bSYun-Fly                             tensor::ParallelInsertSliceOp candidateSliceOp) {
19192b2ce50fSAbhishek Varma   // Step 1. Fetch the corresponding output
19202b2ce50fSAbhishek Varma   Value sliceDest = candidateSliceOp.getDest();
19212b2ce50fSAbhishek Varma   auto iterArg = dyn_cast<BlockArgument>(sliceDest);
19222b2ce50fSAbhishek Varma   if (!iterArg)
19232b2ce50fSAbhishek Varma     return failure();
19242b2ce50fSAbhishek Varma   Operation *containingOp = iterArg.getOwner()->getParentOp();
19252b2ce50fSAbhishek Varma   if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
19262b2ce50fSAbhishek Varma     return failure();
19272b2ce50fSAbhishek Varma   // Step 2. Check that the containing op is scf.forall.
19282b2ce50fSAbhishek Varma   auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
19292b2ce50fSAbhishek Varma   if (!forallOp)
19302b2ce50fSAbhishek Varma     return failure();
19319bc3102bSYun-Fly   unsigned resultNumber =
19329bc3102bSYun-Fly       forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
19339bc3102bSYun-Fly           .getResultNumber();
19342b2ce50fSAbhishek Varma 
19359bc3102bSYun-Fly   return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
19362b2ce50fSAbhishek Varma }
19372b2ce50fSAbhishek Varma 
19382b2ce50fSAbhishek Varma /// A utility to fetch an untiled consumer of
19392b2ce50fSAbhishek Varma /// tensor.insert_slice/tensor.parallel_insert_slice.
19409bc3102bSYun-Fly static FailureOr<OpOperand *>
19419bc3102bSYun-Fly getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
19422b2ce50fSAbhishek Varma   if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
19439bc3102bSYun-Fly     return getUntiledConsumerFromSlice(rewriter, insertSlice);
19442b2ce50fSAbhishek Varma   } else if (auto parallelInsertSlice =
19452b2ce50fSAbhishek Varma                  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
19469bc3102bSYun-Fly     return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
19472b2ce50fSAbhishek Varma   } else {
19482b2ce50fSAbhishek Varma     return failure();
19492b2ce50fSAbhishek Varma   }
19502b2ce50fSAbhishek Varma }
19512b2ce50fSAbhishek Varma 
19522b2ce50fSAbhishek Varma /// Implementation of fusing consumer of a single slice by computing the
19532b2ce50fSAbhishek Varma /// slice of the consumer in-place for scf loop.
19542b2ce50fSAbhishek Varma FailureOr<scf::SCFFuseConsumerOfSliceResult>
19552b2ce50fSAbhishek Varma mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
19562b2ce50fSAbhishek Varma                                       Operation *candidateSliceOp) {
19572b2ce50fSAbhishek Varma   if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
19582b2ce50fSAbhishek Varma           candidateSliceOp))
19592b2ce50fSAbhishek Varma     return failure();
19602b2ce50fSAbhishek Varma 
19612b2ce50fSAbhishek Varma   bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
19622b2ce50fSAbhishek Varma 
19632b2ce50fSAbhishek Varma   // 1. Get the consumer of scf.for for the result yielded by
19642b2ce50fSAbhishek Varma   // tensor.insert_slice/parallel_insert_slice.
19652b2ce50fSAbhishek Varma   FailureOr<OpOperand *> maybeConsumerOpOperand =
19669bc3102bSYun-Fly       getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
19672b2ce50fSAbhishek Varma   if (failed(maybeConsumerOpOperand)) {
19682b2ce50fSAbhishek Varma     return rewriter.notifyMatchFailure(candidateSliceOp,
19692b2ce50fSAbhishek Varma                                        "could not fetch consumer to fuse");
19702b2ce50fSAbhishek Varma   }
19712b2ce50fSAbhishek Varma   OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
19722b2ce50fSAbhishek Varma   Operation *consumerOp = consumerOpOperand->getOwner();
19732b2ce50fSAbhishek Varma   unsigned operandNumber = consumerOpOperand->getOperandNumber();
19742b2ce50fSAbhishek Varma   unsigned resultNumber = 0;
19752b2ce50fSAbhishek Varma   if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
19762b2ce50fSAbhishek Varma     resultNumber = producerResult.getResultNumber();
19772b2ce50fSAbhishek Varma   } else {
19782b2ce50fSAbhishek Varma     return rewriter.notifyMatchFailure(
19792b2ce50fSAbhishek Varma         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
19802b2ce50fSAbhishek Varma   }
19812b2ce50fSAbhishek Varma 
1982a9ba1b6dSYun-Fly   // There are two possible cases regarding `oldLoopOp` here:
1983a9ba1b6dSYun-Fly   // 1. single `scf.forall` or `scf.for`.
1984a9ba1b6dSYun-Fly   // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1985a9ba1b6dSYun-Fly   // top-level loop is the outer-most one of these nested loops.
1986a9ba1b6dSYun-Fly   LoopLikeOpInterface innerMostLoop =
1987a9ba1b6dSYun-Fly       candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
1988a9ba1b6dSYun-Fly   SmallVector<LoopLikeOpInterface> nestedLoops;
19892b2ce50fSAbhishek Varma   if (isInsertSliceOp) {
1990a9ba1b6dSYun-Fly     nestedLoops = llvm::map_to_vector(
1991a9ba1b6dSYun-Fly         getPerfectlyNestedLoopsOutsideOf(
1992a9ba1b6dSYun-Fly             cast<scf::ForOp>(innerMostLoop.getOperation())),
1993a9ba1b6dSYun-Fly         [](scf::ForOp forOp) {
1994a9ba1b6dSYun-Fly           return cast<LoopLikeOpInterface>(forOp.getOperation());
1995a9ba1b6dSYun-Fly         });
19962b2ce50fSAbhishek Varma   } else {
1997a9ba1b6dSYun-Fly     nestedLoops = {innerMostLoop};
19982b2ce50fSAbhishek Varma   }
19992b2ce50fSAbhishek Varma 
2000a9ba1b6dSYun-Fly   LoopLikeOpInterface outerMostLoop = nestedLoops.front();
2001a9ba1b6dSYun-Fly 
20029bc3102bSYun-Fly   // Check assumption for loop with `reorderOperations` disabled.
20039bc3102bSYun-Fly   if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
20042b2ce50fSAbhishek Varma     return rewriter.notifyMatchFailure(
20059bc3102bSYun-Fly         outerMostLoop, "the first user of loop should not dominate any define "
20069bc3102bSYun-Fly                        "of consumer operand(s)");
20072b2ce50fSAbhishek Varma   }
20082b2ce50fSAbhishek Varma 
20092b2ce50fSAbhishek Varma   OpBuilder::InsertionGuard g(rewriter);
20102b2ce50fSAbhishek Varma 
20112b2ce50fSAbhishek Varma   // 2. Check consumer is not using scf loop's output as init.
2012a9ba1b6dSYun-Fly   auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2013a9ba1b6dSYun-Fly   if (!dstOp)
2014a9ba1b6dSYun-Fly     return rewriter.notifyMatchFailure(consumerOp,
2015a9ba1b6dSYun-Fly                                        "consumer op is not DPS operation");
20162b2ce50fSAbhishek Varma   SmallVector<Value> dpsInits =
20172b2ce50fSAbhishek Varma       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
2018a9ba1b6dSYun-Fly   if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
20192b2ce50fSAbhishek Varma     return rewriter.notifyMatchFailure(
20202b2ce50fSAbhishek Varma         consumerOp,
20212b2ce50fSAbhishek Varma         "consumer op taking the result of scf.for as init is not supported");
20222b2ce50fSAbhishek Varma   }
2023a9ba1b6dSYun-Fly   SmallVector<Value> newInits = dpsInits;
20242b2ce50fSAbhishek Varma 
2025a9ba1b6dSYun-Fly   Location loc = outerMostLoop->getLoc();
20262b2ce50fSAbhishek Varma 
20279bc3102bSYun-Fly   // 3. Move the whole loop structure right before firstUserOfLoop, the
20289bc3102bSYun-Fly   // dominance should be already ensured by `checkAssumptionForLoop`.
20299bc3102bSYun-Fly   FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
20309bc3102bSYun-Fly   if (failed(firstUserOfLoop)) {
20319bc3102bSYun-Fly     return rewriter.notifyMatchFailure(
20329bc3102bSYun-Fly         outerMostLoop, "could not find the first user of outer most loop");
20339bc3102bSYun-Fly   }
20349bc3102bSYun-Fly   rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
20352b2ce50fSAbhishek Varma 
2036a9ba1b6dSYun-Fly   // 4. Set insertion point before terminator op of the loop and create a new
20372b2ce50fSAbhishek Varma   // tensor.insert_slice. In the scf.for case this is a clone of the
20382b2ce50fSAbhishek Varma   // candidateSliceOp whereas in the scf.forall case this is created from the
20392b2ce50fSAbhishek Varma   // operands of tensor.parallel_insert_slice.
20402b2ce50fSAbhishek Varma   tensor::InsertSliceOp clonedInsertSliceOp;
20412b2ce50fSAbhishek Varma   if (auto sliceOp =
20422b2ce50fSAbhishek Varma           dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
2043a9ba1b6dSYun-Fly     auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
20442b2ce50fSAbhishek Varma     rewriter.setInsertionPoint(newForallOp.getTerminator());
20452b2ce50fSAbhishek Varma     clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
20462b2ce50fSAbhishek Varma         loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
20472b2ce50fSAbhishek Varma         sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
20482b2ce50fSAbhishek Varma   } else {
20492b2ce50fSAbhishek Varma     rewriter.setInsertionPoint(candidateSliceOp);
20502b2ce50fSAbhishek Varma     clonedInsertSliceOp =
20512b2ce50fSAbhishek Varma         cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
20522b2ce50fSAbhishek Varma   }
20532b2ce50fSAbhishek Varma 
2054a9ba1b6dSYun-Fly   // 5.a. Clone consumer op.
2055a9ba1b6dSYun-Fly   auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
20562b2ce50fSAbhishek Varma 
2057a9ba1b6dSYun-Fly   // 5.b. Replace all uses of the loop result with the result of the cloned
20582b2ce50fSAbhishek Varma   // tensor.insert_slice.
20592b2ce50fSAbhishek Varma   OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
20602b2ce50fSAbhishek Varma   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
20612b2ce50fSAbhishek Varma     operandToReplace.set(clonedInsertSliceOp.getResult());
20622b2ce50fSAbhishek Varma   });
20632b2ce50fSAbhishek Varma 
2064a9ba1b6dSYun-Fly   // 6. Perform tiling of the cloned consumer and replace the operand at
20652b2ce50fSAbhishek Varma   // `operandNumber` with the source of the cloned tensor.insert_slice op.
20662b2ce50fSAbhishek Varma   auto ossSliceOp =
20672b2ce50fSAbhishek Varma       cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
20682b2ce50fSAbhishek Varma   FailureOr<TilingResult> tileAndFuseResult =
20692b2ce50fSAbhishek Varma       tensor::replaceInsertSliceWithTiledConsumer(
20702b2ce50fSAbhishek Varma           rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
20712b2ce50fSAbhishek Varma   if (failed(tileAndFuseResult)) {
20722b2ce50fSAbhishek Varma     return failure();
20732b2ce50fSAbhishek Varma   }
2074a9ba1b6dSYun-Fly   auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2075a9ba1b6dSYun-Fly   rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
20762b2ce50fSAbhishek Varma                               clonedInsertSliceOp.getSource());
20772b2ce50fSAbhishek Varma 
2078a9ba1b6dSYun-Fly   // 7. Reconstruct [nested] loop with new inits.
2079a9ba1b6dSYun-Fly   YieldTiledValuesFn newYieldValuesFn =
2080a9ba1b6dSYun-Fly       [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
2081a9ba1b6dSYun-Fly           ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
2082a9ba1b6dSYun-Fly           SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
2083a9ba1b6dSYun-Fly           SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
2084a9ba1b6dSYun-Fly     OpBuilder::InsertionGuard g(innerRewriter);
2085a9ba1b6dSYun-Fly     // 8. Set inner insertPoint right before tiled consumer op.
2086a9ba1b6dSYun-Fly     innerRewriter.setInsertionPoint(tiledConsumerOp);
2087a9ba1b6dSYun-Fly 
20882b2ce50fSAbhishek Varma     SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
20892b2ce50fSAbhishek Varma     SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
20902b2ce50fSAbhishek Varma     SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
20912b2ce50fSAbhishek Varma 
20922b2ce50fSAbhishek Varma     // 9. Check all insert stride is 1.
20932b2ce50fSAbhishek Varma     if (llvm::any_of(strides, [](OpFoldResult stride) {
20942b2ce50fSAbhishek Varma           return !isConstantIntValue(stride, 1);
20952b2ce50fSAbhishek Varma         })) {
20962b2ce50fSAbhishek Varma       return rewriter.notifyMatchFailure(
20972b2ce50fSAbhishek Varma           candidateSliceOp, "containingOp's result yield with stride");
20982b2ce50fSAbhishek Varma     }
20992b2ce50fSAbhishek Varma 
21008cc616bcSMax191     // 10. Try to get iter domain position from input position. Use
21014b563458SKunwar Grover     // clonedConsumerOp instead of tiledConsumerOp, because the iteration
21024b563458SKunwar Grover     // domain may require index computation based on the result size. The
21034b563458SKunwar Grover     // sizes and offsets should be the same either way, but using
21044b563458SKunwar Grover     // tiledConsumerOp could lead to some chained unnecessary extra index
21054b563458SKunwar Grover     // computation.
21062b2ce50fSAbhishek Varma     SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
21078cc616bcSMax191     if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
21082b2ce50fSAbhishek Varma             rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
21092b2ce50fSAbhishek Varma             iterDomainSizes))) {
21102b2ce50fSAbhishek Varma       return rewriter.notifyMatchFailure(
21118cc616bcSMax191           clonedConsumerOp,
2112a9ba1b6dSYun-Fly           "can't get iter domain position from input position");
21132b2ce50fSAbhishek Varma     }
21142b2ce50fSAbhishek Varma 
21152b2ce50fSAbhishek Varma     // 11. Try to fetch the offset and size for all results of the cloned
21162b2ce50fSAbhishek Varma     // consumer. This would then be used to form the corresponding
21172b2ce50fSAbhishek Varma     // tensor.insert_slice/parallel_insert_slice later.
2118a9ba1b6dSYun-Fly     unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
21192b2ce50fSAbhishek Varma     SmallVector<SmallVector<OpFoldResult>> resultOffsets(
21202b2ce50fSAbhishek Varma         totalNumResultsOfConsumer);
2121a9ba1b6dSYun-Fly     SmallVector<SmallVector<OpFoldResult>> resultSizes(
2122a9ba1b6dSYun-Fly         totalNumResultsOfConsumer);
2123a9ba1b6dSYun-Fly     for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2124a9ba1b6dSYun-Fly       if (failed(tiledConsumerOp.getResultTilePosition(
21252b2ce50fSAbhishek Varma               rewriter, idx, iterDomainOffsets, iterDomainSizes,
21262b2ce50fSAbhishek Varma               resultOffsets[idx], resultSizes[idx]))) {
21272b2ce50fSAbhishek Varma         return rewriter.notifyMatchFailure(
2128a9ba1b6dSYun-Fly             tiledConsumerOp,
21292b2ce50fSAbhishek Varma             "can't get result domain position from iter domain position");
21302b2ce50fSAbhishek Varma       }
21312b2ce50fSAbhishek Varma     }
21322b2ce50fSAbhishek Varma 
2133a9ba1b6dSYun-Fly     // 12. Create `extract_slice` for `iter_args` for DPS operation if
2134a9ba1b6dSYun-Fly     // necessary.
2135a9ba1b6dSYun-Fly     if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2136a9ba1b6dSYun-Fly             tiledConsumerOp.getOperation())) {
2137a9ba1b6dSYun-Fly       rewriter.setInsertionPoint(tiledDestStyleOp);
2138a9ba1b6dSYun-Fly       for (const auto &&[index, newRegionArg] :
2139a9ba1b6dSYun-Fly            llvm::enumerate(newRegionIterArgs)) {
2140a9ba1b6dSYun-Fly         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
2141a9ba1b6dSYun-Fly             loc, newRegionArg, resultOffsets[index], resultSizes[index],
2142a9ba1b6dSYun-Fly             SmallVector<OpFoldResult>(resultOffsets[index].size(),
2143a9ba1b6dSYun-Fly                                       rewriter.getIndexAttr(1)));
2144a9ba1b6dSYun-Fly         // Make a copy of index to avoid a capturing structured binding, which
2145a9ba1b6dSYun-Fly         // is a C++20 extension.
2146a9ba1b6dSYun-Fly         auto dstNumber = index;
2147a9ba1b6dSYun-Fly         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
2148a9ba1b6dSYun-Fly           tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2149a9ba1b6dSYun-Fly         });
2150a9ba1b6dSYun-Fly       }
21512b2ce50fSAbhishek Varma     }
21522b2ce50fSAbhishek Varma 
2153a9ba1b6dSYun-Fly     // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
2154a9ba1b6dSYun-Fly     // caller.
2155a9ba1b6dSYun-Fly     Block *block = rewriter.getInsertionPoint()->getBlock();
2156a9ba1b6dSYun-Fly     rewriter.setInsertionPoint(block->getTerminator());
2157a9ba1b6dSYun-Fly     for (const auto &&[index, result] :
2158a9ba1b6dSYun-Fly          llvm::enumerate(tiledConsumerOp->getResults())) {
2159a9ba1b6dSYun-Fly       tiledResult.push_back(result);
2160a9ba1b6dSYun-Fly       tiledOffset.emplace_back(resultOffsets[index]);
2161a9ba1b6dSYun-Fly       tiledSizes.emplace_back(resultSizes[index]);
2162a9ba1b6dSYun-Fly     }
2163a9ba1b6dSYun-Fly     return success();
2164a9ba1b6dSYun-Fly   };
2165a9ba1b6dSYun-Fly   // 14. Add new inits to [nested] loops.
2166a9ba1b6dSYun-Fly   if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
2167a9ba1b6dSYun-Fly                                        newYieldValuesFn))) {
2168a9ba1b6dSYun-Fly     return rewriter.notifyMatchFailure(tiledConsumerOp,
2169a9ba1b6dSYun-Fly                                        "unable to add new inits to nest loop");
2170a9ba1b6dSYun-Fly   }
2171a9ba1b6dSYun-Fly 
21724b563458SKunwar Grover   // 15. Replace the result of scf loop and consumer op with new loop's
21734b563458SKunwar Grover   // results.
2174a9ba1b6dSYun-Fly 
2175a9ba1b6dSYun-Fly   for (auto &&[oldResult, newResult] : llvm::zip(
2176a9ba1b6dSYun-Fly            consumerOp->getResults(),
2177a9ba1b6dSYun-Fly            nestedLoops.front()->getResults().take_back(newInits.size()))) {
21782b2ce50fSAbhishek Varma     rewriter.replaceAllUsesWith(oldResult, newResult);
21792b2ce50fSAbhishek Varma   }
21802b2ce50fSAbhishek Varma 
2181a9ba1b6dSYun-Fly   // 16. Need to erase the old scf loop and the cloned consumer op.
21822b2ce50fSAbhishek Varma   rewriter.eraseOp(clonedConsumerOp);
21832b2ce50fSAbhishek Varma 
21842b2ce50fSAbhishek Varma   return scf::SCFFuseConsumerOfSliceResult{
21852b2ce50fSAbhishek Varma       consumerOpOperand,
21862b2ce50fSAbhishek Varma       &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
21872b2ce50fSAbhishek Varma       tileAndFuseResult->tiledOps};
21882b2ce50fSAbhishek Varma }
21892b2ce50fSAbhishek Varma 
21902b2ce50fSAbhishek Varma //===----------------------------------------------------------------------===//
219197f91982SMahesh Ravishankar // lowerToLoopsUsingSCFForOp implementation.
21926f03a10eSMahesh Ravishankar //===----------------------------------------------------------------------===//
21936f03a10eSMahesh Ravishankar 
21946f03a10eSMahesh Ravishankar FailureOr<SmallVector<scf::ForOp>>
219597f91982SMahesh Ravishankar mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
219697f91982SMahesh Ravishankar                                      TilingInterface op) {
21976f03a10eSMahesh Ravishankar   // TODO: Handle cases where the op has results if needed.
21986f03a10eSMahesh Ravishankar   if (op->getNumResults() > 0) {
21996f03a10eSMahesh Ravishankar     return rewriter.notifyMatchFailure(
22006f03a10eSMahesh Ravishankar         op, "unable to lower to loops operations with return values");
22016f03a10eSMahesh Ravishankar   }
22026f03a10eSMahesh Ravishankar 
220397f91982SMahesh Ravishankar   SmallVector<Range> domain = op.getIterationDomain(rewriter);
22046f03a10eSMahesh Ravishankar   SmallVector<Value> ivs;
22056f03a10eSMahesh Ravishankar   SmallVector<scf::ForOp> loops;
22066f03a10eSMahesh Ravishankar   Location loc = op.getLoc();
22076f03a10eSMahesh Ravishankar   for (auto loopRange : domain) {
22086f03a10eSMahesh Ravishankar     Value offsetVal =
22096f03a10eSMahesh Ravishankar         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
22106f03a10eSMahesh Ravishankar     Value sizeVal =
22116f03a10eSMahesh Ravishankar         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
22126f03a10eSMahesh Ravishankar     Value strideVal =
22136f03a10eSMahesh Ravishankar         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
22146f03a10eSMahesh Ravishankar     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
22156f03a10eSMahesh Ravishankar                                             strideVal, ValueRange{});
22166f03a10eSMahesh Ravishankar     loops.push_back(loop);
22176f03a10eSMahesh Ravishankar     ivs.push_back(loop.getInductionVar());
22186f03a10eSMahesh Ravishankar     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
22196f03a10eSMahesh Ravishankar   }
22206f03a10eSMahesh Ravishankar   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
22216f03a10eSMahesh Ravishankar     return failure();
22226f03a10eSMahesh Ravishankar   }
22236f03a10eSMahesh Ravishankar   return loops;
22246f03a10eSMahesh Ravishankar }
2225