1cf6a7c19SMahesh Ravishankar //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2cf6a7c19SMahesh Ravishankar // 3cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information. 5cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6cf6a7c19SMahesh Ravishankar // 7cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===// 8cf6a7c19SMahesh Ravishankar // 9cf6a7c19SMahesh Ravishankar // This file implements the tiling using TilingInterface. 10cf6a7c19SMahesh Ravishankar // 11cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===// 12cf6a7c19SMahesh Ravishankar 138b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14cf6a7c19SMahesh Ravishankar 159bc3102bSYun-Fly #include "mlir/Analysis/SliceAnalysis.h" 169bc3102bSYun-Fly #include "mlir/Analysis/TopologicalSortUtils.h" 17cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h" 18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 19abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h" 20cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Func/IR/FuncOps.h" 21cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/Utils/Utils.h" 22cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h" 23b1d3afc9SHanhan Wang #include "mlir/Dialect/Utils/IndexingUtils.h" 242b2ce50fSAbhishek Varma #include "mlir/IR/Dominance.h" 25cf6a7c19SMahesh Ravishankar #include "mlir/IR/Matchers.h" 26cf6a7c19SMahesh Ravishankar #include "mlir/IR/PatternMatch.h" 27b169643fSMatthias Springer #include "mlir/Interfaces/DestinationStyleOpInterface.h" 28cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h" 299144fed3SQuinn Dawkins #include "mlir/Rewrite/FrozenRewritePatternSet.h" 309144fed3SQuinn Dawkins #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 316e3631d0SKunwar Grover #include "llvm/ADT/ScopeExit.h" 3276ead96cSMaheshRavishankar #include "llvm/ADT/TypeSwitch.h" 33cf6a7c19SMahesh Ravishankar #include "llvm/Support/Debug.h" 345c9013e2SKazu Hirata #include <optional> 35cf6a7c19SMahesh Ravishankar 36cf6a7c19SMahesh Ravishankar #define DEBUG_TYPE "tile-using-interface" 37cf6a7c19SMahesh Ravishankar 38cf6a7c19SMahesh Ravishankar using namespace mlir; 39cf6a7c19SMahesh Ravishankar 40cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions & 41170a25a7SMaheshRavishankar scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) { 42cf6a7c19SMahesh Ravishankar assert(!tileSizeComputationFunction && "tile sizes already set"); 43170a25a7SMaheshRavishankar auto tileSizes = llvm::to_vector(ts); 44cf6a7c19SMahesh Ravishankar tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 45170a25a7SMaheshRavishankar return tileSizes; 46cf6a7c19SMahesh Ravishankar }; 47cf6a7c19SMahesh Ravishankar return *this; 48cf6a7c19SMahesh Ravishankar } 49cf6a7c19SMahesh Ravishankar 506740d701SMaheshRavishankar scf::SCFTilingOptions & 516740d701SMaheshRavishankar scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) { 526740d701SMaheshRavishankar assert(!numThreadsComputationFunction && "num tiles already set"); 536740d701SMaheshRavishankar auto numThreads = llvm::to_vector(nt); 546740d701SMaheshRavishankar numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) { 556740d701SMaheshRavishankar return numThreads; 566740d701SMaheshRavishankar }; 576740d701SMaheshRavishankar return *this; 586740d701SMaheshRavishankar } 596740d701SMaheshRavishankar 60b8a1f00dSMahesh Ravishankar /// Helper method to adjust the interchange vector to match the iteration 61b8a1f00dSMahesh Ravishankar /// domain. 6279150279SNicolas Vasilache static SmallVector<int64_t> 6379150279SNicolas Vasilache fillInterchangeVector(ArrayRef<int64_t> interchangeVector, 64b8a1f00dSMahesh Ravishankar size_t iterationDomainSize) { 6579150279SNicolas Vasilache SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); 66b8a1f00dSMahesh Ravishankar if (filledVector.size() < iterationDomainSize) { 6779150279SNicolas Vasilache auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); 68b8a1f00dSMahesh Ravishankar filledVector.append(range.begin(), range.end()); 69b8a1f00dSMahesh Ravishankar } 70b8a1f00dSMahesh Ravishankar if (filledVector.size() > iterationDomainSize) 71b8a1f00dSMahesh Ravishankar filledVector.resize(iterationDomainSize); 72b8a1f00dSMahesh Ravishankar return filledVector; 73b8a1f00dSMahesh Ravishankar } 74b8a1f00dSMahesh Ravishankar 752f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===// 7676ead96cSMaheshRavishankar // tileUsingSCF implementation. 772f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===// 782f637fe7SMahesh Ravishankar 796740d701SMaheshRavishankar /// Verify the tile size options are set in a consistent manner. 806740d701SMaheshRavishankar static LogicalResult 816740d701SMaheshRavishankar verifyTileSizeOptions(RewriterBase &rewriter, Location loc, 826740d701SMaheshRavishankar const scf::SCFTilingOptions &options) { 836740d701SMaheshRavishankar // Specifying number of threads is only supported on `scf.forall` op. 846740d701SMaheshRavishankar if (options.numThreadsComputationFunction && 856740d701SMaheshRavishankar options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) { 866740d701SMaheshRavishankar return rewriter.notifyMatchFailure( 876740d701SMaheshRavishankar loc, "number of threads can only by specified when loop type is " 886740d701SMaheshRavishankar "set to use `scf.forall`"); 896740d701SMaheshRavishankar } 906740d701SMaheshRavishankar 916740d701SMaheshRavishankar // If specified, check that the interchange vector is a permutation. 926740d701SMaheshRavishankar if (!options.interchangeVector.empty()) { 936740d701SMaheshRavishankar if (!isPermutationVector(options.interchangeVector)) { 946740d701SMaheshRavishankar return rewriter.notifyMatchFailure( 956740d701SMaheshRavishankar loc, "invalid interchange vector, not a permutation of the entire " 966740d701SMaheshRavishankar "iteration space"); 976740d701SMaheshRavishankar } 986740d701SMaheshRavishankar } 996740d701SMaheshRavishankar return success(); 1006740d701SMaheshRavishankar } 1016740d701SMaheshRavishankar 1026740d701SMaheshRavishankar /// Method to instantiate the tile sizes and/or number of threads specified 1036740d701SMaheshRavishankar /// by the user. 1046740d701SMaheshRavishankar static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> 1056740d701SMaheshRavishankar getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, 1066740d701SMaheshRavishankar ArrayRef<Range> iterationDomain, 1076740d701SMaheshRavishankar const scf::SCFTilingOptions &options) { 1086740d701SMaheshRavishankar OpFoldResult zero = rewriter.getIndexAttr(0); 1096740d701SMaheshRavishankar SmallVector<OpFoldResult> tileSizes, numThreads; 1106740d701SMaheshRavishankar size_t numLoops = iterationDomain.size(); 1116740d701SMaheshRavishankar 1126740d701SMaheshRavishankar // Check whether the number of tiles to use is specified. 1136740d701SMaheshRavishankar if (options.numThreadsComputationFunction) { 1146740d701SMaheshRavishankar numThreads = options.numThreadsComputationFunction(rewriter, op); 1156740d701SMaheshRavishankar numThreads.resize(numLoops, zero); 1166740d701SMaheshRavishankar 1176740d701SMaheshRavishankar // If the number of tiles is also specified, use that. 1186740d701SMaheshRavishankar if (options.tileSizeComputationFunction) { 1196740d701SMaheshRavishankar tileSizes = options.tileSizeComputationFunction(rewriter, op); 1206740d701SMaheshRavishankar tileSizes.resize(numLoops, zero); 1216740d701SMaheshRavishankar return {tileSizes, numThreads}; 1226740d701SMaheshRavishankar } 1236740d701SMaheshRavishankar 1246740d701SMaheshRavishankar // Compute the tile sizes from the iteration domain and number 1256740d701SMaheshRavishankar // of tiles as follows 1266740d701SMaheshRavishankar // - niters = ceilDiv(ub - lb, step) 1276740d701SMaheshRavishankar // - tileSize = ceilDiv(niters, numThreads) 1286740d701SMaheshRavishankar AffineExpr s0, s1, s2; 1296740d701SMaheshRavishankar bindSymbols(rewriter.getContext(), s0, s1, s2); 1306740d701SMaheshRavishankar // TODO: The step here is assumed to be 1. 1316740d701SMaheshRavishankar AffineExpr numItersExpr = (s1 - s0); 1326740d701SMaheshRavishankar AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2); 1336740d701SMaheshRavishankar tileSizes.resize(numLoops, zero); 1346740d701SMaheshRavishankar for (auto [index, range, nt] : 1356740d701SMaheshRavishankar llvm::enumerate(iterationDomain, numThreads)) { 1366740d701SMaheshRavishankar if (isConstantIntValue(nt, 0)) 1376740d701SMaheshRavishankar continue; 1386740d701SMaheshRavishankar 1396740d701SMaheshRavishankar tileSizes[index] = affine::makeComposedFoldedAffineApply( 1406740d701SMaheshRavishankar rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt}); 1416740d701SMaheshRavishankar } 1426740d701SMaheshRavishankar tileSizes.resize(numLoops, zero); 1436740d701SMaheshRavishankar return {tileSizes, numThreads}; 1446740d701SMaheshRavishankar } 1456740d701SMaheshRavishankar 1466740d701SMaheshRavishankar // Enforce the convention that "tiling by zero" 1476740d701SMaheshRavishankar // skips tiling a particular dimension. This convention is significantly 1486740d701SMaheshRavishankar // simpler to handle instead of adjusting affine maps to account for missing 1496740d701SMaheshRavishankar // dimensions. 1506740d701SMaheshRavishankar assert(options.tileSizeComputationFunction && 1516740d701SMaheshRavishankar "expected tile sizes to be specified"); 1526740d701SMaheshRavishankar tileSizes = options.tileSizeComputationFunction(rewriter, op); 1536740d701SMaheshRavishankar tileSizes.resize(numLoops, zero); 1546740d701SMaheshRavishankar 1556740d701SMaheshRavishankar return {tileSizes, numThreads}; 1566740d701SMaheshRavishankar } 1576740d701SMaheshRavishankar 1586740d701SMaheshRavishankar /// Checks if any of the tiled loops are not parallel. 1596740d701SMaheshRavishankar static void checkSafeToTileToForall(TilingInterface op, 1606740d701SMaheshRavishankar ArrayRef<OpFoldResult> tileSizes, 1616740d701SMaheshRavishankar ArrayRef<OpFoldResult> numThreads) { 1626740d701SMaheshRavishankar auto iterators = op.getLoopIteratorTypes(); 1636740d701SMaheshRavishankar assert(iterators.size() == tileSizes.size() && 1646740d701SMaheshRavishankar "expected as many tile size values as number of loops"); 1656740d701SMaheshRavishankar assert((numThreads.empty() || (numThreads.size() == iterators.size())) && 1666740d701SMaheshRavishankar "when specified, expected number of threads to use for each loop"); 1676740d701SMaheshRavishankar 1686740d701SMaheshRavishankar for (auto [index, iterator, tileSize] : 1696740d701SMaheshRavishankar llvm::enumerate(iterators, tileSizes)) { 1706740d701SMaheshRavishankar // If num threads is specified, check that it is greater than one only for 1716740d701SMaheshRavishankar // parallel dimensions. 1726740d701SMaheshRavishankar if (!numThreads.empty()) { 1736740d701SMaheshRavishankar if (std::optional<int64_t> constNumThreads = 1746740d701SMaheshRavishankar getConstantIntValue(numThreads[index])) { 1756740d701SMaheshRavishankar if (constNumThreads.value() > 1 && 1766740d701SMaheshRavishankar iterator != utils::IteratorType::parallel) { 1776740d701SMaheshRavishankar op.emitWarning() << "tiling is not thread safe at axis #" << index; 1786740d701SMaheshRavishankar } 1796740d701SMaheshRavishankar } 1806740d701SMaheshRavishankar continue; 1816740d701SMaheshRavishankar } 1826740d701SMaheshRavishankar 1836740d701SMaheshRavishankar if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) { 1846740d701SMaheshRavishankar if (constTileSize.value() > 0 && 1856740d701SMaheshRavishankar iterator != utils::IteratorType::parallel) { 1866740d701SMaheshRavishankar op.emitWarning() << "tiling is not thread safe at axis #" << index; 1876740d701SMaheshRavishankar } 1886740d701SMaheshRavishankar } 1896740d701SMaheshRavishankar } 1906740d701SMaheshRavishankar } 1916740d701SMaheshRavishankar 1926740d701SMaheshRavishankar /// Check if `stride` evenly divides the trip count `size - offset`. 193954de25aSlorenzo chelini static bool tileDividesIterationDomain(Range loopRange) { 19422426110SRamkumar Ramachandra std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 195954de25aSlorenzo chelini if (!offsetAsInt) 196954de25aSlorenzo chelini return false; 19722426110SRamkumar Ramachandra std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 198954de25aSlorenzo chelini if (!sizeAsInt) 199954de25aSlorenzo chelini return false; 20022426110SRamkumar Ramachandra std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 201954de25aSlorenzo chelini if (!strideAsInt) 202954de25aSlorenzo chelini return false; 203954de25aSlorenzo chelini return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 204954de25aSlorenzo chelini } 205954de25aSlorenzo chelini 2066740d701SMaheshRavishankar /// Returns the bounded tile size given the current `offset`, `loopRange` and 2076740d701SMaheshRavishankar /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`. 20871cf48a6SHanhan Wang static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, 2096740d701SMaheshRavishankar Range loopRange, OpFoldResult offset, 210d871daeaSMaheshRavishankar OpFoldResult tileSize) { 211d2b7a8e8SAdrian Kuegel std::optional<int64_t> ts = getConstantIntValue(tileSize); 212d2b7a8e8SAdrian Kuegel if (ts && ts.value() == 1) 213d871daeaSMaheshRavishankar return tileSize; 21471cf48a6SHanhan Wang 21571cf48a6SHanhan Wang if (tileDividesIterationDomain( 21671cf48a6SHanhan Wang Range{loopRange.offset, loopRange.size, tileSize})) 21771cf48a6SHanhan Wang return tileSize; 21871cf48a6SHanhan Wang 21971cf48a6SHanhan Wang // The tile size to use (to avoid out of bounds access) is minimum of 22071cf48a6SHanhan Wang // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled 22171cf48a6SHanhan Wang // loop. 22271cf48a6SHanhan Wang AffineExpr s0, s1, d0; 22371cf48a6SHanhan Wang bindDims(b.getContext(), d0); 22471cf48a6SHanhan Wang bindSymbols(b.getContext(), s0, s1); 2256740d701SMaheshRavishankar AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext()); 22671cf48a6SHanhan Wang Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); 2274c48f016SMatthias Springer return affine::makeComposedFoldedAffineMin( 2286740d701SMaheshRavishankar b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize}); 2296740d701SMaheshRavishankar } 2306740d701SMaheshRavishankar 2316740d701SMaheshRavishankar /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less 2326740d701SMaheshRavishankar /// than `iterationSize`. 2336740d701SMaheshRavishankar static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, 2346740d701SMaheshRavishankar OpFoldResult numThreads, 2356740d701SMaheshRavishankar OpFoldResult iterationSize) { 2366740d701SMaheshRavishankar std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize); 2376740d701SMaheshRavishankar std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads); 2386740d701SMaheshRavishankar std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize); 2396740d701SMaheshRavishankar if (!tileSizeConst || !numThreadsConst || !iterSizeConst) 2406740d701SMaheshRavishankar return false; 2416740d701SMaheshRavishankar return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; 2426740d701SMaheshRavishankar } 2436740d701SMaheshRavishankar 2446740d701SMaheshRavishankar /// Compute the `OpFoldResult`s that represents the multi-dimensional 2456740d701SMaheshRavishankar /// `offset`s and `size`s of the tile of the iteration space that the 2466740d701SMaheshRavishankar /// innermost loop body of the generated tiled loops corresponds to. 2476740d701SMaheshRavishankar static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> 2486740d701SMaheshRavishankar getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, 2496740d701SMaheshRavishankar ArrayRef<Range> iterationDomain, 2506740d701SMaheshRavishankar ArrayRef<OpFoldResult> tileSizes, 2516740d701SMaheshRavishankar ArrayRef<OpFoldResult> numThreads) { 2526740d701SMaheshRavishankar SmallVector<OpFoldResult> offsets, sizes; 2536740d701SMaheshRavishankar int materializedLoopNum = 0; 2546740d701SMaheshRavishankar 2556740d701SMaheshRavishankar if (!numThreads.empty()) { 2566740d701SMaheshRavishankar AffineExpr d0, d1, s0, s1; 2576740d701SMaheshRavishankar AffineExpr offsetExpr, residualTileSizeExpr; 2586740d701SMaheshRavishankar bindDims(rewriter.getContext(), d0, d1); 2596740d701SMaheshRavishankar bindSymbols(rewriter.getContext(), s0, s1); 2606740d701SMaheshRavishankar offsetExpr = d0 + d1 * s0; 2616740d701SMaheshRavishankar residualTileSizeExpr = s1 - (d0 + d1 * s0); 2626740d701SMaheshRavishankar 2636740d701SMaheshRavishankar for (auto [nt, tileSize, loopRange] : 2646740d701SMaheshRavishankar llvm::zip_equal(numThreads, tileSizes, iterationDomain)) { 2656740d701SMaheshRavishankar 2666740d701SMaheshRavishankar // Non-tiled cases, set the offset and size to the 2676740d701SMaheshRavishankar // `loopRange.offset/size`. 2686740d701SMaheshRavishankar if (isConstantIntValue(nt, 0)) { 2696740d701SMaheshRavishankar offsets.push_back(loopRange.offset); 2706740d701SMaheshRavishankar sizes.push_back(loopRange.size); 2716740d701SMaheshRavishankar continue; 2726740d701SMaheshRavishankar } 2736740d701SMaheshRavishankar 2746740d701SMaheshRavishankar Value iv = ivs[materializedLoopNum++]; 2756740d701SMaheshRavishankar OpFoldResult offset = affine::makeComposedFoldedAffineApply( 2766740d701SMaheshRavishankar rewriter, loc, offsetExpr, 2776740d701SMaheshRavishankar ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize}); 2786740d701SMaheshRavishankar OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply( 2796740d701SMaheshRavishankar rewriter, loc, residualTileSizeExpr, 2806740d701SMaheshRavishankar {loopRange.offset, nt, tileSize, loopRange.size}); 2816740d701SMaheshRavishankar 2826740d701SMaheshRavishankar OpFoldResult size = tileSize; 2836740d701SMaheshRavishankar if (!isConstantIntValue(residualTileSize, 0)) { 2846740d701SMaheshRavishankar OpFoldResult sizeMinusOffsetPerThread = 2856740d701SMaheshRavishankar affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, 2866740d701SMaheshRavishankar {offset, loopRange.size}); 2876740d701SMaheshRavishankar size = affine::makeComposedFoldedAffineMin( 2886740d701SMaheshRavishankar rewriter, loc, 2896740d701SMaheshRavishankar AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), 2906740d701SMaheshRavishankar {sizeMinusOffsetPerThread, tileSize}); 2916740d701SMaheshRavishankar } 2926740d701SMaheshRavishankar 2936740d701SMaheshRavishankar // Consider the case where the original loop was `[0, 100)`. 2946740d701SMaheshRavishankar // If number of threads are `7`, the tile size would be computed as 2956740d701SMaheshRavishankar // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6) 2966740d701SMaheshRavishankar // - `offset = 0 + 6 * 15 = 105` 2976740d701SMaheshRavishankar // - `tileSize = min(15, 100 - 105) = -5` 2986740d701SMaheshRavishankar // To avoid negative tile sizes, we need to do a further 2996740d701SMaheshRavishankar // `nonNegativeTileSize = affine.max(0, tileSize)`. 3006740d701SMaheshRavishankar // This `max` can be avoided if 3016740d701SMaheshRavishankar // `offset + tileSize * (numThreads - 1) < (ub - lb)` 3026740d701SMaheshRavishankar if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) { 3036740d701SMaheshRavishankar AffineMap maxMap = 3046740d701SMaheshRavishankar AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); 3056740d701SMaheshRavishankar size = affine::makeComposedFoldedAffineMax( 3066740d701SMaheshRavishankar rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size}); 3076740d701SMaheshRavishankar } 3086740d701SMaheshRavishankar 3096740d701SMaheshRavishankar offsets.push_back(offset); 3106740d701SMaheshRavishankar sizes.push_back(size); 3116740d701SMaheshRavishankar } 3126740d701SMaheshRavishankar return {offsets, sizes}; 3136740d701SMaheshRavishankar } else { 3146740d701SMaheshRavishankar for (auto [tileSize, loopRange] : 3156740d701SMaheshRavishankar llvm::zip_equal(tileSizes, iterationDomain)) { 3166740d701SMaheshRavishankar 3176740d701SMaheshRavishankar // Non-tiled cases, set the offset and size to the 3186740d701SMaheshRavishankar // `loopRange.offset/size`. 3196740d701SMaheshRavishankar if (isConstantIntValue(tileSize, 0)) { 3206740d701SMaheshRavishankar offsets.push_back(loopRange.offset); 3216740d701SMaheshRavishankar sizes.push_back(loopRange.size); 3226740d701SMaheshRavishankar continue; 3236740d701SMaheshRavishankar } 3246740d701SMaheshRavishankar 3256740d701SMaheshRavishankar Value iv = ivs[materializedLoopNum++]; 3266740d701SMaheshRavishankar OpFoldResult offset = getAsOpFoldResult(iv); 3276740d701SMaheshRavishankar offsets.push_back(offset); 3286740d701SMaheshRavishankar OpFoldResult size = 3296740d701SMaheshRavishankar getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize); 3306740d701SMaheshRavishankar sizes.push_back(size); 3316740d701SMaheshRavishankar } 3326740d701SMaheshRavishankar return {offsets, sizes}; 3336740d701SMaheshRavishankar } 3346740d701SMaheshRavishankar } 3356740d701SMaheshRavishankar 3366740d701SMaheshRavishankar /// Function to return the bounds of the loops to be generated. 3376740d701SMaheshRavishankar static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, 3386740d701SMaheshRavishankar SmallVector<OpFoldResult>> 3396740d701SMaheshRavishankar getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 3406740d701SMaheshRavishankar ArrayRef<OpFoldResult> tileSizes) { 3416740d701SMaheshRavishankar SmallVector<OpFoldResult> lbs, ubs, steps; 3426740d701SMaheshRavishankar for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) { 3436740d701SMaheshRavishankar // No loop if the tile size is 0. 3446740d701SMaheshRavishankar if (isConstantIntValue(tileSize, 0)) 3456740d701SMaheshRavishankar continue; 3466740d701SMaheshRavishankar lbs.push_back(loopRange.offset); 3476740d701SMaheshRavishankar ubs.push_back(loopRange.size); 3486740d701SMaheshRavishankar steps.push_back(tileSize); 3496740d701SMaheshRavishankar } 3506740d701SMaheshRavishankar return {lbs, ubs, steps}; 35171cf48a6SHanhan Wang } 35271cf48a6SHanhan Wang 35376ead96cSMaheshRavishankar /// A function that allows returning additional yielded values during 35476ead96cSMaheshRavishankar /// `yieldTiledValuesAndReplace`. 35576ead96cSMaheshRavishankar /// - `ivs` induction variable for the loop. 35676ead96cSMaheshRavishankar /// - `newBbArgs` basic block arguments corresponding to newly added iter_args. 35776ead96cSMaheshRavishankar /// - `tiledValues` the tiled values to return. Must be of same size as 35876ead96cSMaheshRavishankar /// `newbbArgs`, each element of this array is inserted into the corresponding 35976ead96cSMaheshRavishankar /// element in `newbbArgs`. 36076ead96cSMaheshRavishankar /// - `resultOffsets` is of the same size as `tiledValues` and represents 36176ead96cSMaheshRavishankar /// the offsets to use when inserting corresponding element from `tiledValues` 36276ead96cSMaheshRavishankar /// into the element from `newBbArgs`. 36376ead96cSMaheshRavishankar /// - `resultSizes` is of the same size as `tiledValues` and represents 36476ead96cSMaheshRavishankar /// the size of the corresponding element from `tiledValues` inserted into 36576ead96cSMaheshRavishankar /// the element from `newBbArgs`. 36676ead96cSMaheshRavishankar /// In case the method needs to return `failure()` the method is expected 36776ead96cSMaheshRavishankar /// to clean up any inserted operations. 36876ead96cSMaheshRavishankar using YieldTiledValuesFn = std::function<LogicalResult( 36976ead96cSMaheshRavishankar RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, 37076ead96cSMaheshRavishankar SmallVector<Value> &tiledValues, 37176ead96cSMaheshRavishankar SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 37276ead96cSMaheshRavishankar SmallVector<SmallVector<OpFoldResult>> &resultSizes)>; 37376ead96cSMaheshRavishankar 374d871daeaSMaheshRavishankar /// Clones the operation and updates the destination if the operation 375d871daeaSMaheshRavishankar /// implements the `DestinationStyleOpInterface`. 376d871daeaSMaheshRavishankar static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, 377d871daeaSMaheshRavishankar Operation *op, 378d871daeaSMaheshRavishankar ValueRange newDestArgs) { 379d871daeaSMaheshRavishankar Operation *clonedOp = rewriter.clone(*op); 3804a020018SMaheshRavishankar if (newDestArgs.empty()) 3814a020018SMaheshRavishankar return clonedOp; 3824a020018SMaheshRavishankar if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp)) 383d871daeaSMaheshRavishankar destinationStyleOp.getDpsInitsMutable().assign(newDestArgs); 384d871daeaSMaheshRavishankar return clonedOp; 385d871daeaSMaheshRavishankar } 386d871daeaSMaheshRavishankar 38776ead96cSMaheshRavishankar /// Generate the tile-loop nest using `scf.for` operation. 388cf6a7c19SMahesh Ravishankar /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 389170a25a7SMaheshRavishankar /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 39076ead96cSMaheshRavishankar /// - `destinationTensors` are the init values to use for the outer most loop. 39176ead96cSMaheshRavishankar /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 39276ead96cSMaheshRavishankar /// most 39376ead96cSMaheshRavishankar /// loop. 39476ead96cSMaheshRavishankar /// - `loops` is an in-out parameter into which the generated loops are 39576ead96cSMaheshRavishankar /// populated. 39676ead96cSMaheshRavishankar static LogicalResult generateLoopNestUsingForOp( 39776ead96cSMaheshRavishankar RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 39876ead96cSMaheshRavishankar ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors, 39976ead96cSMaheshRavishankar YieldTiledValuesFn yieldTiledValuesFn, 40076ead96cSMaheshRavishankar SmallVector<LoopLikeOpInterface> &loops) { 40176ead96cSMaheshRavishankar assert(!loopRanges.empty() && "unexpected empty loop ranges"); 402170a25a7SMaheshRavishankar assert(loopRanges.size() == tileSizes.size() && 403cf6a7c19SMahesh Ravishankar "expected as many tile sizes as loop ranges"); 40476ead96cSMaheshRavishankar OpBuilder::InsertionGuard guard(rewriter); 4056740d701SMaheshRavishankar 4066740d701SMaheshRavishankar SmallVector<OpFoldResult> lbs, ubs, steps; 4076740d701SMaheshRavishankar std::tie(lbs, ubs, steps) = 4086740d701SMaheshRavishankar getLoopBounds(rewriter, loc, loopRanges, tileSizes); 4096740d701SMaheshRavishankar SmallVector<Value> lbVals = 4106740d701SMaheshRavishankar getValueOrCreateConstantIndexOp(rewriter, loc, lbs); 4116740d701SMaheshRavishankar SmallVector<Value> ubVals = 4126740d701SMaheshRavishankar getValueOrCreateConstantIndexOp(rewriter, loc, ubs); 4136740d701SMaheshRavishankar SmallVector<Value> stepVals = 4146740d701SMaheshRavishankar getValueOrCreateConstantIndexOp(rewriter, loc, steps); 4156740d701SMaheshRavishankar 41676ead96cSMaheshRavishankar SmallVector<Value> ivs; 4176740d701SMaheshRavishankar for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) { 41876ead96cSMaheshRavishankar auto loop = 41976ead96cSMaheshRavishankar rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors, 42076ead96cSMaheshRavishankar [](OpBuilder &bodyBuilder, Location bodyLoc, 42176ead96cSMaheshRavishankar Value iv, ValueRange /*iterArgs*/) {}); 422cf6a7c19SMahesh Ravishankar loops.push_back(loop); 42376ead96cSMaheshRavishankar ivs.push_back(loop.getInductionVar()); 42476ead96cSMaheshRavishankar rewriter.setInsertionPointToEnd(loop.getBody()); 4254a020018SMaheshRavishankar destinationTensors = loop.getRegionIterArgs(); 4264a020018SMaheshRavishankar } 4274a020018SMaheshRavishankar 42876ead96cSMaheshRavishankar SmallVector<Value> tiledResults; 42976ead96cSMaheshRavishankar SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 43076ead96cSMaheshRavishankar if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors, 43176ead96cSMaheshRavishankar tiledResults, resultOffsets, resultSizes))) { 43276ead96cSMaheshRavishankar return rewriter.notifyMatchFailure( 43376ead96cSMaheshRavishankar loc, "failed to generate inner tile loop body"); 43476ead96cSMaheshRavishankar } 43576ead96cSMaheshRavishankar if (loops.empty()) 43676ead96cSMaheshRavishankar return success(); 43776ead96cSMaheshRavishankar 4389329b20dSKunwar Grover assert(tiledResults.size() == destinationTensors.size() && 4399329b20dSKunwar Grover "Number of results of body should be equal to number of iter args"); 4409329b20dSKunwar Grover 44176ead96cSMaheshRavishankar // 6. Yield all the results of the tiled operation. 44276ead96cSMaheshRavishankar SmallVector<Value> yieldedValues; 44376ead96cSMaheshRavishankar for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : 44476ead96cSMaheshRavishankar llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, 44576ead96cSMaheshRavishankar resultSizes)) { 44676ead96cSMaheshRavishankar SmallVector<OpFoldResult> resultStride(resultOffset.size(), 44776ead96cSMaheshRavishankar rewriter.getIndexAttr(1)); 44876ead96cSMaheshRavishankar auto insertSlice = rewriter.create<tensor::InsertSliceOp>( 44976ead96cSMaheshRavishankar loc, tiledValue, destinationTensor, resultOffset, resultSize, 45076ead96cSMaheshRavishankar resultStride); 45176ead96cSMaheshRavishankar yieldedValues.push_back(insertSlice); 45276ead96cSMaheshRavishankar } 45376ead96cSMaheshRavishankar rewriter.create<scf::YieldOp>(loc, yieldedValues); 45476ead96cSMaheshRavishankar 4554a020018SMaheshRavishankar // Add the scf.yield operations for all the outer loops. 4564a020018SMaheshRavishankar for (auto [outerLoop, innerLoop] : 4574a020018SMaheshRavishankar llvm::zip_equal(MutableArrayRef(loops).drop_back(), 4584a020018SMaheshRavishankar MutableArrayRef(loops).drop_front())) { 45976ead96cSMaheshRavishankar rewriter.setInsertionPointToEnd( 46076ead96cSMaheshRavishankar cast<scf::ForOp>(outerLoop.getOperation()).getBody()); 46176ead96cSMaheshRavishankar rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults()); 4624a020018SMaheshRavishankar } 46376ead96cSMaheshRavishankar return success(); 464cf6a7c19SMahesh Ravishankar } 46576ead96cSMaheshRavishankar 46676ead96cSMaheshRavishankar /// Generate the tile-loop nest using `scf.forall` operation. 46776ead96cSMaheshRavishankar /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 46876ead96cSMaheshRavishankar /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 46976ead96cSMaheshRavishankar /// - `destinationTensors` are the init values to use for the outer most loop. 47076ead96cSMaheshRavishankar /// - `mappingVector` is the mapping attributes to use for loop construction. 47176ead96cSMaheshRavishankar /// Can be empty. 47276ead96cSMaheshRavishankar /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 47376ead96cSMaheshRavishankar /// most 47476ead96cSMaheshRavishankar /// loop. 47576ead96cSMaheshRavishankar /// - `loops` is an in-out parameter into which the generated loops are 47676ead96cSMaheshRavishankar /// populated. 47776ead96cSMaheshRavishankar static LogicalResult generateLoopNestUsingForallOp( 47876ead96cSMaheshRavishankar RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, 4796740d701SMaheshRavishankar ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads, 4806740d701SMaheshRavishankar ArrayRef<Attribute> mappingVector, ValueRange destinationTensors, 4816740d701SMaheshRavishankar YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { 48276ead96cSMaheshRavishankar assert(!loopRanges.empty() && "unexpected empty loop ranges"); 48376ead96cSMaheshRavishankar assert(loopRanges.size() == tileSizes.size() && 48476ead96cSMaheshRavishankar "expected as many tile sizes as loop ranges"); 48576ead96cSMaheshRavishankar OpBuilder::InsertionGuard guard(rewriter); 48676ead96cSMaheshRavishankar SmallVector<OpFoldResult> offsets(loopRanges.size()), 48776ead96cSMaheshRavishankar sizes(loopRanges.size()); 48876ead96cSMaheshRavishankar 48976ead96cSMaheshRavishankar std::optional<ArrayAttr> mappingAttr; 49076ead96cSMaheshRavishankar if (!mappingVector.empty()) 49176ead96cSMaheshRavishankar mappingAttr = rewriter.getArrayAttr(mappingVector); 49276ead96cSMaheshRavishankar 4936740d701SMaheshRavishankar scf::ForallOp forallOp; 4946740d701SMaheshRavishankar bool useNumThreads = !numThreads.empty(); 4956740d701SMaheshRavishankar 4966740d701SMaheshRavishankar if (useNumThreads) { 4976740d701SMaheshRavishankar // Prune the zero numthreads. 4986740d701SMaheshRavishankar SmallVector<OpFoldResult> nonZeroNumThreads; 4996740d701SMaheshRavishankar for (auto nt : numThreads) { 5006740d701SMaheshRavishankar if (isConstantIntValue(nt, 0)) 5016740d701SMaheshRavishankar continue; 5026740d701SMaheshRavishankar nonZeroNumThreads.push_back(nt); 5036740d701SMaheshRavishankar } 5046740d701SMaheshRavishankar forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads, 5056740d701SMaheshRavishankar destinationTensors, mappingAttr); 5066740d701SMaheshRavishankar } else { 5076740d701SMaheshRavishankar SmallVector<OpFoldResult> lbs, ubs, steps; 5086740d701SMaheshRavishankar std::tie(lbs, ubs, steps) = 5096740d701SMaheshRavishankar getLoopBounds(rewriter, loc, loopRanges, tileSizes); 5106740d701SMaheshRavishankar forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, 5116740d701SMaheshRavishankar destinationTensors, mappingAttr); 5126740d701SMaheshRavishankar } 51376ead96cSMaheshRavishankar loops.push_back(forallOp); 51476ead96cSMaheshRavishankar 51576ead96cSMaheshRavishankar rewriter.setInsertionPoint(forallOp.getTerminator()); 51676ead96cSMaheshRavishankar destinationTensors = forallOp.getRegionOutArgs(); 51776ead96cSMaheshRavishankar 51876ead96cSMaheshRavishankar SmallVector<Value> tiledResults; 51976ead96cSMaheshRavishankar SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 52076ead96cSMaheshRavishankar if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(), 52176ead96cSMaheshRavishankar destinationTensors, tiledResults, resultOffsets, 52276ead96cSMaheshRavishankar resultSizes))) 52376ead96cSMaheshRavishankar return rewriter.notifyMatchFailure(loc, "failed to generate loop body"); 52476ead96cSMaheshRavishankar 52576ead96cSMaheshRavishankar rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); 52676ead96cSMaheshRavishankar for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : 52776ead96cSMaheshRavishankar llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, 52876ead96cSMaheshRavishankar resultSizes)) { 52976ead96cSMaheshRavishankar SmallVector<OpFoldResult> resultStride(resultOffset.size(), 53076ead96cSMaheshRavishankar rewriter.getIndexAttr(1)); 53176ead96cSMaheshRavishankar 53276ead96cSMaheshRavishankar rewriter.create<tensor::ParallelInsertSliceOp>( 53376ead96cSMaheshRavishankar loc, tiledValue, destinationTensor, resultOffset, resultSize, 53476ead96cSMaheshRavishankar resultStride); 53576ead96cSMaheshRavishankar } 53676ead96cSMaheshRavishankar return success(); 53776ead96cSMaheshRavishankar } 53876ead96cSMaheshRavishankar 53976ead96cSMaheshRavishankar /// Generate the tile-loop nest using the loop construct specifed in `options`. 54076ead96cSMaheshRavishankar /// - `options`: Tiling options specified. 54176ead96cSMaheshRavishankar /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 54276ead96cSMaheshRavishankar /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. 54376ead96cSMaheshRavishankar /// - `destinationTensors` are the init values to use for the outer most loop. 54476ead96cSMaheshRavishankar /// - `yieldTiledValuesFn` is called to generated the loop body of the inner 54576ead96cSMaheshRavishankar /// most 54676ead96cSMaheshRavishankar /// loop. 54776ead96cSMaheshRavishankar /// - `loops` is an in-out parameter into which the generated loops are 54876ead96cSMaheshRavishankar /// populated. 5496740d701SMaheshRavishankar static LogicalResult generateLoopNest( 5506740d701SMaheshRavishankar RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, 5516740d701SMaheshRavishankar ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes, 5526740d701SMaheshRavishankar ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors, 5536740d701SMaheshRavishankar YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { 55476ead96cSMaheshRavishankar // If the tile sizes are all zero, no loops are generated. Just call the 55576ead96cSMaheshRavishankar // callback function to handle untiled case. 55676ead96cSMaheshRavishankar if (llvm::all_of(tileSizes, isZeroIndex)) { 55776ead96cSMaheshRavishankar SmallVector<Value> tiledResults; 55876ead96cSMaheshRavishankar SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 55976ead96cSMaheshRavishankar return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, 56076ead96cSMaheshRavishankar tiledResults, resultOffsets, resultSizes); 56176ead96cSMaheshRavishankar } 56276ead96cSMaheshRavishankar if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { 56376ead96cSMaheshRavishankar return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, 56476ead96cSMaheshRavishankar destinationTensors, tiledBodyFn, loops); 56576ead96cSMaheshRavishankar } 56676ead96cSMaheshRavishankar if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { 56776ead96cSMaheshRavishankar return generateLoopNestUsingForallOp( 5686740d701SMaheshRavishankar rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector, 56976ead96cSMaheshRavishankar destinationTensors, tiledBodyFn, loops); 57076ead96cSMaheshRavishankar } 57176ead96cSMaheshRavishankar return rewriter.notifyMatchFailure(loc, "unhandled loop type"); 57276ead96cSMaheshRavishankar } 57376ead96cSMaheshRavishankar 5744b563458SKunwar Grover static FailureOr<SmallVector<Value>> 5754b563458SKunwar Grover createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, 5764b563458SKunwar Grover ArrayRef<OpFoldResult> tileSizes, 5774b563458SKunwar Grover const scf::SCFTilingOptions &options) { 5784b563458SKunwar Grover SmallVector<Value> initTensors; 5794b563458SKunwar Grover Location loc = op->getLoc(); 5804b563458SKunwar Grover switch (options.reductionStrategy) { 5814b563458SKunwar Grover case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: 5824b563458SKunwar Grover if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors))) 5834b563458SKunwar Grover return failure(); 5844b563458SKunwar Grover return initTensors; 5854b563458SKunwar Grover case scf::SCFTilingOptions::ReductionTilingStrategy:: 5864b563458SKunwar Grover PartialReductionOuterReduction: { 5874b563458SKunwar Grover auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); 5884b563458SKunwar Grover if (!redOp) { 5894b563458SKunwar Grover return rewriter.notifyMatchFailure( 5904b563458SKunwar Grover op, "PartialReductionOuterReduction tiling strategy is only supported" 5914b563458SKunwar Grover "for operations implementing PartialReductionOpInterface"); 5924b563458SKunwar Grover } 5934b563458SKunwar Grover // Get reduction dimensions. 5944b563458SKunwar Grover // TODO: PartialReductionOpInterface should really query TilingInterface 5954b563458SKunwar Grover // itself and find reduction dimensions. 5964b563458SKunwar Grover SmallVector<int> reductionDims; 5974b563458SKunwar Grover for (auto [idx, iteratorType] : 5984b563458SKunwar Grover llvm::enumerate(op.getLoopIteratorTypes())) { 5994b563458SKunwar Grover if (iteratorType == utils::IteratorType::reduction) 6004b563458SKunwar Grover reductionDims.push_back(idx); 6014b563458SKunwar Grover } 6024b563458SKunwar Grover return redOp.generateInitialTensorForPartialReduction( 6034b563458SKunwar Grover rewriter, loc, tileSizes, reductionDims); 6044b563458SKunwar Grover } 6054b563458SKunwar Grover default: 6064b563458SKunwar Grover return rewriter.notifyMatchFailure(op, 6074b563458SKunwar Grover "unhandled reduction tiling strategy"); 6084b563458SKunwar Grover } 6094b563458SKunwar Grover } 6104b563458SKunwar Grover 6114b563458SKunwar Grover static FailureOr<TilingResult> 6124b563458SKunwar Grover getTiledImplementation(RewriterBase &rewriter, TilingInterface op, 6134b563458SKunwar Grover ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets, 6144b563458SKunwar Grover ArrayRef<OpFoldResult> sizes, 6154b563458SKunwar Grover const scf::SCFTilingOptions &options) { 6164b563458SKunwar Grover switch (options.reductionStrategy) { 6174b563458SKunwar Grover case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: 6184b563458SKunwar Grover return op.getTiledImplementation(rewriter, offsets, sizes); 6194b563458SKunwar Grover case scf::SCFTilingOptions::ReductionTilingStrategy:: 6204b563458SKunwar Grover PartialReductionOuterReduction: { 6214b563458SKunwar Grover auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); 6224b563458SKunwar Grover if (!redOp) { 6234b563458SKunwar Grover return rewriter.notifyMatchFailure( 6244b563458SKunwar Grover op, "PartialReductionOuterReduction tiling strategy is only " 6254b563458SKunwar Grover "supported for operations " 6264b563458SKunwar Grover "implementing PartialReductionOpInterface"); 6274b563458SKunwar Grover } 6284b563458SKunwar Grover // Get reduction dimensions. 6294b563458SKunwar Grover // TODO: PartialReductionOpInterface should really query TilingInterface 6304b563458SKunwar Grover // itself and find reduction dimensions. 6314b563458SKunwar Grover SmallVector<int> reductionDims; 6324b563458SKunwar Grover for (auto [idx, iteratorType] : 6334b563458SKunwar Grover llvm::enumerate(op.getLoopIteratorTypes())) { 6344b563458SKunwar Grover if (iteratorType == utils::IteratorType::reduction) 6354b563458SKunwar Grover reductionDims.push_back(idx); 6364b563458SKunwar Grover } 6374b563458SKunwar Grover return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg, 6384b563458SKunwar Grover offsets, sizes, reductionDims); 6394b563458SKunwar Grover } 6404b563458SKunwar Grover default: 6414b563458SKunwar Grover return rewriter.notifyMatchFailure(op, 6424b563458SKunwar Grover "unhandled reduction tiling strategy"); 6434b563458SKunwar Grover } 6444b563458SKunwar Grover } 6454b563458SKunwar Grover 6464b563458SKunwar Grover static LogicalResult 6474b563458SKunwar Grover getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, 6484b563458SKunwar Grover TilingInterface op, ArrayRef<OpFoldResult> offsets, 6494b563458SKunwar Grover ArrayRef<OpFoldResult> sizes, 6504b563458SKunwar Grover SmallVector<OpFoldResult> &resultOffset, 6514b563458SKunwar Grover SmallVector<OpFoldResult> &resultSize, 6524b563458SKunwar Grover const scf::SCFTilingOptions &options) { 6534b563458SKunwar Grover 6544b563458SKunwar Grover switch (options.reductionStrategy) { 6554b563458SKunwar Grover case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: 6564b563458SKunwar Grover return op.getResultTilePosition(rewriter, index, offsets, sizes, 6574b563458SKunwar Grover resultOffset, resultSize); 6584b563458SKunwar Grover case scf::SCFTilingOptions::ReductionTilingStrategy:: 6594b563458SKunwar Grover PartialReductionOuterReduction: { 660*91bbebc7SKunwar Grover auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); 661*91bbebc7SKunwar Grover if (!redOp) { 662*91bbebc7SKunwar Grover return rewriter.notifyMatchFailure( 663*91bbebc7SKunwar Grover op, "PartialReductionOuterReduction tiling strategy is only supported" 664*91bbebc7SKunwar Grover "for operations implementing PartialReductionOpInterface"); 6654b563458SKunwar Grover } 666*91bbebc7SKunwar Grover // Get reduction dimensions. 667*91bbebc7SKunwar Grover // TODO: PartialReductionOpInterface should really query TilingInterface 668*91bbebc7SKunwar Grover // itself and find reduction dimensions. 669*91bbebc7SKunwar Grover SmallVector<int> reductionDims; 670*91bbebc7SKunwar Grover for (auto [idx, iteratorType] : 671*91bbebc7SKunwar Grover llvm::enumerate(op.getLoopIteratorTypes())) { 672*91bbebc7SKunwar Grover if (iteratorType == utils::IteratorType::reduction) 673*91bbebc7SKunwar Grover reductionDims.push_back(idx); 674*91bbebc7SKunwar Grover } 675*91bbebc7SKunwar Grover return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes, 676*91bbebc7SKunwar Grover resultOffset, resultSize, 677*91bbebc7SKunwar Grover reductionDims); 678*91bbebc7SKunwar Grover } 6794b563458SKunwar Grover default: 6804b563458SKunwar Grover return rewriter.notifyMatchFailure(op, 6814b563458SKunwar Grover "unhandled reduction tiling strategy"); 6824b563458SKunwar Grover } 6834b563458SKunwar Grover } 6844b563458SKunwar Grover 6854b563458SKunwar Grover static FailureOr<MergeResult> 6864b563458SKunwar Grover mergeTilingResults(RewriterBase &rewriter, TilingInterface op, 6874b563458SKunwar Grover ValueRange partialResults, 6884b563458SKunwar Grover const scf::SCFTilingOptions &options) { 6894b563458SKunwar Grover switch (options.reductionStrategy) { 6904b563458SKunwar Grover case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: 6914b563458SKunwar Grover // No need to merge results for reduction tiling strategy. 6924b563458SKunwar Grover return MergeResult{{}, partialResults}; 6934b563458SKunwar Grover case scf::SCFTilingOptions::ReductionTilingStrategy:: 6944b563458SKunwar Grover PartialReductionOuterReduction: { 6954b563458SKunwar Grover auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); 6964b563458SKunwar Grover if (!redOp) { 6974b563458SKunwar Grover return rewriter.notifyMatchFailure( 6984b563458SKunwar Grover op, "PartialReductionOuterReduction tiling strategy is only " 6994b563458SKunwar Grover "supported for operations " 7004b563458SKunwar Grover "implementing PartialReductionOpInterface"); 7014b563458SKunwar Grover } 7024b563458SKunwar Grover // Get reduction dimensions. 7034b563458SKunwar Grover // TODO: PartialReductionOpInterface should really query TilingInterface 7044b563458SKunwar Grover // itself and find reduction dimensions. 7054b563458SKunwar Grover SmallVector<int> reductionDims; 7064b563458SKunwar Grover for (auto [idx, iteratorType] : 7074b563458SKunwar Grover llvm::enumerate(op.getLoopIteratorTypes())) { 7084b563458SKunwar Grover if (iteratorType == utils::IteratorType::reduction) 7094b563458SKunwar Grover reductionDims.push_back(idx); 7104b563458SKunwar Grover } 7114b563458SKunwar Grover return redOp.mergeReductions(rewriter, op.getLoc(), partialResults, 7124b563458SKunwar Grover reductionDims); 7134b563458SKunwar Grover } 7144b563458SKunwar Grover default: 7154b563458SKunwar Grover return rewriter.notifyMatchFailure(op, 7164b563458SKunwar Grover "unhandled reduction tiling strategy"); 7174b563458SKunwar Grover } 7184b563458SKunwar Grover } 7194b563458SKunwar Grover 72076ead96cSMaheshRavishankar /// Append the specified additional `newInitOperands` operands to the 72176ead96cSMaheshRavishankar /// loops existing `init` operands (or similar), and replace `loopOp` with 72276ead96cSMaheshRavishankar /// the new loop that has the additional init operands. The loop body of 72376ead96cSMaheshRavishankar /// this loop is moved over to the new loop. `yieldTiledValuesFn` 72476ead96cSMaheshRavishankar /// is called to get the new tiled values returned, and the offset 72576ead96cSMaheshRavishankar /// and sizes at which the tiled value is inserted into the 72676ead96cSMaheshRavishankar /// new region iter_args that correspond to the newly added init operands. 72776ead96cSMaheshRavishankar template <typename LoopType> 72876ead96cSMaheshRavishankar FailureOr<LoopLikeOpInterface> 72976ead96cSMaheshRavishankar yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, 73076ead96cSMaheshRavishankar ValueRange newInitOperands, 73176ead96cSMaheshRavishankar YieldTiledValuesFn yieldTiledValuesFn) { 73276ead96cSMaheshRavishankar return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); 73376ead96cSMaheshRavishankar } 73476ead96cSMaheshRavishankar 73576ead96cSMaheshRavishankar /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`. 73676ead96cSMaheshRavishankar template <> 73776ead96cSMaheshRavishankar FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( 73876ead96cSMaheshRavishankar scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, 73976ead96cSMaheshRavishankar YieldTiledValuesFn yieldTiledValuesFn) { 74076ead96cSMaheshRavishankar OpBuilder::InsertionGuard g(rewriter); 74176ead96cSMaheshRavishankar Location loc = loopOp.getLoc(); 74276ead96cSMaheshRavishankar rewriter.setInsertionPoint(loopOp); 74376ead96cSMaheshRavishankar 74476ead96cSMaheshRavishankar auto inits = llvm::to_vector(loopOp.getInitArgs()); 74576ead96cSMaheshRavishankar inits.append(newInitOperands.begin(), newInitOperands.end()); 74676ead96cSMaheshRavishankar auto newLoop = rewriter.create<scf::ForOp>( 74776ead96cSMaheshRavishankar loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(), 74876ead96cSMaheshRavishankar inits, [](OpBuilder &, Location, Value, ValueRange) {}); 74976ead96cSMaheshRavishankar 75076ead96cSMaheshRavishankar // Move the loop body to the new op. 75176ead96cSMaheshRavishankar Block *loopBody = loopOp.getBody(); 75276ead96cSMaheshRavishankar Block *newLoopBody = newLoop.getBody(); 75376ead96cSMaheshRavishankar rewriter.mergeBlocks( 75476ead96cSMaheshRavishankar loopBody, newLoopBody, 75576ead96cSMaheshRavishankar newLoopBody->getArguments().take_front(loopBody->getNumArguments())); 75676ead96cSMaheshRavishankar 75776ead96cSMaheshRavishankar auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator()); 75876ead96cSMaheshRavishankar rewriter.setInsertionPoint(yieldOp); 75976ead96cSMaheshRavishankar 76076ead96cSMaheshRavishankar SmallVector<Value> tiledValues; 76176ead96cSMaheshRavishankar SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 76276ead96cSMaheshRavishankar ValueRange newRegionIterArgs = 76376ead96cSMaheshRavishankar newLoop.getRegionIterArgs().take_back(newInitOperands.size()); 76476ead96cSMaheshRavishankar if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(), 76576ead96cSMaheshRavishankar newRegionIterArgs, tiledValues, resultOffsets, 76676ead96cSMaheshRavishankar resultSizes))) { 76776ead96cSMaheshRavishankar rewriter.eraseOp(newLoop); 76876ead96cSMaheshRavishankar return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values"); 76976ead96cSMaheshRavishankar } 77076ead96cSMaheshRavishankar 77176ead96cSMaheshRavishankar SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands()); 77276ead96cSMaheshRavishankar for (auto [tiledValue, regionIterArg, resultOffset, resultSize] : 77376ead96cSMaheshRavishankar llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets, 77476ead96cSMaheshRavishankar resultSizes)) { 77576ead96cSMaheshRavishankar SmallVector<OpFoldResult> resultStride(resultOffset.size(), 77676ead96cSMaheshRavishankar rewriter.getIndexAttr(1)); 77776ead96cSMaheshRavishankar Value insert = rewriter.create<tensor::InsertSliceOp>( 77876ead96cSMaheshRavishankar yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize, 77976ead96cSMaheshRavishankar resultStride); 78076ead96cSMaheshRavishankar newYieldValues.push_back(insert); 78176ead96cSMaheshRavishankar } 78276ead96cSMaheshRavishankar 78376ead96cSMaheshRavishankar rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 78476ead96cSMaheshRavishankar rewriter.replaceOp(loopOp, 78576ead96cSMaheshRavishankar newLoop->getResults().take_front(loopOp.getNumResults())); 78676ead96cSMaheshRavishankar return cast<LoopLikeOpInterface>(newLoop.getOperation()); 78776ead96cSMaheshRavishankar } 78876ead96cSMaheshRavishankar 78976ead96cSMaheshRavishankar /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall` 79076ead96cSMaheshRavishankar template <> 79176ead96cSMaheshRavishankar FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>( 79276ead96cSMaheshRavishankar scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, 79376ead96cSMaheshRavishankar YieldTiledValuesFn yieldTiledValuesFn) { 79476ead96cSMaheshRavishankar OpBuilder::InsertionGuard g(rewriter); 79576ead96cSMaheshRavishankar Location loc = loopOp.getLoc(); 79676ead96cSMaheshRavishankar rewriter.setInsertionPoint(loopOp); 79776ead96cSMaheshRavishankar auto inits = llvm::to_vector(loopOp.getOutputs()); 79876ead96cSMaheshRavishankar inits.append(newInitOperands.begin(), newInitOperands.end()); 79976ead96cSMaheshRavishankar auto newLoop = rewriter.create<scf::ForallOp>( 80076ead96cSMaheshRavishankar loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(), 80176ead96cSMaheshRavishankar loopOp.getMixedStep(), inits, loopOp.getMapping(), 80276ead96cSMaheshRavishankar [](OpBuilder &, Location, ValueRange) {}); 80376ead96cSMaheshRavishankar 80476ead96cSMaheshRavishankar // Move the region of the current block to the newly created op. 80576ead96cSMaheshRavishankar Block *loopBody = loopOp.getBody(); 80676ead96cSMaheshRavishankar Block *newLoopBody = newLoop.getBody(); 80776ead96cSMaheshRavishankar rewriter.mergeBlocks( 80876ead96cSMaheshRavishankar loopBody, newLoopBody, 80976ead96cSMaheshRavishankar newLoopBody->getArguments().take_front(loopBody->getNumArguments())); 81076ead96cSMaheshRavishankar 81176ead96cSMaheshRavishankar auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator()); 81276ead96cSMaheshRavishankar rewriter.setInsertionPoint(terminator); 81376ead96cSMaheshRavishankar SmallVector<Value> tiledValues; 81476ead96cSMaheshRavishankar SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; 81576ead96cSMaheshRavishankar ValueRange regionIterArgs = 81676ead96cSMaheshRavishankar newLoop.getRegionIterArgs().take_back(newInitOperands.size()); 81776ead96cSMaheshRavishankar if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(), 81876ead96cSMaheshRavishankar regionIterArgs, tiledValues, resultOffsets, 81976ead96cSMaheshRavishankar resultSizes))) { 82076ead96cSMaheshRavishankar rewriter.eraseOp(newLoop); 82176ead96cSMaheshRavishankar return rewriter.notifyMatchFailure(loopOp, 82276ead96cSMaheshRavishankar "failed to get yielded tiled values"); 82376ead96cSMaheshRavishankar } 82476ead96cSMaheshRavishankar 82576ead96cSMaheshRavishankar // Update the terminator. 82676ead96cSMaheshRavishankar rewriter.setInsertionPointToEnd(terminator.getBody()); 82776ead96cSMaheshRavishankar 82876ead96cSMaheshRavishankar for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal( 82976ead96cSMaheshRavishankar tiledValues, regionIterArgs, resultOffsets, resultSizes)) { 83076ead96cSMaheshRavishankar SmallVector<OpFoldResult> resultStride(resultOffset.size(), 83176ead96cSMaheshRavishankar rewriter.getIndexAttr(1)); 83276ead96cSMaheshRavishankar rewriter.create<tensor::ParallelInsertSliceOp>( 83376ead96cSMaheshRavishankar terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize, 83476ead96cSMaheshRavishankar resultStride); 83576ead96cSMaheshRavishankar } 83676ead96cSMaheshRavishankar 83776ead96cSMaheshRavishankar rewriter.replaceOp(loopOp, 83876ead96cSMaheshRavishankar newLoop->getResults().take_front(loopOp.getNumResults())); 83976ead96cSMaheshRavishankar return cast<LoopLikeOpInterface>(newLoop.getOperation()); 84076ead96cSMaheshRavishankar } 84176ead96cSMaheshRavishankar 84276ead96cSMaheshRavishankar /// Implementation of `yieldTiledValuesAndReplaceLoop` for 84376ead96cSMaheshRavishankar /// `LoopLikeOpInterface`, that just dispatches to the implementation for each 84476ead96cSMaheshRavishankar /// supported loop type. 84576ead96cSMaheshRavishankar FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop( 84676ead96cSMaheshRavishankar LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter, 84776ead96cSMaheshRavishankar ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { 84876ead96cSMaheshRavishankar return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>( 84976ead96cSMaheshRavishankar loopLikeOp.getOperation()) 85076ead96cSMaheshRavishankar .Case<scf::ForOp, scf::ForallOp>( 85176ead96cSMaheshRavishankar [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { 85276ead96cSMaheshRavishankar return yieldTiledValuesAndReplaceLoop( 85376ead96cSMaheshRavishankar loopOp, rewriter, newInitOperands, yieldTiledValuesFn); 85476ead96cSMaheshRavishankar }) 85576ead96cSMaheshRavishankar .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { 85676ead96cSMaheshRavishankar return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); 85776ead96cSMaheshRavishankar }); 858cf6a7c19SMahesh Ravishankar } 859cf6a7c19SMahesh Ravishankar 8604b563458SKunwar Grover /// Method to add new init values to a loop nest. Updates `loops` in-place 8614b563458SKunwar Grover /// with new loops that use the `newInitValues`. The outer-loops are updated 8624b563458SKunwar Grover /// to yield the new result values of the inner loop. For the innermost loop, 8634b563458SKunwar Grover /// the call back `getNewYields` is invoked to get the additional values to 8644b563458SKunwar Grover /// yield form the innermost loop. 86576ead96cSMaheshRavishankar static LogicalResult addInitOperandsToLoopNest( 86676ead96cSMaheshRavishankar RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops, 86776ead96cSMaheshRavishankar ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) { 8684a020018SMaheshRavishankar SmallVector<scf::ForOp> newLoops; 8694a020018SMaheshRavishankar if (loops.empty()) 87076ead96cSMaheshRavishankar return success(); 8714a020018SMaheshRavishankar OpBuilder::InsertionGuard g(rewriter); 8724a020018SMaheshRavishankar rewriter.setInsertionPoint(loops.front()); 87376ead96cSMaheshRavishankar 87476ead96cSMaheshRavishankar SmallVector<Value> ivs; 87576ead96cSMaheshRavishankar for (auto &loop : loops.drop_back()) { 8764a020018SMaheshRavishankar rewriter.setInsertionPoint(loop); 87797f91982SMahesh Ravishankar 87876ead96cSMaheshRavishankar // if loops.size() > 1 we assume that scf.for is used for the loops. 87976ead96cSMaheshRavishankar auto forLoop = cast<scf::ForOp>(loop.getOperation()); 88076ead96cSMaheshRavishankar 8814a020018SMaheshRavishankar // Create a new loop with the new init values for this loop. 88276ead96cSMaheshRavishankar SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs()); 8834a020018SMaheshRavishankar newInits.append(newInitValues.begin(), newInitValues.end()); 8844a020018SMaheshRavishankar auto newLoop = rewriter.create<scf::ForOp>( 88576ead96cSMaheshRavishankar forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), 88676ead96cSMaheshRavishankar forLoop.getStep(), newInits, 8874a020018SMaheshRavishankar [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); 8884a020018SMaheshRavishankar 8894a020018SMaheshRavishankar // Merge the body of the new loop with the body of the old loops. 8904a020018SMaheshRavishankar SmallVector<Value> sourceBlockArgs; 8914a020018SMaheshRavishankar sourceBlockArgs.push_back(newLoop.getInductionVar()); 8924a020018SMaheshRavishankar auto newRegionIterArgs = newLoop.getRegionIterArgs(); 8934a020018SMaheshRavishankar sourceBlockArgs.append( 8944a020018SMaheshRavishankar newRegionIterArgs.begin(), 89576ead96cSMaheshRavishankar std::next(newRegionIterArgs.begin(), forLoop.getNumResults())); 89676ead96cSMaheshRavishankar rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs); 89776ead96cSMaheshRavishankar rewriter.replaceOp( 89876ead96cSMaheshRavishankar forLoop, newLoop.getResults().take_front(forLoop.getNumResults())); 8994a020018SMaheshRavishankar loop = newLoop; 90076ead96cSMaheshRavishankar ivs.push_back(newLoop.getInductionVar()); 9014a020018SMaheshRavishankar newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size()); 90297f91982SMahesh Ravishankar } 90397f91982SMahesh Ravishankar 9044a020018SMaheshRavishankar // Update the loop body of the innermost loop to get new yield values. 90576ead96cSMaheshRavishankar LoopLikeOpInterface innerMostLoop = loops.back(); 90676ead96cSMaheshRavishankar FailureOr<LoopLikeOpInterface> newInnerMostLoop = 90776ead96cSMaheshRavishankar yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues, 90876ead96cSMaheshRavishankar getNewTiledYieldsFn); 90976ead96cSMaheshRavishankar 91076ead96cSMaheshRavishankar if (failed(newInnerMostLoop)) 91176ead96cSMaheshRavishankar return innerMostLoop.emitOpError("failed to return additional yields"); 91276ead96cSMaheshRavishankar loops.back() = newInnerMostLoop.value(); 9134a020018SMaheshRavishankar 9144a020018SMaheshRavishankar // Make all other loops except the innermost loops yield the values returned 9154a020018SMaheshRavishankar // by the inner loop. 9164a020018SMaheshRavishankar for (auto [outerLoop, innerLoop] : 9174a020018SMaheshRavishankar llvm::zip_equal(loops.drop_back(), loops.drop_front())) { 91876ead96cSMaheshRavishankar // Again assume that all the outer loops are scf.for operations. 91976ead96cSMaheshRavishankar auto outerForLoop = cast<scf::ForOp>(outerLoop); 9204a020018SMaheshRavishankar auto outerLoopYield = 92176ead96cSMaheshRavishankar cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); 9224a020018SMaheshRavishankar SmallVector<Value> newYields = 9234a020018SMaheshRavishankar llvm::to_vector(outerLoopYield.getOperands()); 9244a020018SMaheshRavishankar ValueRange additionalYields = 92576ead96cSMaheshRavishankar innerLoop->getResults().take_back(newInitValues.size()); 9264a020018SMaheshRavishankar newYields.append(additionalYields.begin(), additionalYields.end()); 9274a020018SMaheshRavishankar rewriter.setInsertionPoint(outerLoopYield); 9284a020018SMaheshRavishankar rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields); 92994f2a6ddSMahesh Ravishankar } 93076ead96cSMaheshRavishankar return success(); 931809e3d8cSMahesh Ravishankar } 93294f2a6ddSMahesh Ravishankar 93397f91982SMahesh Ravishankar /// Implementation of tiling transformation of `op` that implements the 93497f91982SMahesh Ravishankar /// `TilingInterface` using `scf.for` to iterate over the tiles. 935cf6a7c19SMahesh Ravishankar FailureOr<scf::SCFTilingResult> 93676ead96cSMaheshRavishankar mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, 9376d4baa74SMehdi Amini const scf::SCFTilingOptions &options) { 9386740d701SMaheshRavishankar if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) { 9396740d701SMaheshRavishankar return failure(); 9406740d701SMaheshRavishankar } 9416740d701SMaheshRavishankar 942cf6a7c19SMahesh Ravishankar OpBuilder::InsertionGuard guard(rewriter); 943cf6a7c19SMahesh Ravishankar rewriter.setInsertionPointAfter(op); 944cf6a7c19SMahesh Ravishankar 945cf6a7c19SMahesh Ravishankar // 1. Get the range of the loops that are represented by the operation. 946cf6a7c19SMahesh Ravishankar SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 947aa2a96a2SMaheshRavishankar 9486740d701SMaheshRavishankar // 2. Materialize the tile sizes and/or number of threads; 9496740d701SMaheshRavishankar SmallVector<OpFoldResult> tileSizes, numThreads; 9506740d701SMaheshRavishankar std::tie(tileSizes, numThreads) = 9516740d701SMaheshRavishankar getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options); 9526740d701SMaheshRavishankar 9536740d701SMaheshRavishankar // Check if it is safe to tile. This is hold over from previous iterations 9546740d701SMaheshRavishankar // of tile to for-all. Consider dropping it. 9556740d701SMaheshRavishankar if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { 9566740d701SMaheshRavishankar checkSafeToTileToForall(op, tileSizes, numThreads); 957cf6a7c19SMahesh Ravishankar } 958cf6a7c19SMahesh Ravishankar 95976ead96cSMaheshRavishankar // 3. If there is an interchange specified, permute the iteration domain and 960b8a1f00dSMahesh Ravishankar // the tile sizes. 96179150279SNicolas Vasilache SmallVector<int64_t> interchangeVector; 962b8a1f00dSMahesh Ravishankar if (!options.interchangeVector.empty()) { 963b8a1f00dSMahesh Ravishankar interchangeVector = fillInterchangeVector(options.interchangeVector, 964b8a1f00dSMahesh Ravishankar iterationDomain.size()); 9656740d701SMaheshRavishankar assert(isPermutationVector(interchangeVector) && 9666740d701SMaheshRavishankar "expected interchange vector to be a permutation"); 967b8a1f00dSMahesh Ravishankar 968b8a1f00dSMahesh Ravishankar applyPermutationToVector(iterationDomain, interchangeVector); 96976ead96cSMaheshRavishankar applyPermutationToVector(tileSizes, interchangeVector); 9706740d701SMaheshRavishankar if (!numThreads.empty()) 9716740d701SMaheshRavishankar applyPermutationToVector(numThreads, interchangeVector); 972b8a1f00dSMahesh Ravishankar } 973b8a1f00dSMahesh Ravishankar 97476ead96cSMaheshRavishankar FailureOr<TilingResult> tilingResult; 97576ead96cSMaheshRavishankar // 4. Define the lambda function used later to generate the body of the 97676ead96cSMaheshRavishankar // innermost tiled loop. 97776ead96cSMaheshRavishankar YieldTiledValuesFn innerYieldTiledValuesFn = 97876ead96cSMaheshRavishankar [&](RewriterBase &rewriter, Location loc, ValueRange ivs, 97976ead96cSMaheshRavishankar ValueRange regionIterArgs, SmallVector<Value> &tiledResults, 98076ead96cSMaheshRavishankar SmallVector<SmallVector<OpFoldResult>> &resultOffsets, 98176ead96cSMaheshRavishankar SmallVector<SmallVector<OpFoldResult>> &resultSizes) 98276ead96cSMaheshRavishankar -> LogicalResult { 98376ead96cSMaheshRavishankar // 4a. Compute the `offsets` and `sizes` to use for tiling. 98476ead96cSMaheshRavishankar SmallVector<OpFoldResult> offsets, sizes; 9856740d701SMaheshRavishankar std::tie(offsets, sizes) = getTileOffsetAndSizes( 9866740d701SMaheshRavishankar rewriter, loc, ivs, iterationDomain, tileSizes, numThreads); 987cf6a7c19SMahesh Ravishankar 98876ead96cSMaheshRavishankar // 4b. If interchange was provided, apply inverse of the interchange 98976ead96cSMaheshRavishankar // to get back the offsets/sizes in the order to be specified. 990b8a1f00dSMahesh Ravishankar if (!interchangeVector.empty()) { 991b8a1f00dSMahesh Ravishankar auto inversePermutation = invertPermutationVector(interchangeVector); 992b1d3afc9SHanhan Wang applyPermutationToVector(offsets, inversePermutation); 993b1d3afc9SHanhan Wang applyPermutationToVector(sizes, inversePermutation); 994b8a1f00dSMahesh Ravishankar } 995cf6a7c19SMahesh Ravishankar 9964a020018SMaheshRavishankar // 5. Generate the tiled implementation within the inner most loop. 99797f91982SMahesh Ravishankar 9984a020018SMaheshRavishankar // 5a. Clone the operation within the loop body. 9994a020018SMaheshRavishankar auto clonedOp = cast<TilingInterface>( 100076ead96cSMaheshRavishankar cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs)); 10014a020018SMaheshRavishankar 10024b563458SKunwar Grover // 5b. Early return cloned op if tiling is not happening. We can not 10034b563458SKunwar Grover // return the original op because it could lead to `rewriter.replaceOp(op, 10044b563458SKunwar Grover // op->getResults())` and users would get crash. 100576ead96cSMaheshRavishankar if (llvm::all_of(tileSizes, isZeroIndex)) { 100676ead96cSMaheshRavishankar tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); 100776ead96cSMaheshRavishankar tilingResult = 1008d5f0969cSMaheshRavishankar TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(), 1009d5f0969cSMaheshRavishankar /*generatedSlices=*/{}}; 101076ead96cSMaheshRavishankar return success(); 1011899c2bedSHan-Chung Wang } 1012899c2bedSHan-Chung Wang 1013899c2bedSHan-Chung Wang // 5c. Tile the cloned operation. 10144b563458SKunwar Grover tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs, 10154b563458SKunwar Grover offsets, sizes, options); 101676ead96cSMaheshRavishankar if (failed(tilingResult)) { 101776ead96cSMaheshRavishankar rewriter.eraseOp(clonedOp); 101876ead96cSMaheshRavishankar return op.emitOpError("faild to tile operation"); 10194a020018SMaheshRavishankar } 10204a020018SMaheshRavishankar 1021899c2bedSHan-Chung Wang // 5d. Delete the cloned operation. 10224a020018SMaheshRavishankar rewriter.eraseOp(clonedOp); 10234a020018SMaheshRavishankar 102476ead96cSMaheshRavishankar // 5e. Compute the offsets at which the result values are to be inserted 102576ead96cSMaheshRavishankar // back into its destinations. 10264a020018SMaheshRavishankar for (auto [index, tiledValue] : 102776ead96cSMaheshRavishankar llvm::enumerate(tilingResult->tiledValues)) { 102876ead96cSMaheshRavishankar tiledResults.push_back(tiledValue); 102976ead96cSMaheshRavishankar SmallVector<OpFoldResult> resultOffset, resultSize; 10304b563458SKunwar Grover if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets, 10314b563458SKunwar Grover sizes, resultOffset, resultSize, 10324b563458SKunwar Grover options))) { 103376ead96cSMaheshRavishankar for (auto op : tilingResult->tiledOps) { 103476ead96cSMaheshRavishankar rewriter.eraseOp(op); 103576ead96cSMaheshRavishankar } 103697f91982SMahesh Ravishankar return rewriter.notifyMatchFailure( 103797f91982SMahesh Ravishankar op, "failed to get slice of result produced"); 103897f91982SMahesh Ravishankar } 103976ead96cSMaheshRavishankar resultOffsets.emplace_back(std::move(resultOffset)); 104076ead96cSMaheshRavishankar resultSizes.emplace_back(std::move(resultSize)); 104197f91982SMahesh Ravishankar } 104276ead96cSMaheshRavishankar 104376ead96cSMaheshRavishankar return success(); 104476ead96cSMaheshRavishankar }; 104576ead96cSMaheshRavishankar 104676ead96cSMaheshRavishankar // 6. Find the destination tensors to use for the operation. 10474b563458SKunwar Grover FailureOr<SmallVector<Value>> maybeInits = 10484b563458SKunwar Grover createInitialTensorsForTiling(rewriter, op, tileSizes, options); 10494b563458SKunwar Grover if (failed(maybeInits)) { 10504b563458SKunwar Grover return rewriter.notifyMatchFailure( 10514b563458SKunwar Grover op, "unable to create initial tensors for tiling"); 105276ead96cSMaheshRavishankar } 10534b563458SKunwar Grover SmallVector<Value> &initTensors = maybeInits.value(); 105476ead96cSMaheshRavishankar 105576ead96cSMaheshRavishankar // 7. Generate the tiled loops nest using the callback defined above. 105676ead96cSMaheshRavishankar SmallVector<LoopLikeOpInterface> loops; 105776ead96cSMaheshRavishankar if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain, 10584b563458SKunwar Grover tileSizes, numThreads, initTensors, 105976ead96cSMaheshRavishankar innerYieldTiledValuesFn, loops))) 106076ead96cSMaheshRavishankar return op.emitOpError("failed to generate tiling loops"); 106176ead96cSMaheshRavishankar assert(succeeded(tilingResult) && 106276ead96cSMaheshRavishankar "expected tiling result to be computed after loop generation"); 106376ead96cSMaheshRavishankar 10644b563458SKunwar Grover SmallVector<Value> partialResults; 106576ead96cSMaheshRavishankar if (loops.empty()) { 10664b563458SKunwar Grover // If loops are empty, the tiled op is used as the replacement for the 10674b563458SKunwar Grover // untiled op. 10684b563458SKunwar Grover partialResults = tilingResult->tiledValues; 10694b563458SKunwar Grover } else { 10704b563458SKunwar Grover partialResults = llvm::map_to_vector(loops.front()->getResults(), 10714b563458SKunwar Grover [](OpResult r) -> Value { return r; }); 10724b563458SKunwar Grover } 10734b563458SKunwar Grover 10744b563458SKunwar Grover FailureOr<MergeResult> mergeResult = 10754b563458SKunwar Grover mergeTilingResults(rewriter, op, partialResults, options); 10764b563458SKunwar Grover if (failed(mergeResult)) { 10774b563458SKunwar Grover return rewriter.notifyMatchFailure( 10784b563458SKunwar Grover op, "Failed to merge partial results from tiling"); 10794b563458SKunwar Grover } 10804b563458SKunwar Grover 10814b563458SKunwar Grover return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops, 10824b563458SKunwar Grover mergeResult.value(), 1083d5f0969cSMaheshRavishankar tilingResult->generatedSlices}; 108476ead96cSMaheshRavishankar } 108597f91982SMahesh Ravishankar 10864b563458SKunwar Grover FailureOr<scf::SCFTilingResult> 10871cff4cbdSNicolas Vasilache mlir::scf::tileReductionUsingScf(RewriterBase &b, 10883310fe55SThomas Raoux PartialReductionOpInterface op, 1089170a25a7SMaheshRavishankar ArrayRef<OpFoldResult> tileSizes) { 10904b563458SKunwar Grover SCFTilingOptions options; 10914b563458SKunwar Grover options.setLoopType(SCFTilingOptions::LoopType::ForOp); 10924b563458SKunwar Grover options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy:: 10934b563458SKunwar Grover PartialReductionOuterReduction); 10944b563458SKunwar Grover options.setTileSizes(tileSizes); 10952cc5f5d4SGroverkss 10964b563458SKunwar Grover TilingInterface tilingInterfaceOp = 10974b563458SKunwar Grover dyn_cast<TilingInterface>(op.getOperation()); 10984b563458SKunwar Grover if (!tilingInterfaceOp) { 10994b563458SKunwar Grover return b.notifyMatchFailure( 11004b563458SKunwar Grover op, 11014b563458SKunwar Grover "Operation implementing PartialReductionOpInterface should implement " 11024b563458SKunwar Grover "TilingInterface"); 11033310fe55SThomas Raoux } 1104faac8989SAlex Zinenko 11054b563458SKunwar Grover return tileUsingSCF(b, tilingInterfaceOp, options); 11063310fe55SThomas Raoux } 110793c42299SMaheshRavishankar 11082f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===// 110976ead96cSMaheshRavishankar // tileConsumerAndFuseProducersUsingSCF implementation. 11102f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===// 11112f637fe7SMahesh Ravishankar 11127ee34550SMahesh Ravishankar /// Return the untiled producer whose slice is used in a tiled consumer. The 11137ee34550SMahesh Ravishankar /// method traverses the tile loop nest (`loops`) if needed, and returns the 11144b563458SKunwar Grover /// `iter_args` of the outer most that is encountered. Traversing the 11154b563458SKunwar Grover /// iter_args indicates that this is a destination operand of the consumer. If 11164b563458SKunwar Grover /// there was no loop traversal needed, the second value of the returned tuple 11174b563458SKunwar Grover /// is empty. 111822426110SRamkumar Ramachandra static std::tuple<OpResult, std::optional<OpOperand *>> 11197ee34550SMahesh Ravishankar getUntiledProducerFromSliceSource(OpOperand *source, 112076ead96cSMaheshRavishankar ArrayRef<LoopLikeOpInterface> loops) { 112122426110SRamkumar Ramachandra std::optional<OpOperand *> destinationIterArg; 11227ee34550SMahesh Ravishankar auto loopIt = loops.rbegin(); 11235550c821STres Popp while (auto iterArg = dyn_cast<BlockArgument>(source->get())) { 112476ead96cSMaheshRavishankar auto loop = *loopIt; 11257ee34550SMahesh Ravishankar if (iterArg.getOwner()->getParentOp() != loop) 11267ee34550SMahesh Ravishankar break; 11273cd2a0bcSMatthias Springer source = loop.getTiedLoopInit(iterArg); 11287ee34550SMahesh Ravishankar loopIt++; 11292f637fe7SMahesh Ravishankar } 11307ee34550SMahesh Ravishankar if (loopIt == loops.rend()) 11317ee34550SMahesh Ravishankar destinationIterArg = source; 11325550c821STres Popp return {dyn_cast<OpResult>(source->get()), destinationIterArg}; 11332ed7c3fdSlorenzo chelini } 11342ed7c3fdSlorenzo chelini 11359db7d4edSMahesh Ravishankar /// Implementation of fusing producer of a single slice by computing the 11369db7d4edSMahesh Ravishankar /// slice of the producer in-place. 11379db7d4edSMahesh Ravishankar std::optional<scf::SCFFuseProducerOfSliceResult> 113876ead96cSMaheshRavishankar mlir::scf::tileAndFuseProducerOfSlice( 113976ead96cSMaheshRavishankar RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, 114076ead96cSMaheshRavishankar MutableArrayRef<LoopLikeOpInterface> loops) { 1141ce349ff1SMahesh Ravishankar // 1. Get the producer of the source (potentially walking through 1142ce349ff1SMahesh Ravishankar // `iter_args` of nested `scf.for`) 11436923a315SMatthias Springer auto [fusableProducer, destinationInitArg] = 11448823e961SMatthias Springer getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), 1145ce349ff1SMahesh Ravishankar loops); 1146ce349ff1SMahesh Ravishankar if (!fusableProducer) 1147ce349ff1SMahesh Ravishankar return std::nullopt; 11484a020018SMaheshRavishankar unsigned resultNumber = fusableProducer.getResultNumber(); 1149ce349ff1SMahesh Ravishankar 1150ce349ff1SMahesh Ravishankar OpBuilder::InsertionGuard g(rewriter); 1151ce349ff1SMahesh Ravishankar rewriter.setInsertionPoint(candidateSliceOp); 11524a020018SMaheshRavishankar 11534a020018SMaheshRavishankar // 2. Clone the fused producer 11544a020018SMaheshRavishankar // 2a. Compute the destination operands to use for the cloned operation. 11554a020018SMaheshRavishankar SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors; 11564a020018SMaheshRavishankar Operation *fusableProducerOp = fusableProducer.getOwner(); 11574a020018SMaheshRavishankar if (isa<DestinationStyleOpInterface>(fusableProducerOp) && 11584a020018SMaheshRavishankar failed(tensor::getOrCreateDestinations( 11594a020018SMaheshRavishankar rewriter, fusableProducerOp->getLoc(), fusableProducerOp, 11604a020018SMaheshRavishankar origDestinationTensors))) 11614a020018SMaheshRavishankar return std::nullopt; 11624a020018SMaheshRavishankar 11634a020018SMaheshRavishankar clonedOpDestinationTensors = origDestinationTensors; 11644a020018SMaheshRavishankar if (destinationInitArg && 11654a020018SMaheshRavishankar isa<DestinationStyleOpInterface>(fusableProducerOp)) { 11664a020018SMaheshRavishankar // 2b. If the producer is also destination style, then to maintain the 11674a020018SMaheshRavishankar // destination passing style, update the destination of the producer to be 11684a020018SMaheshRavishankar // the source of the slice. 11694a020018SMaheshRavishankar clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource(); 11704a020018SMaheshRavishankar } 11714a020018SMaheshRavishankar // 2c. Clone the fused producer. 11724a020018SMaheshRavishankar Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs( 11734a020018SMaheshRavishankar rewriter, fusableProducerOp, clonedOpDestinationTensors); 11744a020018SMaheshRavishankar // 2d. Update the source of the candidateSlice to be the cloned producer. 11754b563458SKunwar Grover // Easier to just clone the slice with different source since 11764b563458SKunwar Grover // replacements and DCE of cloned ops becomes easier 11774a020018SMaheshRavishankar SmallVector<Value> candidateSliceOpOperands = 11784a020018SMaheshRavishankar llvm::to_vector(candidateSliceOp->getOperands()); 11794a020018SMaheshRavishankar candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber); 11804a020018SMaheshRavishankar tensor::ExtractSliceOp clonedCandidateSliceOp = 11814a020018SMaheshRavishankar mlir::clone(rewriter, candidateSliceOp, 11824a020018SMaheshRavishankar candidateSliceOp->getResultTypes(), candidateSliceOpOperands); 11834a020018SMaheshRavishankar 11844a020018SMaheshRavishankar // 3. Generate the tiled implementation of the producer of the source 1185809e3d8cSMahesh Ravishankar FailureOr<TilingResult> tileAndFuseResult = 11864a020018SMaheshRavishankar tensor::replaceExtractSliceWithTiledProducer( 11874a020018SMaheshRavishankar rewriter, clonedCandidateSliceOp, 11884a020018SMaheshRavishankar clonedProducerOp->getResult(resultNumber)); 1189809e3d8cSMahesh Ravishankar if (failed(tileAndFuseResult)) 1190ce349ff1SMahesh Ravishankar return std::nullopt; 11914a020018SMaheshRavishankar // Note: Do not delete the candidateSliceOp, since its passed in from the 11924a020018SMaheshRavishankar // caller. 1193809e3d8cSMahesh Ravishankar rewriter.replaceAllUsesWith(candidateSliceOp, 1194809e3d8cSMahesh Ravishankar tileAndFuseResult->tiledValues[0]); 11954a020018SMaheshRavishankar rewriter.eraseOp(clonedCandidateSliceOp); 11964a020018SMaheshRavishankar rewriter.eraseOp(clonedProducerOp); 1197ce349ff1SMahesh Ravishankar 1198ce349ff1SMahesh Ravishankar // 3. If the slice is for a destination operand, for example, 1199ce349ff1SMahesh Ravishankar // 1200ce349ff1SMahesh Ravishankar // ```mlir 1201ce349ff1SMahesh Ravishankar // %0 = linalg.init 1202ce349ff1SMahesh Ravishankar // %1 = linalg.fill .. outs(%0 : ) 1203ce349ff1SMahesh Ravishankar // %2 = scf.for .. iter_args(%arg0 = %1) { 1204ce349ff1SMahesh Ravishankar // %3 = scf.for .. iter_args(%arg1 = %arg0) { 1205ce349ff1SMahesh Ravishankar // %4 = tensor.extract_slice %arg1 [..] 1206ce349ff1SMahesh Ravishankar // .. = linalg.matmul .. outs(%4 : ) 1207ce349ff1SMahesh Ravishankar // } 1208ce349ff1SMahesh Ravishankar // } 1209ce349ff1SMahesh Ravishankar // ``` 1210ce349ff1SMahesh Ravishankar // 1211ce349ff1SMahesh Ravishankar // the IR is currently 1212ce349ff1SMahesh Ravishankar // 1213ce349ff1SMahesh Ravishankar // ``` 1214ce349ff1SMahesh Ravishankar // %0 = linalg.init 1215ce349ff1SMahesh Ravishankar // %1 = linalg.fill 1216ce349ff1SMahesh Ravishankar // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { 1217ce349ff1SMahesh Ravishankar // %3 = scf.for .. iter_args(%arg1 = %arg0) { 12184a020018SMaheshRavishankar // %4 = tensor.extract_slice %arg1[..] 1219ce349ff1SMahesh Ravishankar // %5 = linalg.fill .. outs(%4 : ) 1220ce349ff1SMahesh Ravishankar // .. = linalg.matmul .. outs(%5 : ) 1221ce349ff1SMahesh Ravishankar // } 1222ce349ff1SMahesh Ravishankar // } 1223ce349ff1SMahesh Ravishankar // ``` 1224ce349ff1SMahesh Ravishankar // 1225ce349ff1SMahesh Ravishankar // The untiled `linalg.fill` is still used as the `init_value` since it 1226ce349ff1SMahesh Ravishankar // was originally a destination operand of the untiled `linalg.matmul`. 12274a020018SMaheshRavishankar // When fusing an operand that is a destination operand, the iter_arg of 12284a020018SMaheshRavishankar // the outer most loop should be changed to use the destination of the 12294a020018SMaheshRavishankar // fused operation. With this the IR will be. 1230ce349ff1SMahesh Ravishankar // 1231ce349ff1SMahesh Ravishankar // ``` 1232ce349ff1SMahesh Ravishankar // %0 = linalg.init 1233ce349ff1SMahesh Ravishankar // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { 1234ce349ff1SMahesh Ravishankar // %2 = scf.for .. iter_args(%arg1 = %arg0) { 12354a020018SMaheshRavishankar // %3 = tensor.extract_slice %arg1[..] 1236ce349ff1SMahesh Ravishankar // %4 = linalg.fill .. outs(%3 : ) 1237ce349ff1SMahesh Ravishankar // .. = linalg.matmul .. outs(%4 : ) 1238ce349ff1SMahesh Ravishankar // } 1239ce349ff1SMahesh Ravishankar // } 1240ce349ff1SMahesh Ravishankar // ``` 12416923a315SMatthias Springer if (destinationInitArg && 12424a020018SMaheshRavishankar isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) { 12434a020018SMaheshRavishankar loops.front() 12444a020018SMaheshRavishankar ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] 12454a020018SMaheshRavishankar .set(origDestinationTensors[resultNumber]); 1246ce349ff1SMahesh Ravishankar } 1247d5f0969cSMaheshRavishankar return scf::SCFFuseProducerOfSliceResult{ 1248d5f0969cSMaheshRavishankar fusableProducer, tileAndFuseResult->tiledValues[0], 1249d5f0969cSMaheshRavishankar tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices}; 12509db7d4edSMahesh Ravishankar } 12519db7d4edSMahesh Ravishankar 12529db7d4edSMahesh Ravishankar /// Reconstruct the fused producer from within the tiled-and-fused code. 1253d5f0969cSMaheshRavishankar FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer( 12549db7d4edSMahesh Ravishankar RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, 12559db7d4edSMahesh Ravishankar scf::SCFFuseProducerOfSliceResult fusedProducerInfo, 12567ef08eacSYun-Fly MutableArrayRef<LoopLikeOpInterface> loops, 12577ef08eacSYun-Fly ArrayRef<unsigned> yieldResultNumber) { 12584a020018SMaheshRavishankar if (loops.empty()) 125976ead96cSMaheshRavishankar return success(); 12604a020018SMaheshRavishankar 12617ef08eacSYun-Fly Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(), 12627ef08eacSYun-Fly *tiledOwner = fusedProducerInfo.tiledOps[0]; 12637ef08eacSYun-Fly 12647ef08eacSYun-Fly Location loc = originalOwner->getLoc(); 12657ef08eacSYun-Fly // a. collect all init Value to be appended 12667ef08eacSYun-Fly SmallVector<unsigned> initNumberList = 12677ef08eacSYun-Fly yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>( 12687ef08eacSYun-Fly 0, originalOwner->getNumResults())) 12697ef08eacSYun-Fly : llvm::to_vector(yieldResultNumber); 12707ef08eacSYun-Fly SmallVector<Value> initValueList; 12717ef08eacSYun-Fly for (const auto &resultNumber : initNumberList) { 12729db7d4edSMahesh Ravishankar FailureOr<Value> initValue = tensor::getOrCreateDestination( 12737ef08eacSYun-Fly rewriter, loc, originalOwner->getResult(resultNumber)); 12749db7d4edSMahesh Ravishankar if (succeeded(initValue)) { 12757ef08eacSYun-Fly initValueList.push_back(initValue.value()); 12767ef08eacSYun-Fly } else { 12777ef08eacSYun-Fly return failure(); 12787ef08eacSYun-Fly } 12797ef08eacSYun-Fly } 12804a020018SMaheshRavishankar 1281d5f0969cSMaheshRavishankar SmallVector<Operation *> generatedSlices; 128276ead96cSMaheshRavishankar YieldTiledValuesFn newYieldValuesFn = 128376ead96cSMaheshRavishankar [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, 128476ead96cSMaheshRavishankar ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, 128576ead96cSMaheshRavishankar SmallVector<SmallVector<OpFoldResult>> &tiledOffset, 12867ef08eacSYun-Fly SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { 12874a020018SMaheshRavishankar OpBuilder::InsertionGuard g(innerRewriter); 12887ef08eacSYun-Fly 12897ef08eacSYun-Fly // get sliceOp tile information 12907ef08eacSYun-Fly SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(), 12917ef08eacSYun-Fly sliceSizes = sliceOp.getMixedSizes(); 12927ef08eacSYun-Fly 12937ef08eacSYun-Fly // expect all strides of sliceOp being 1 12947ef08eacSYun-Fly if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { 12957ef08eacSYun-Fly return !isConstantIntValue(ofr, 1); 12967ef08eacSYun-Fly })) 12977ef08eacSYun-Fly return failure(); 12987ef08eacSYun-Fly 12997ef08eacSYun-Fly unsigned sliceResultNumber = 13007ef08eacSYun-Fly fusedProducerInfo.origProducer.getResultNumber(); 13017ef08eacSYun-Fly 13027ef08eacSYun-Fly auto tilableOp = cast<TilingInterface>(originalOwner); 13037ef08eacSYun-Fly // b. get iterDomain Offset and Sizes based on sliceOp tile 13047ef08eacSYun-Fly SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes; 13057ef08eacSYun-Fly // skip tensor.pack/unpack/pad, which expects single opResult 13067ef08eacSYun-Fly if (tilableOp->getNumResults() > 1 && 13077ef08eacSYun-Fly failed(tilableOp.getIterationDomainTileFromResultTile( 13087ef08eacSYun-Fly rewriter, sliceResultNumber, sliceOffset, sliceSizes, 13097ef08eacSYun-Fly iterDomainOffset, iterDomainSizes))) { 13104b563458SKunwar Grover // In theory, it is unnecessary to raise an error here. Actually 13114b563458SKunwar Grover // although it fails to reconstruct the result tensor, it should not 13124b563458SKunwar Grover // broke current fusion anyway. The reason why we must return failure 13134b563458SKunwar Grover // currently is that the callback function `newYieldValuesFn` will be 13144b563458SKunwar Grover // called after new init operand(s) has already been appended. It will 13154b563458SKunwar Grover // take more refactoring to make sure the init operands are added 13164b563458SKunwar Grover // consistently in the future. For more details, please refer to: 13177ef08eacSYun-Fly // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814 13187ef08eacSYun-Fly return failure(); 13197ef08eacSYun-Fly } 13207ef08eacSYun-Fly 13217ef08eacSYun-Fly // c. calculate offsets and sizes info of all OpResults respectively based 13227ef08eacSYun-Fly // on iteration Domain Tile 13237ef08eacSYun-Fly SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList; 13247ef08eacSYun-Fly for (const auto &resultNumber : initNumberList) { 13257ef08eacSYun-Fly if (resultNumber == sliceResultNumber) { 13267ef08eacSYun-Fly offsetList.push_back(sliceOffset); 13277ef08eacSYun-Fly sizesList.push_back(sliceSizes); 13287ef08eacSYun-Fly } else { 13297ef08eacSYun-Fly assert(!iterDomainOffset.empty() && !iterDomainSizes.empty()); 13307ef08eacSYun-Fly // infer result tile according to the iteration domain tile 13317ef08eacSYun-Fly SmallVector<OpFoldResult> offset, sizes; 13327ef08eacSYun-Fly if (failed(tilableOp.getResultTilePosition( 13337ef08eacSYun-Fly rewriter, resultNumber, iterDomainOffset, iterDomainSizes, 13347ef08eacSYun-Fly offset, sizes))) { 13357ef08eacSYun-Fly return failure(); 13367ef08eacSYun-Fly } 13377ef08eacSYun-Fly offsetList.push_back(offset); 13387ef08eacSYun-Fly sizesList.push_back(sizes); 13397ef08eacSYun-Fly } 13407ef08eacSYun-Fly } 13417ef08eacSYun-Fly 13424b563458SKunwar Grover // d. create `extract_slice` for `iter_args` for DPS operation if 13434b563458SKunwar Grover // necessary 13444a020018SMaheshRavishankar if (auto tiledDestStyleOp = 13457ef08eacSYun-Fly dyn_cast<DestinationStyleOpInterface>(tiledOwner)) { 13464a020018SMaheshRavishankar rewriter.setInsertionPoint(tiledDestStyleOp); 13477ef08eacSYun-Fly for (const auto &&[index, newRegionArg] : 13487ef08eacSYun-Fly llvm::enumerate(newRegionIterArgs)) { 13494a020018SMaheshRavishankar auto destSlice = rewriter.create<tensor::ExtractSliceOp>( 13507ef08eacSYun-Fly loc, newRegionArg, offsetList[index], sizesList[index], 13517ef08eacSYun-Fly SmallVector<OpFoldResult>(offsetList[index].size(), 13527ef08eacSYun-Fly rewriter.getIndexAttr(1))); 1353d5f0969cSMaheshRavishankar generatedSlices.push_back(destSlice); 13547ef08eacSYun-Fly unsigned resultNumber = initNumberList[index]; 13555fcf907bSMatthias Springer rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { 13564a020018SMaheshRavishankar tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); 13574a020018SMaheshRavishankar }); 1358ec1086f2SMaheshRavishankar } 13597ef08eacSYun-Fly } 13607ef08eacSYun-Fly 13617ef08eacSYun-Fly // e. prepare tiled offset and sizes for later `insert_slice` creation by 13627ef08eacSYun-Fly // caller 13634a020018SMaheshRavishankar Block *block = rewriter.getInsertionPoint()->getBlock(); 13644a020018SMaheshRavishankar rewriter.setInsertionPoint(block->getTerminator()); 13657ef08eacSYun-Fly for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) { 13667ef08eacSYun-Fly tiledResult.push_back(tiledOwner->getResult(resultNumber)); 13677ef08eacSYun-Fly tiledOffset.emplace_back(offsetList[index]); 13687ef08eacSYun-Fly tiledSizes.emplace_back(sizesList[index]); 13697ef08eacSYun-Fly } 137076ead96cSMaheshRavishankar return success(); 13714a020018SMaheshRavishankar }; 13724a020018SMaheshRavishankar 1373d5f0969cSMaheshRavishankar if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList, 1374d5f0969cSMaheshRavishankar newYieldValuesFn))) { 1375d5f0969cSMaheshRavishankar return failure(); 1376d5f0969cSMaheshRavishankar } 1377d5f0969cSMaheshRavishankar return generatedSlices; 13789db7d4edSMahesh Ravishankar } 1379ce349ff1SMahesh Ravishankar 13809144fed3SQuinn Dawkins namespace { 13819144fed3SQuinn Dawkins 13829144fed3SQuinn Dawkins //===----------------------------------------------------------------------===// 13839144fed3SQuinn Dawkins // SliceTrackingListener 13849144fed3SQuinn Dawkins //===----------------------------------------------------------------------===// 13859144fed3SQuinn Dawkins 13869144fed3SQuinn Dawkins /// This class is a listener for tracking the insertion and removal of 13879144fed3SQuinn Dawkins /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy 13889144fed3SQuinn Dawkins /// fusion algorithm to apply cleanup patterns in between fusion steps. 13899144fed3SQuinn Dawkins class SliceTrackingListener : public RewriterBase::Listener { 13909144fed3SQuinn Dawkins public: 13919144fed3SQuinn Dawkins explicit SliceTrackingListener( 13929144fed3SQuinn Dawkins std::optional<FrozenRewritePatternSet> patterns); 13939144fed3SQuinn Dawkins SliceTrackingListener() = default; 13949144fed3SQuinn Dawkins 13954b563458SKunwar Grover /// Adds the given list of operations to the worklist, and if present, 13964b563458SKunwar Grover /// applies the list of `patterns` to the newly added operations. This only 13974b563458SKunwar Grover /// processes the given operations and any newly inserted ones by the 13984b563458SKunwar Grover /// pattern set. 13999144fed3SQuinn Dawkins LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps); 14009144fed3SQuinn Dawkins 14019144fed3SQuinn Dawkins /// Add to the new operation worklist if it is an extract_slice. 14029144fed3SQuinn Dawkins void notifyOperationInserted(Operation *op, 14039144fed3SQuinn Dawkins OpBuilder::InsertPoint previous) override; 14049144fed3SQuinn Dawkins 14059144fed3SQuinn Dawkins /// Shared helper for operation removal from the worklist. 14069144fed3SQuinn Dawkins void removeOp(Operation *op); 14079144fed3SQuinn Dawkins 14089144fed3SQuinn Dawkins /// Remove the operation from the worklist. 14099144fed3SQuinn Dawkins void notifyOperationErased(Operation *op) override; 14109144fed3SQuinn Dawkins 14119144fed3SQuinn Dawkins /// Remove the operation from the worklist. 14129144fed3SQuinn Dawkins void notifyOperationReplaced(Operation *op, ValueRange replacement) override; 14139144fed3SQuinn Dawkins 14149144fed3SQuinn Dawkins /// The worklist for this transformation keeps track of the slices to visit 14159144fed3SQuinn Dawkins /// next for fusion. 14169144fed3SQuinn Dawkins std::deque<tensor::ExtractSliceOp> worklist; 14179144fed3SQuinn Dawkins 14189144fed3SQuinn Dawkins private: 14194b563458SKunwar Grover /// Optional pattern set to apply when adding new operations to the 14204b563458SKunwar Grover /// worklist. 14219144fed3SQuinn Dawkins std::optional<FrozenRewritePatternSet> patterns = std::nullopt; 14229144fed3SQuinn Dawkins }; 14239144fed3SQuinn Dawkins 14249144fed3SQuinn Dawkins SliceTrackingListener::SliceTrackingListener( 14259144fed3SQuinn Dawkins std::optional<FrozenRewritePatternSet> p) { 14269144fed3SQuinn Dawkins patterns = std::move(p); 14279144fed3SQuinn Dawkins } 14289144fed3SQuinn Dawkins 14299144fed3SQuinn Dawkins LogicalResult 14309144fed3SQuinn Dawkins SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) { 14319144fed3SQuinn Dawkins for (Operation *op : ops) { 14329144fed3SQuinn Dawkins if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op)) 14339144fed3SQuinn Dawkins worklist.push_back(slice); 14349144fed3SQuinn Dawkins } 14359144fed3SQuinn Dawkins 14369144fed3SQuinn Dawkins if (!patterns) 14379144fed3SQuinn Dawkins return success(); 14389144fed3SQuinn Dawkins 14399144fed3SQuinn Dawkins GreedyRewriteConfig config; 14409144fed3SQuinn Dawkins config.listener = this; 14419144fed3SQuinn Dawkins config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 144209dfc571SJacques Pienaar return applyOpPatternsGreedily(ops, patterns.value(), config); 14439144fed3SQuinn Dawkins } 14449144fed3SQuinn Dawkins 14459144fed3SQuinn Dawkins void SliceTrackingListener::notifyOperationInserted( 14469144fed3SQuinn Dawkins Operation *op, OpBuilder::InsertPoint previous) { 14479144fed3SQuinn Dawkins auto slice = dyn_cast<tensor::ExtractSliceOp>(op); 14489144fed3SQuinn Dawkins if (!slice) 14499144fed3SQuinn Dawkins return; 14509144fed3SQuinn Dawkins worklist.push_back(slice); 14519144fed3SQuinn Dawkins } 14529144fed3SQuinn Dawkins 14534b563458SKunwar Grover // Scan the worklist for the given op and remove it if present. The 14544b563458SKunwar Grover // expectation is for the worklist to be small and for removal to be 14554b563458SKunwar Grover // relatively rare. 14569144fed3SQuinn Dawkins void SliceTrackingListener::removeOp(Operation *op) { 14579144fed3SQuinn Dawkins if (!isa<tensor::ExtractSliceOp>(op)) 14589144fed3SQuinn Dawkins return; 14599144fed3SQuinn Dawkins auto iter = worklist.begin(); 14609144fed3SQuinn Dawkins while (iter != worklist.end()) { 14619144fed3SQuinn Dawkins if (*iter == op) 14629144fed3SQuinn Dawkins break; 14639144fed3SQuinn Dawkins iter++; 14649144fed3SQuinn Dawkins } 14659144fed3SQuinn Dawkins if (iter == worklist.end()) 14669144fed3SQuinn Dawkins return; 14679144fed3SQuinn Dawkins 14689144fed3SQuinn Dawkins worklist.erase(iter); 14699144fed3SQuinn Dawkins } 14709144fed3SQuinn Dawkins 14719144fed3SQuinn Dawkins void SliceTrackingListener::notifyOperationErased(Operation *op) { 14729144fed3SQuinn Dawkins removeOp(op); 14739144fed3SQuinn Dawkins } 14749144fed3SQuinn Dawkins 14759144fed3SQuinn Dawkins void SliceTrackingListener::notifyOperationReplaced(Operation *op, 14769144fed3SQuinn Dawkins ValueRange replacement) { 14779144fed3SQuinn Dawkins removeOp(op); 14789144fed3SQuinn Dawkins } 14796e3631d0SKunwar Grover 14806e3631d0SKunwar Grover //===----------------------------------------------------------------------===// 14816e3631d0SKunwar Grover // ReplacementListener 14826e3631d0SKunwar Grover //===----------------------------------------------------------------------===// 14836e3631d0SKunwar Grover 14846e3631d0SKunwar Grover /// Listener that tracks updates replacements for values which can be mutated. 14856e3631d0SKunwar Grover /// This listener runs on top of the existing listener for the rewriter, 14866e3631d0SKunwar Grover /// to make sure external users can still run listeners. 14876e3631d0SKunwar Grover class ReplacementListener : public RewriterBase::ForwardingListener { 14886e3631d0SKunwar Grover public: 14896e3631d0SKunwar Grover ReplacementListener(DenseMap<Value, Value> &replacements, 14906e3631d0SKunwar Grover OpBuilder::Listener *listener) 14916e3631d0SKunwar Grover : ForwardingListener(listener), replacements(replacements) {} 14926e3631d0SKunwar Grover 14936e3631d0SKunwar Grover void updateReplacementValues(ValueRange origValues, 14946e3631d0SKunwar Grover ValueRange replaceValues) { 14956e3631d0SKunwar Grover // This can probably be written better, but just iterates over the map 14966e3631d0SKunwar Grover // and the new replacements for now. 14976e3631d0SKunwar Grover for (auto &[key, val] : replacements) { 14986e3631d0SKunwar Grover for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) { 14996e3631d0SKunwar Grover if (val == orig) { 15006e3631d0SKunwar Grover val = replace; 15016e3631d0SKunwar Grover } 15026e3631d0SKunwar Grover } 15036e3631d0SKunwar Grover } 15046e3631d0SKunwar Grover } 15056e3631d0SKunwar Grover 15066e3631d0SKunwar Grover void notifyOperationReplaced(Operation *op, Operation *newOp) override { 15076e3631d0SKunwar Grover ForwardingListener::notifyOperationReplaced(op, newOp); 15086e3631d0SKunwar Grover updateReplacementValues(op->getResults(), newOp->getResults()); 15096e3631d0SKunwar Grover } 15106e3631d0SKunwar Grover 15116e3631d0SKunwar Grover void notifyOperationReplaced(Operation *op, ValueRange values) override { 15126e3631d0SKunwar Grover ForwardingListener::notifyOperationReplaced(op, values); 15136e3631d0SKunwar Grover updateReplacementValues(op->getResults(), values); 15146e3631d0SKunwar Grover } 15156e3631d0SKunwar Grover 15166e3631d0SKunwar Grover private: 15176e3631d0SKunwar Grover DenseMap<Value, Value> &replacements; 15186e3631d0SKunwar Grover }; 15196e3631d0SKunwar Grover 15209144fed3SQuinn Dawkins } // namespace 15219144fed3SQuinn Dawkins 152297f91982SMahesh Ravishankar /// Implementation of tile consumer and fuse producer greedily. 15232f637fe7SMahesh Ravishankar FailureOr<scf::SCFTileAndFuseResult> 152476ead96cSMaheshRavishankar mlir::scf::tileConsumerAndFuseProducersUsingSCF( 152597f91982SMahesh Ravishankar RewriterBase &rewriter, TilingInterface consumer, 15266d4baa74SMehdi Amini const scf::SCFTileAndFuseOptions &options) { 15272f637fe7SMahesh Ravishankar // This transformation is only valid for ops that return values (i.e. not 15282f637fe7SMahesh Ravishankar // valid to use with operations that have memref operands). 152997f91982SMahesh Ravishankar if (!consumer->getNumResults()) { 15302f637fe7SMahesh Ravishankar return rewriter.notifyMatchFailure( 153197f91982SMahesh Ravishankar consumer, "invalid pattern for op with no results"); 15322f637fe7SMahesh Ravishankar } 15332f637fe7SMahesh Ravishankar 15342f637fe7SMahesh Ravishankar // 1. First tile the consumer. 153593c42299SMaheshRavishankar SetVector<Operation *> fusedProducers, tiledAndFusedOps; 15364435ced9SMaheshRavishankar llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum; 153776ead96cSMaheshRavishankar 153897f91982SMahesh Ravishankar FailureOr<scf::SCFTilingResult> tilingResult = 153976ead96cSMaheshRavishankar tileUsingSCF(rewriter, consumer, options.tilingOptions); 154076ead96cSMaheshRavishankar 154197f91982SMahesh Ravishankar if (failed(tilingResult)) 154297f91982SMahesh Ravishankar return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); 1543fbfca43eSMehdi Amini for (auto *tiledOp : tilingResult->tiledOps) 154493c42299SMaheshRavishankar tiledAndFusedOps.insert(tiledOp); 154597f91982SMahesh Ravishankar 15464435ced9SMaheshRavishankar DenseMap<Value, Value> replacements; 15474b563458SKunwar Grover for (auto [origVal, replacement] : llvm::zip_equal( 15484b563458SKunwar Grover consumer->getResults(), tilingResult->mergeResult.replacements)) { 15494435ced9SMaheshRavishankar replacements[origVal] = replacement; 15504435ced9SMaheshRavishankar } 15516e3631d0SKunwar Grover 15526e3631d0SKunwar Grover // If there are no loops generated, fusion is immaterial. 15536e3631d0SKunwar Grover auto &loops = tilingResult->loops; 15546e3631d0SKunwar Grover if (loops.empty()) { 155576ead96cSMaheshRavishankar return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, 155676ead96cSMaheshRavishankar replacements}; 155793c42299SMaheshRavishankar } 15582f637fe7SMahesh Ravishankar 15596e3631d0SKunwar Grover // Since the loop gets potentially replaced during fusion, we need to track 15606e3631d0SKunwar Grover // the mutation of replacement values. To do this, we attach a listener to 15616e3631d0SKunwar Grover // update the replacements as they happen. 15626e3631d0SKunwar Grover OpBuilder::Listener *previousListener = rewriter.getListener(); 15636e3631d0SKunwar Grover auto resetListener = 15646e3631d0SKunwar Grover llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); }); 15656e3631d0SKunwar Grover ReplacementListener replaceListener(replacements, previousListener); 15666e3631d0SKunwar Grover rewriter.setListener(&replaceListener); 15674435ced9SMaheshRavishankar 15682f637fe7SMahesh Ravishankar // 2. Typically, the operands of the tiled operation are slices of the 15692f637fe7SMahesh Ravishankar // operands of the untiled operation. These are expressed in IR using 15704b563458SKunwar Grover // `tensor.extract_slice` operations with source being the operands of 15714b563458SKunwar Grover // the untiled operation. Create a worklist of these 15724b563458SKunwar Grover // `tensor.extract_slice` operations. If the producers of the source of 15734b563458SKunwar Grover // the `tensor.extract_slice` can be tiled such that the tiled value is 15744b563458SKunwar Grover // generated in-place, that effectively tiles + fuses the operations. 1575d5f0969cSMaheshRavishankar struct WorklistItem { 1576d5f0969cSMaheshRavishankar tensor::ExtractSliceOp candidateSlice; 1577d5f0969cSMaheshRavishankar SCFTileAndFuseOptions::ControlFnResult controlFnResult; 15782f637fe7SMahesh Ravishankar }; 15799144fed3SQuinn Dawkins 15809144fed3SQuinn Dawkins SliceTrackingListener sliceTracker = 15819144fed3SQuinn Dawkins SliceTrackingListener(options.cleanupPatterns); 15829144fed3SQuinn Dawkins 15839144fed3SQuinn Dawkins if (failed( 15849144fed3SQuinn Dawkins sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) { 15859144fed3SQuinn Dawkins return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); 15869144fed3SQuinn Dawkins } 15879144fed3SQuinn Dawkins OpBuilder::InsertionGuard g(rewriter); 15889144fed3SQuinn Dawkins while (!sliceTracker.worklist.empty()) { 15899144fed3SQuinn Dawkins auto candidateSlice = sliceTracker.worklist.front(); 15909144fed3SQuinn Dawkins sliceTracker.worklist.pop_front(); 15912f637fe7SMahesh Ravishankar 15924435ced9SMaheshRavishankar auto [fusableProducer, destinationInitArg] = 15939144fed3SQuinn Dawkins getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(), 15949144fed3SQuinn Dawkins loops); 15954435ced9SMaheshRavishankar if (!fusableProducer) 15964435ced9SMaheshRavishankar continue; 15979144fed3SQuinn Dawkins 1598d5f0969cSMaheshRavishankar std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult = 15999144fed3SQuinn Dawkins options.fusionControlFn(candidateSlice, fusableProducer, 1600d5f0969cSMaheshRavishankar destinationInitArg.has_value()); 1601d5f0969cSMaheshRavishankar if (!controlFnResult) 16024435ced9SMaheshRavishankar continue; 1603d5f0969cSMaheshRavishankar 16049144fed3SQuinn Dawkins WorklistItem worklistItem = {candidateSlice, controlFnResult.value()}; 16054435ced9SMaheshRavishankar 1606ce349ff1SMahesh Ravishankar // The operands of the fused producer might themselved be slices of 16072f637fe7SMahesh Ravishankar // values produced by operations that implement the `TilingInterface`. 16082f637fe7SMahesh Ravishankar // Add these operations to the worklist. 160993c42299SMaheshRavishankar std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult = 1610d5f0969cSMaheshRavishankar tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice, 1611d5f0969cSMaheshRavishankar loops); 161293c42299SMaheshRavishankar if (!fusedResult) 1613ce349ff1SMahesh Ravishankar continue; 16142f637fe7SMahesh Ravishankar 16159144fed3SQuinn Dawkins SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices; 16169144fed3SQuinn Dawkins 1617d5f0969cSMaheshRavishankar if (worklistItem.controlFnResult.yieldProducerReplacement) { 16184b563458SKunwar Grover // Reconstruct and yield all opResult of fusableProducerOp by default. 16194b563458SKunwar Grover // The caller can specific which one to yield by designating optional 16204b563458SKunwar Grover // argument named `yieldResultNumber` of 16214b563458SKunwar Grover // `yieldReplacementForFusedProducer`. 1622d5f0969cSMaheshRavishankar Operation *fusableProducerOp = fusedResult->origProducer.getOwner(); 1623d5f0969cSMaheshRavishankar FailureOr<SmallVector<Operation *>> newSlices = 1624d5f0969cSMaheshRavishankar yieldReplacementForFusedProducer(rewriter, 1625d5f0969cSMaheshRavishankar worklistItem.candidateSlice, 1626d5f0969cSMaheshRavishankar fusedResult.value(), loops); 1627d5f0969cSMaheshRavishankar if (failed(newSlices)) { 162876ead96cSMaheshRavishankar return rewriter.notifyMatchFailure( 16297ef08eacSYun-Fly fusableProducerOp, "failed to replacement value for this " 16307ef08eacSYun-Fly "operation from within the tiled loop"); 163176ead96cSMaheshRavishankar } 16329144fed3SQuinn Dawkins worklistCandidates.append(newSlices.value()); 16337ef08eacSYun-Fly for (auto [index, result] : 16347ef08eacSYun-Fly llvm::enumerate(fusableProducerOp->getResults())) { 16356e3631d0SKunwar Grover replacements[result] = loops.front()->getResult( 16366e3631d0SKunwar Grover loops.front()->getNumResults() - 16376e3631d0SKunwar Grover fusableProducerOp->getNumResults() + index); 16387ef08eacSYun-Fly } 16394435ced9SMaheshRavishankar } 16409db7d4edSMahesh Ravishankar if (Operation *tiledAndFusedOp = 164193c42299SMaheshRavishankar fusedResult->tiledAndFusedProducer.getDefiningOp()) { 164293c42299SMaheshRavishankar fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); 164393c42299SMaheshRavishankar tiledAndFusedOps.insert(tiledAndFusedOp); 16449db7d4edSMahesh Ravishankar } 16459144fed3SQuinn Dawkins 16469144fed3SQuinn Dawkins if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) { 16479144fed3SQuinn Dawkins return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); 16489144fed3SQuinn Dawkins } 16492f637fe7SMahesh Ravishankar } 16504435ced9SMaheshRavishankar 165176ead96cSMaheshRavishankar return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, 165276ead96cSMaheshRavishankar replacements}; 1653d871daeaSMaheshRavishankar } 1654d871daeaSMaheshRavishankar 1655d871daeaSMaheshRavishankar //===----------------------------------------------------------------------===// 16562b2ce50fSAbhishek Varma // tileAndFuseConsumerUsingSCF implementation. 16572b2ce50fSAbhishek Varma //===----------------------------------------------------------------------===// 16582b2ce50fSAbhishek Varma 16592b2ce50fSAbhishek Varma /// A utility function that checks whether the only use of the result of a 16602b2ce50fSAbhishek Varma /// tensor.insert_slice op is in a scf.yield op. 16612b2ce50fSAbhishek Varma static LogicalResult 16622b2ce50fSAbhishek Varma checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { 16632b2ce50fSAbhishek Varma Value result = candidateSliceOp.getResult(); 16642b2ce50fSAbhishek Varma Value::use_range uses = result.getUses(); 16652b2ce50fSAbhishek Varma if (!llvm::hasSingleElement(uses)) { 16662b2ce50fSAbhishek Varma LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n"); 16672b2ce50fSAbhishek Varma return failure(); 16682b2ce50fSAbhishek Varma } 16692b2ce50fSAbhishek Varma OpOperand &operandUse = (*uses.begin()); 16702b2ce50fSAbhishek Varma Operation *userOp = operandUse.getOwner(); 16712b2ce50fSAbhishek Varma if (!isa<scf::YieldOp>(userOp)) { 16722b2ce50fSAbhishek Varma LLVM_DEBUG(llvm::dbgs() 16732b2ce50fSAbhishek Varma << "Expected scf.yield to be the only user, but got -> " 16742b2ce50fSAbhishek Varma << (*userOp)); 16752b2ce50fSAbhishek Varma return failure(); 16762b2ce50fSAbhishek Varma } 16772b2ce50fSAbhishek Varma if (result.getDefiningOp()->getBlock() != userOp->getBlock()) { 16782b2ce50fSAbhishek Varma LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to " 16792b2ce50fSAbhishek Varma "be in the same block\n"); 16802b2ce50fSAbhishek Varma return failure(); 16812b2ce50fSAbhishek Varma } 16822b2ce50fSAbhishek Varma return success(); 16832b2ce50fSAbhishek Varma } 16842b2ce50fSAbhishek Varma 16854b563458SKunwar Grover /// An utility to get the first user of the given loopOp. If any of user stay 16864b563458SKunwar Grover /// in different block of loopOp, return failure. 16879bc3102bSYun-Fly static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) { 16889bc3102bSYun-Fly if (!isa<LoopLikeOpInterface>(loopOp)) 16892b2ce50fSAbhishek Varma return failure(); 16909bc3102bSYun-Fly Operation *firstUserOfLoop = nullptr; 16919bc3102bSYun-Fly for (Operation *userOp : loopOp->getUsers()) { 16929bc3102bSYun-Fly // `ParallelInsertSlice` located inside `InParallelOp` has no same parent 16939bc3102bSYun-Fly // block with any other types of operation. Thus, just redirecting to its 16949bc3102bSYun-Fly // parent `InParallelOp`. E.g. 16959bc3102bSYun-Fly // 16969bc3102bSYun-Fly // ``` 16979bc3102bSYun-Fly // %1 = scf.for { 16989bc3102bSYun-Fly // ... 16999bc3102bSYun-Fly // } 17009bc3102bSYun-Fly // %2 = consumerOp ins(%1, ...) 17019bc3102bSYun-Fly // scf.forall.in_parallel { 17029bc3102bSYun-Fly // tensor.parallel_insert_slice %1 17039bc3102bSYun-Fly // } 17049bc3102bSYun-Fly // ``` 17059bc3102bSYun-Fly // where `InParallelOp` but not `ParallelInsertSlice` stays in the same 17069bc3102bSYun-Fly // same block with `consumerOp`. 17079bc3102bSYun-Fly if (isa<tensor::ParallelInsertSliceOp>(userOp)) 17089bc3102bSYun-Fly userOp = userOp->getParentOfType<scf::InParallelOp>(); 17099bc3102bSYun-Fly 17109bc3102bSYun-Fly if (loopOp->getBlock() != userOp->getBlock()) 17112b2ce50fSAbhishek Varma return failure(); 17129bc3102bSYun-Fly 17139bc3102bSYun-Fly if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop)) 17149bc3102bSYun-Fly firstUserOfLoop = userOp; 17159bc3102bSYun-Fly } 17169bc3102bSYun-Fly return firstUserOfLoop; 1717b8c974f0SAbhishek Varma } 1718b8c974f0SAbhishek Varma 17194b563458SKunwar Grover /// This utility currently checks whether the first userOp of loop is NOT 17204b563458SKunwar Grover /// before the last defineOp of consumer operand. Because that we need to move 17214b563458SKunwar Grover /// the whole loop structure right before the `firstUserOfLoop`. This utility 17224b563458SKunwar Grover /// thus helps ensuring that no invalid IR is formed, i.e. no backward slice 17234b563458SKunwar Grover /// of consumerOp is dominated by the `firstUserOfLoop`. Saying that: 17249bc3102bSYun-Fly /// 17259bc3102bSYun-Fly /// ``` 17269bc3102bSYun-Fly /// %0 = scf.for() { 17279bc3102bSYun-Fly /// ... 17289bc3102bSYun-Fly /// } 17299bc3102bSYun-Fly /// ... 17309bc3102bSYun-Fly /// %1 = firstUserOfLoop(%0) 17319bc3102bSYun-Fly /// ... 17329bc3102bSYun-Fly /// %2 = lastDefOfConsumerOperand 17339bc3102bSYun-Fly /// ... 17349bc3102bSYun-Fly /// %3 = consumerOp(%2) 17359bc3102bSYun-Fly /// ``` 17369bc3102bSYun-Fly /// 17374b563458SKunwar Grover /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it 17384b563458SKunwar Grover /// would be invalid to move the `loopOp` right before the `firstUserOfLoop`, 17394b563458SKunwar Grover /// a.k.a. use-def chain violation: 17409bc3102bSYun-Fly /// 17419bc3102bSYun-Fly /// ``` 17429bc3102bSYun-Fly /// %0:2 = scf.for() { 17439bc3102bSYun-Fly /// // use before define error 17449bc3102bSYun-Fly /// %3 = tiledConsumerOp(%2) 17459bc3102bSYun-Fly /// } 17469bc3102bSYun-Fly /// %1 = firstUserOfLoop(%0) 17479bc3102bSYun-Fly /// ... 17489bc3102bSYun-Fly /// %2 = lastDefOfConsumerOperand 17499bc3102bSYun-Fly /// ``` 17509bc3102bSYun-Fly /// 17519bc3102bSYun-Fly /// @param loopOp: loop operation 17529bc3102bSYun-Fly /// @param consumerOp: consumer operation 17534b563458SKunwar Grover /// @param reorderOperations: the flag controls whether to reorder the 17544b563458SKunwar Grover /// backward slice w.r.t. the defineOp of `consumerOp` operands. 17554b563458SKunwar Grover /// @return: computed backward slice of consumerOp, but excluding those 17564b563458SKunwar Grover /// already dominates `firstUserOfLoop`. 17579bc3102bSYun-Fly static FailureOr<llvm::SetVector<Operation *>> 17589bc3102bSYun-Fly checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, 17599bc3102bSYun-Fly bool reorderOperations) { 17609bc3102bSYun-Fly FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp); 17619bc3102bSYun-Fly if (failed(firstUserOfLoop)) 17629bc3102bSYun-Fly return failure(); 17639bc3102bSYun-Fly 17649bc3102bSYun-Fly BackwardSliceOptions options; 17659bc3102bSYun-Fly DominanceInfo dominanceInfo; 17669bc3102bSYun-Fly options.inclusive = true; 17679bc3102bSYun-Fly options.omitBlockArguments = true; 17689bc3102bSYun-Fly bool includeLoopOp = false; 17699bc3102bSYun-Fly options.filter = [&](Operation *op) { 17709bc3102bSYun-Fly if (op == loopOp) { 17719bc3102bSYun-Fly includeLoopOp = true; 17729bc3102bSYun-Fly return false; 17739bc3102bSYun-Fly } 17749bc3102bSYun-Fly // Cut off the slice to not include any operation that already dominates 17759bc3102bSYun-Fly // firstUserOfLoop. 17769bc3102bSYun-Fly return !dominanceInfo.properlyDominates(op, *firstUserOfLoop); 17779bc3102bSYun-Fly }; 17789bc3102bSYun-Fly llvm::SetVector<Operation *> slice; 17799bc3102bSYun-Fly for (auto operand : consumerOp->getOperands()) { 17809bc3102bSYun-Fly getBackwardSlice(operand, &slice, options); 17819bc3102bSYun-Fly } 17829bc3102bSYun-Fly 17839bc3102bSYun-Fly if (!slice.empty()) { 17849bc3102bSYun-Fly // If consumerOp has one producer, which is also the user of loopOp. 17859bc3102bSYun-Fly // E.g. 17869bc3102bSYun-Fly // ``` 17879bc3102bSYun-Fly // %0 = %loopOp 17889bc3102bSYun-Fly // %1 = consumerOp1 ins(%0) 17899bc3102bSYun-Fly // %2 = consumerOp2 ins(%0, %1) 17909bc3102bSYun-Fly // ``` 17919bc3102bSYun-Fly // We can not fuse consumerOp2 into loopOp due to UD chain, unless 17929bc3102bSYun-Fly // consumerOp1 has already been fused into loopOp before. 17939bc3102bSYun-Fly if (includeLoopOp || !reorderOperations) 17949bc3102bSYun-Fly return failure(); 17959bc3102bSYun-Fly } 17969bc3102bSYun-Fly 17979bc3102bSYun-Fly return slice; 17989bc3102bSYun-Fly } 17999bc3102bSYun-Fly 18009bc3102bSYun-Fly /// Fetches the OpOperand of the first valid user (and use) of the value `val` 18019bc3102bSYun-Fly /// which implements `TilingInterface` and `DestinationStyleOpInterface`. 18029bc3102bSYun-Fly /// Returns failure otherwise. 18039bc3102bSYun-Fly static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter, 18049bc3102bSYun-Fly Operation *loopOp, 18059bc3102bSYun-Fly unsigned resultNumber) { 18069bc3102bSYun-Fly if (!isa<LoopLikeOpInterface>(loopOp)) 18079bc3102bSYun-Fly return failure(); 18089bc3102bSYun-Fly Value val = loopOp->getResult(resultNumber); 18099bc3102bSYun-Fly Block *loopBlock = loopOp->getBlock(); 18109bc3102bSYun-Fly for (OpOperand &opOperand : val.getUses()) { 18119bc3102bSYun-Fly Operation *consumerOp = opOperand.getOwner(); 18129bc3102bSYun-Fly // Step 1. Check if the user is tilable. 181354ae9e7bSQuinn Dawkins if (!isa<TilingInterface>(consumerOp) || 181454ae9e7bSQuinn Dawkins !isa<DestinationStyleOpInterface>(consumerOp)) { 18159bc3102bSYun-Fly // TODO: We have to init result of consumer before scf.for, use 18164b563458SKunwar Grover // DestinationStyleOpInterface to get result shape from init for now. 18174b563458SKunwar Grover // Add support for other op such as op has InferTypeOpInterface. 18189bc3102bSYun-Fly continue; 18199bc3102bSYun-Fly } 18209bc3102bSYun-Fly // Step 2. Check if user stay in the same block. 18219bc3102bSYun-Fly if (loopBlock != consumerOp->getBlock()) 18229bc3102bSYun-Fly continue; 18239bc3102bSYun-Fly // Step 3. Check if user has succeeding user. Otherwise, it usually 18249bc3102bSYun-Fly // represents already tiled. 18259bc3102bSYun-Fly if (consumerOp->use_empty()) 18269bc3102bSYun-Fly continue; 18279bc3102bSYun-Fly // Step 4. Check assumption for loop with `reorderOperations` enabled. 18289bc3102bSYun-Fly FailureOr<llvm::SetVector<Operation *>> slice = 18299bc3102bSYun-Fly checkAssumptionForLoop(loopOp, consumerOp, true); 18309bc3102bSYun-Fly if (failed(slice)) 18319bc3102bSYun-Fly continue; 18324b563458SKunwar Grover // Step 5. If backward sice is not empty, move them before 18334b563458SKunwar Grover // firstUserOfLoop. 18349bc3102bSYun-Fly if (!slice->empty()) { 18359bc3102bSYun-Fly mlir::topologicalSort(*slice); 18369bc3102bSYun-Fly FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp); 18379bc3102bSYun-Fly assert(succeeded(firstUserOfLoop) && "First user of loop is not found"); 18389bc3102bSYun-Fly for (auto op : *slice) { 18399bc3102bSYun-Fly rewriter.moveOpBefore(op, *firstUserOfLoop); 18409bc3102bSYun-Fly } 18419bc3102bSYun-Fly } 18429bc3102bSYun-Fly return &opOperand; 18439bc3102bSYun-Fly } 1844b8c974f0SAbhishek Varma return failure(); 18452b2ce50fSAbhishek Varma } 18462b2ce50fSAbhishek Varma 18474b563458SKunwar Grover /// Find the perfectly nested loops outside of given loop(included) sorted 18484b563458SKunwar Grover /// from outer to inner. 1849a9ba1b6dSYun-Fly /// 1850a9ba1b6dSYun-Fly /// E.g. 1851a9ba1b6dSYun-Fly /// 1852a9ba1b6dSYun-Fly /// ``` 1853a9ba1b6dSYun-Fly /// %0 = scf.for() 1854a9ba1b6dSYun-Fly /// %1 = scf.for() 1855a9ba1b6dSYun-Fly /// %2 = scf.for() 1856a9ba1b6dSYun-Fly /// %3 = ... 1857a9ba1b6dSYun-Fly /// yield %3 1858a9ba1b6dSYun-Fly /// yield %2 1859a9ba1b6dSYun-Fly /// yield %1 1860a9ba1b6dSYun-Fly /// ``` 1861a9ba1b6dSYun-Fly /// 1862a9ba1b6dSYun-Fly /// This function will return three perfectly nested loops: %0 + %1 + %2, when 1863a9ba1b6dSYun-Fly /// target inner loop is %2. 1864a9ba1b6dSYun-Fly static SmallVector<scf::ForOp> 1865a9ba1b6dSYun-Fly getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) { 1866a9ba1b6dSYun-Fly SmallVector<scf::ForOp> nestLoops = {loop}; 1867a9ba1b6dSYun-Fly auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp()); 1868a9ba1b6dSYun-Fly 1869a9ba1b6dSYun-Fly // Check if it is the ForOp that yield the result of inner loop. 1870a9ba1b6dSYun-Fly auto isForOpYieldResultOfInnerLoop = 1871a9ba1b6dSYun-Fly [](scf::ForOp outerLoop) -> LogicalResult { 1872a9ba1b6dSYun-Fly Block *body = outerLoop.getBody(); 1873a9ba1b6dSYun-Fly if (!llvm::hasSingleElement(body->without_terminator())) 1874a9ba1b6dSYun-Fly return failure(); 1875a9ba1b6dSYun-Fly auto yieldOp = cast<scf::YieldOp>(body->getTerminator()); 1876a9ba1b6dSYun-Fly auto innerForOp = dyn_cast<scf::ForOp>(body->front()); 1877a9ba1b6dSYun-Fly if (!innerForOp) 1878a9ba1b6dSYun-Fly return failure(); 1879a9ba1b6dSYun-Fly // All of innerForOp results should be yielded. 1880a9ba1b6dSYun-Fly return success(innerForOp->getNumResults() == yieldOp->getNumOperands()); 1881a9ba1b6dSYun-Fly }; 1882a9ba1b6dSYun-Fly 1883a9ba1b6dSYun-Fly while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) { 1884a9ba1b6dSYun-Fly nestLoops.push_back(outerLoop); 1885a9ba1b6dSYun-Fly outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp()); 1886a9ba1b6dSYun-Fly } 1887a9ba1b6dSYun-Fly // sorted from outer to inner 1888a9ba1b6dSYun-Fly return {nestLoops.rbegin(), nestLoops.rend()}; 1889a9ba1b6dSYun-Fly } 1890a9ba1b6dSYun-Fly 18912b2ce50fSAbhishek Varma /// Fetch the untiled consumer of a scf.for's result which is yielded by a 18922b2ce50fSAbhishek Varma /// tensor.insert_slice. This function makes the following assumptions : 18932b2ce50fSAbhishek Varma /// 1. tensor.insert_slice has scf.yield as its only user. 18942b2ce50fSAbhishek Varma /// 2. scf.for's corresponding result has only one use. 18952b2ce50fSAbhishek Varma static FailureOr<OpOperand *> 18969bc3102bSYun-Fly getUntiledConsumerFromSlice(RewriterBase &rewriter, 18979bc3102bSYun-Fly tensor::InsertSliceOp candidateSliceOp) { 18982b2ce50fSAbhishek Varma if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) 18992b2ce50fSAbhishek Varma return failure(); 19002b2ce50fSAbhishek Varma Value sliceResult = candidateSliceOp.getResult(); 19012b2ce50fSAbhishek Varma // Step 1. Fetch the corresponding output. 19022b2ce50fSAbhishek Varma OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); 19032b2ce50fSAbhishek Varma unsigned resultNumber = yieldOpOperand.getOperandNumber(); 19042b2ce50fSAbhishek Varma // Step 2. Check containing op is scf.for. 19052b2ce50fSAbhishek Varma Operation *containingOp = candidateSliceOp->getParentOp(); 19062b2ce50fSAbhishek Varma auto forOp = dyn_cast<scf::ForOp>(containingOp); 19072b2ce50fSAbhishek Varma if (!forOp) 19082b2ce50fSAbhishek Varma return failure(); 1909a9ba1b6dSYun-Fly scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front(); 19102b2ce50fSAbhishek Varma 19119bc3102bSYun-Fly return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber); 19122b2ce50fSAbhishek Varma } 19132b2ce50fSAbhishek Varma 19142b2ce50fSAbhishek Varma /// Fetch the first untiled consumer of a scf.forall's result which is yielded 19152b2ce50fSAbhishek Varma /// by a tensor.parallel_insert_slice. 19162b2ce50fSAbhishek Varma static FailureOr<OpOperand *> 19179bc3102bSYun-Fly getUntiledConsumerFromSlice(RewriterBase &rewriter, 19189bc3102bSYun-Fly tensor::ParallelInsertSliceOp candidateSliceOp) { 19192b2ce50fSAbhishek Varma // Step 1. Fetch the corresponding output 19202b2ce50fSAbhishek Varma Value sliceDest = candidateSliceOp.getDest(); 19212b2ce50fSAbhishek Varma auto iterArg = dyn_cast<BlockArgument>(sliceDest); 19222b2ce50fSAbhishek Varma if (!iterArg) 19232b2ce50fSAbhishek Varma return failure(); 19242b2ce50fSAbhishek Varma Operation *containingOp = iterArg.getOwner()->getParentOp(); 19252b2ce50fSAbhishek Varma if (containingOp != candidateSliceOp->getParentOp()->getParentOp()) 19262b2ce50fSAbhishek Varma return failure(); 19272b2ce50fSAbhishek Varma // Step 2. Check that the containing op is scf.forall. 19282b2ce50fSAbhishek Varma auto forallOp = dyn_cast<scf::ForallOp>(containingOp); 19292b2ce50fSAbhishek Varma if (!forallOp) 19302b2ce50fSAbhishek Varma return failure(); 19319bc3102bSYun-Fly unsigned resultNumber = 19329bc3102bSYun-Fly forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)) 19339bc3102bSYun-Fly .getResultNumber(); 19342b2ce50fSAbhishek Varma 19359bc3102bSYun-Fly return getConsumerFromLoopUses(rewriter, containingOp, resultNumber); 19362b2ce50fSAbhishek Varma } 19372b2ce50fSAbhishek Varma 19382b2ce50fSAbhishek Varma /// A utility to fetch an untiled consumer of 19392b2ce50fSAbhishek Varma /// tensor.insert_slice/tensor.parallel_insert_slice. 19409bc3102bSYun-Fly static FailureOr<OpOperand *> 19419bc3102bSYun-Fly getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) { 19422b2ce50fSAbhishek Varma if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) { 19439bc3102bSYun-Fly return getUntiledConsumerFromSlice(rewriter, insertSlice); 19442b2ce50fSAbhishek Varma } else if (auto parallelInsertSlice = 19452b2ce50fSAbhishek Varma dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) { 19469bc3102bSYun-Fly return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice); 19472b2ce50fSAbhishek Varma } else { 19482b2ce50fSAbhishek Varma return failure(); 19492b2ce50fSAbhishek Varma } 19502b2ce50fSAbhishek Varma } 19512b2ce50fSAbhishek Varma 19522b2ce50fSAbhishek Varma /// Implementation of fusing consumer of a single slice by computing the 19532b2ce50fSAbhishek Varma /// slice of the consumer in-place for scf loop. 19542b2ce50fSAbhishek Varma FailureOr<scf::SCFFuseConsumerOfSliceResult> 19552b2ce50fSAbhishek Varma mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, 19562b2ce50fSAbhishek Varma Operation *candidateSliceOp) { 19572b2ce50fSAbhishek Varma if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>( 19582b2ce50fSAbhishek Varma candidateSliceOp)) 19592b2ce50fSAbhishek Varma return failure(); 19602b2ce50fSAbhishek Varma 19612b2ce50fSAbhishek Varma bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp); 19622b2ce50fSAbhishek Varma 19632b2ce50fSAbhishek Varma // 1. Get the consumer of scf.for for the result yielded by 19642b2ce50fSAbhishek Varma // tensor.insert_slice/parallel_insert_slice. 19652b2ce50fSAbhishek Varma FailureOr<OpOperand *> maybeConsumerOpOperand = 19669bc3102bSYun-Fly getUntiledConsumerFromSlice(rewriter, candidateSliceOp); 19672b2ce50fSAbhishek Varma if (failed(maybeConsumerOpOperand)) { 19682b2ce50fSAbhishek Varma return rewriter.notifyMatchFailure(candidateSliceOp, 19692b2ce50fSAbhishek Varma "could not fetch consumer to fuse"); 19702b2ce50fSAbhishek Varma } 19712b2ce50fSAbhishek Varma OpOperand *consumerOpOperand = *maybeConsumerOpOperand; 19722b2ce50fSAbhishek Varma Operation *consumerOp = consumerOpOperand->getOwner(); 19732b2ce50fSAbhishek Varma unsigned operandNumber = consumerOpOperand->getOperandNumber(); 19742b2ce50fSAbhishek Varma unsigned resultNumber = 0; 19752b2ce50fSAbhishek Varma if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) { 19762b2ce50fSAbhishek Varma resultNumber = producerResult.getResultNumber(); 19772b2ce50fSAbhishek Varma } else { 19782b2ce50fSAbhishek Varma return rewriter.notifyMatchFailure( 19792b2ce50fSAbhishek Varma consumerOp, "consumer op's operand doesn't seem to be an OpResult"); 19802b2ce50fSAbhishek Varma } 19812b2ce50fSAbhishek Varma 1982a9ba1b6dSYun-Fly // There are two possible cases regarding `oldLoopOp` here: 1983a9ba1b6dSYun-Fly // 1. single `scf.forall` or `scf.for`. 1984a9ba1b6dSYun-Fly // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the 1985a9ba1b6dSYun-Fly // top-level loop is the outer-most one of these nested loops. 1986a9ba1b6dSYun-Fly LoopLikeOpInterface innerMostLoop = 1987a9ba1b6dSYun-Fly candidateSliceOp->getParentOfType<LoopLikeOpInterface>(); 1988a9ba1b6dSYun-Fly SmallVector<LoopLikeOpInterface> nestedLoops; 19892b2ce50fSAbhishek Varma if (isInsertSliceOp) { 1990a9ba1b6dSYun-Fly nestedLoops = llvm::map_to_vector( 1991a9ba1b6dSYun-Fly getPerfectlyNestedLoopsOutsideOf( 1992a9ba1b6dSYun-Fly cast<scf::ForOp>(innerMostLoop.getOperation())), 1993a9ba1b6dSYun-Fly [](scf::ForOp forOp) { 1994a9ba1b6dSYun-Fly return cast<LoopLikeOpInterface>(forOp.getOperation()); 1995a9ba1b6dSYun-Fly }); 19962b2ce50fSAbhishek Varma } else { 1997a9ba1b6dSYun-Fly nestedLoops = {innerMostLoop}; 19982b2ce50fSAbhishek Varma } 19992b2ce50fSAbhishek Varma 2000a9ba1b6dSYun-Fly LoopLikeOpInterface outerMostLoop = nestedLoops.front(); 2001a9ba1b6dSYun-Fly 20029bc3102bSYun-Fly // Check assumption for loop with `reorderOperations` disabled. 20039bc3102bSYun-Fly if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { 20042b2ce50fSAbhishek Varma return rewriter.notifyMatchFailure( 20059bc3102bSYun-Fly outerMostLoop, "the first user of loop should not dominate any define " 20069bc3102bSYun-Fly "of consumer operand(s)"); 20072b2ce50fSAbhishek Varma } 20082b2ce50fSAbhishek Varma 20092b2ce50fSAbhishek Varma OpBuilder::InsertionGuard g(rewriter); 20102b2ce50fSAbhishek Varma 20112b2ce50fSAbhishek Varma // 2. Check consumer is not using scf loop's output as init. 2012a9ba1b6dSYun-Fly auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); 2013a9ba1b6dSYun-Fly if (!dstOp) 2014a9ba1b6dSYun-Fly return rewriter.notifyMatchFailure(consumerOp, 2015a9ba1b6dSYun-Fly "consumer op is not DPS operation"); 20162b2ce50fSAbhishek Varma SmallVector<Value> dpsInits = 20172b2ce50fSAbhishek Varma llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; }); 2018a9ba1b6dSYun-Fly if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) { 20192b2ce50fSAbhishek Varma return rewriter.notifyMatchFailure( 20202b2ce50fSAbhishek Varma consumerOp, 20212b2ce50fSAbhishek Varma "consumer op taking the result of scf.for as init is not supported"); 20222b2ce50fSAbhishek Varma } 2023a9ba1b6dSYun-Fly SmallVector<Value> newInits = dpsInits; 20242b2ce50fSAbhishek Varma 2025a9ba1b6dSYun-Fly Location loc = outerMostLoop->getLoc(); 20262b2ce50fSAbhishek Varma 20279bc3102bSYun-Fly // 3. Move the whole loop structure right before firstUserOfLoop, the 20289bc3102bSYun-Fly // dominance should be already ensured by `checkAssumptionForLoop`. 20299bc3102bSYun-Fly FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop); 20309bc3102bSYun-Fly if (failed(firstUserOfLoop)) { 20319bc3102bSYun-Fly return rewriter.notifyMatchFailure( 20329bc3102bSYun-Fly outerMostLoop, "could not find the first user of outer most loop"); 20339bc3102bSYun-Fly } 20349bc3102bSYun-Fly rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop); 20352b2ce50fSAbhishek Varma 2036a9ba1b6dSYun-Fly // 4. Set insertion point before terminator op of the loop and create a new 20372b2ce50fSAbhishek Varma // tensor.insert_slice. In the scf.for case this is a clone of the 20382b2ce50fSAbhishek Varma // candidateSliceOp whereas in the scf.forall case this is created from the 20392b2ce50fSAbhishek Varma // operands of tensor.parallel_insert_slice. 20402b2ce50fSAbhishek Varma tensor::InsertSliceOp clonedInsertSliceOp; 20412b2ce50fSAbhishek Varma if (auto sliceOp = 20422b2ce50fSAbhishek Varma dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) { 2043a9ba1b6dSYun-Fly auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation()); 20442b2ce50fSAbhishek Varma rewriter.setInsertionPoint(newForallOp.getTerminator()); 20452b2ce50fSAbhishek Varma clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>( 20462b2ce50fSAbhishek Varma loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(), 20472b2ce50fSAbhishek Varma sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); 20482b2ce50fSAbhishek Varma } else { 20492b2ce50fSAbhishek Varma rewriter.setInsertionPoint(candidateSliceOp); 20502b2ce50fSAbhishek Varma clonedInsertSliceOp = 20512b2ce50fSAbhishek Varma cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp)); 20522b2ce50fSAbhishek Varma } 20532b2ce50fSAbhishek Varma 2054a9ba1b6dSYun-Fly // 5.a. Clone consumer op. 2055a9ba1b6dSYun-Fly auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp)); 20562b2ce50fSAbhishek Varma 2057a9ba1b6dSYun-Fly // 5.b. Replace all uses of the loop result with the result of the cloned 20582b2ce50fSAbhishek Varma // tensor.insert_slice. 20592b2ce50fSAbhishek Varma OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber); 20602b2ce50fSAbhishek Varma rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { 20612b2ce50fSAbhishek Varma operandToReplace.set(clonedInsertSliceOp.getResult()); 20622b2ce50fSAbhishek Varma }); 20632b2ce50fSAbhishek Varma 2064a9ba1b6dSYun-Fly // 6. Perform tiling of the cloned consumer and replace the operand at 20652b2ce50fSAbhishek Varma // `operandNumber` with the source of the cloned tensor.insert_slice op. 20662b2ce50fSAbhishek Varma auto ossSliceOp = 20672b2ce50fSAbhishek Varma cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation()); 20682b2ce50fSAbhishek Varma FailureOr<TilingResult> tileAndFuseResult = 20692b2ce50fSAbhishek Varma tensor::replaceInsertSliceWithTiledConsumer( 20702b2ce50fSAbhishek Varma rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber)); 20712b2ce50fSAbhishek Varma if (failed(tileAndFuseResult)) { 20722b2ce50fSAbhishek Varma return failure(); 20732b2ce50fSAbhishek Varma } 2074a9ba1b6dSYun-Fly auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]); 2075a9ba1b6dSYun-Fly rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber), 20762b2ce50fSAbhishek Varma clonedInsertSliceOp.getSource()); 20772b2ce50fSAbhishek Varma 2078a9ba1b6dSYun-Fly // 7. Reconstruct [nested] loop with new inits. 2079a9ba1b6dSYun-Fly YieldTiledValuesFn newYieldValuesFn = 2080a9ba1b6dSYun-Fly [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, 2081a9ba1b6dSYun-Fly ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, 2082a9ba1b6dSYun-Fly SmallVector<SmallVector<OpFoldResult>> &tiledOffset, 2083a9ba1b6dSYun-Fly SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { 2084a9ba1b6dSYun-Fly OpBuilder::InsertionGuard g(innerRewriter); 2085a9ba1b6dSYun-Fly // 8. Set inner insertPoint right before tiled consumer op. 2086a9ba1b6dSYun-Fly innerRewriter.setInsertionPoint(tiledConsumerOp); 2087a9ba1b6dSYun-Fly 20882b2ce50fSAbhishek Varma SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets(); 20892b2ce50fSAbhishek Varma SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes(); 20902b2ce50fSAbhishek Varma SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides(); 20912b2ce50fSAbhishek Varma 20922b2ce50fSAbhishek Varma // 9. Check all insert stride is 1. 20932b2ce50fSAbhishek Varma if (llvm::any_of(strides, [](OpFoldResult stride) { 20942b2ce50fSAbhishek Varma return !isConstantIntValue(stride, 1); 20952b2ce50fSAbhishek Varma })) { 20962b2ce50fSAbhishek Varma return rewriter.notifyMatchFailure( 20972b2ce50fSAbhishek Varma candidateSliceOp, "containingOp's result yield with stride"); 20982b2ce50fSAbhishek Varma } 20992b2ce50fSAbhishek Varma 21008cc616bcSMax191 // 10. Try to get iter domain position from input position. Use 21014b563458SKunwar Grover // clonedConsumerOp instead of tiledConsumerOp, because the iteration 21024b563458SKunwar Grover // domain may require index computation based on the result size. The 21034b563458SKunwar Grover // sizes and offsets should be the same either way, but using 21044b563458SKunwar Grover // tiledConsumerOp could lead to some chained unnecessary extra index 21054b563458SKunwar Grover // computation. 21062b2ce50fSAbhishek Varma SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes; 21078cc616bcSMax191 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile( 21082b2ce50fSAbhishek Varma rewriter, operandNumber, offsets, sizes, iterDomainOffsets, 21092b2ce50fSAbhishek Varma iterDomainSizes))) { 21102b2ce50fSAbhishek Varma return rewriter.notifyMatchFailure( 21118cc616bcSMax191 clonedConsumerOp, 2112a9ba1b6dSYun-Fly "can't get iter domain position from input position"); 21132b2ce50fSAbhishek Varma } 21142b2ce50fSAbhishek Varma 21152b2ce50fSAbhishek Varma // 11. Try to fetch the offset and size for all results of the cloned 21162b2ce50fSAbhishek Varma // consumer. This would then be used to form the corresponding 21172b2ce50fSAbhishek Varma // tensor.insert_slice/parallel_insert_slice later. 2118a9ba1b6dSYun-Fly unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults(); 21192b2ce50fSAbhishek Varma SmallVector<SmallVector<OpFoldResult>> resultOffsets( 21202b2ce50fSAbhishek Varma totalNumResultsOfConsumer); 2121a9ba1b6dSYun-Fly SmallVector<SmallVector<OpFoldResult>> resultSizes( 2122a9ba1b6dSYun-Fly totalNumResultsOfConsumer); 2123a9ba1b6dSYun-Fly for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) { 2124a9ba1b6dSYun-Fly if (failed(tiledConsumerOp.getResultTilePosition( 21252b2ce50fSAbhishek Varma rewriter, idx, iterDomainOffsets, iterDomainSizes, 21262b2ce50fSAbhishek Varma resultOffsets[idx], resultSizes[idx]))) { 21272b2ce50fSAbhishek Varma return rewriter.notifyMatchFailure( 2128a9ba1b6dSYun-Fly tiledConsumerOp, 21292b2ce50fSAbhishek Varma "can't get result domain position from iter domain position"); 21302b2ce50fSAbhishek Varma } 21312b2ce50fSAbhishek Varma } 21322b2ce50fSAbhishek Varma 2133a9ba1b6dSYun-Fly // 12. Create `extract_slice` for `iter_args` for DPS operation if 2134a9ba1b6dSYun-Fly // necessary. 2135a9ba1b6dSYun-Fly if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>( 2136a9ba1b6dSYun-Fly tiledConsumerOp.getOperation())) { 2137a9ba1b6dSYun-Fly rewriter.setInsertionPoint(tiledDestStyleOp); 2138a9ba1b6dSYun-Fly for (const auto &&[index, newRegionArg] : 2139a9ba1b6dSYun-Fly llvm::enumerate(newRegionIterArgs)) { 2140a9ba1b6dSYun-Fly auto destSlice = rewriter.create<tensor::ExtractSliceOp>( 2141a9ba1b6dSYun-Fly loc, newRegionArg, resultOffsets[index], resultSizes[index], 2142a9ba1b6dSYun-Fly SmallVector<OpFoldResult>(resultOffsets[index].size(), 2143a9ba1b6dSYun-Fly rewriter.getIndexAttr(1))); 2144a9ba1b6dSYun-Fly // Make a copy of index to avoid a capturing structured binding, which 2145a9ba1b6dSYun-Fly // is a C++20 extension. 2146a9ba1b6dSYun-Fly auto dstNumber = index; 2147a9ba1b6dSYun-Fly rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { 2148a9ba1b6dSYun-Fly tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice); 2149a9ba1b6dSYun-Fly }); 2150a9ba1b6dSYun-Fly } 21512b2ce50fSAbhishek Varma } 21522b2ce50fSAbhishek Varma 2153a9ba1b6dSYun-Fly // 13. Prepare tiled offset and sizes for later `insert_slice` creation by 2154a9ba1b6dSYun-Fly // caller. 2155a9ba1b6dSYun-Fly Block *block = rewriter.getInsertionPoint()->getBlock(); 2156a9ba1b6dSYun-Fly rewriter.setInsertionPoint(block->getTerminator()); 2157a9ba1b6dSYun-Fly for (const auto &&[index, result] : 2158a9ba1b6dSYun-Fly llvm::enumerate(tiledConsumerOp->getResults())) { 2159a9ba1b6dSYun-Fly tiledResult.push_back(result); 2160a9ba1b6dSYun-Fly tiledOffset.emplace_back(resultOffsets[index]); 2161a9ba1b6dSYun-Fly tiledSizes.emplace_back(resultSizes[index]); 2162a9ba1b6dSYun-Fly } 2163a9ba1b6dSYun-Fly return success(); 2164a9ba1b6dSYun-Fly }; 2165a9ba1b6dSYun-Fly // 14. Add new inits to [nested] loops. 2166a9ba1b6dSYun-Fly if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits, 2167a9ba1b6dSYun-Fly newYieldValuesFn))) { 2168a9ba1b6dSYun-Fly return rewriter.notifyMatchFailure(tiledConsumerOp, 2169a9ba1b6dSYun-Fly "unable to add new inits to nest loop"); 2170a9ba1b6dSYun-Fly } 2171a9ba1b6dSYun-Fly 21724b563458SKunwar Grover // 15. Replace the result of scf loop and consumer op with new loop's 21734b563458SKunwar Grover // results. 2174a9ba1b6dSYun-Fly 2175a9ba1b6dSYun-Fly for (auto &&[oldResult, newResult] : llvm::zip( 2176a9ba1b6dSYun-Fly consumerOp->getResults(), 2177a9ba1b6dSYun-Fly nestedLoops.front()->getResults().take_back(newInits.size()))) { 21782b2ce50fSAbhishek Varma rewriter.replaceAllUsesWith(oldResult, newResult); 21792b2ce50fSAbhishek Varma } 21802b2ce50fSAbhishek Varma 2181a9ba1b6dSYun-Fly // 16. Need to erase the old scf loop and the cloned consumer op. 21822b2ce50fSAbhishek Varma rewriter.eraseOp(clonedConsumerOp); 21832b2ce50fSAbhishek Varma 21842b2ce50fSAbhishek Varma return scf::SCFFuseConsumerOfSliceResult{ 21852b2ce50fSAbhishek Varma consumerOpOperand, 21862b2ce50fSAbhishek Varma &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)), 21872b2ce50fSAbhishek Varma tileAndFuseResult->tiledOps}; 21882b2ce50fSAbhishek Varma } 21892b2ce50fSAbhishek Varma 21902b2ce50fSAbhishek Varma //===----------------------------------------------------------------------===// 219197f91982SMahesh Ravishankar // lowerToLoopsUsingSCFForOp implementation. 21926f03a10eSMahesh Ravishankar //===----------------------------------------------------------------------===// 21936f03a10eSMahesh Ravishankar 21946f03a10eSMahesh Ravishankar FailureOr<SmallVector<scf::ForOp>> 219597f91982SMahesh Ravishankar mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, 219697f91982SMahesh Ravishankar TilingInterface op) { 21976f03a10eSMahesh Ravishankar // TODO: Handle cases where the op has results if needed. 21986f03a10eSMahesh Ravishankar if (op->getNumResults() > 0) { 21996f03a10eSMahesh Ravishankar return rewriter.notifyMatchFailure( 22006f03a10eSMahesh Ravishankar op, "unable to lower to loops operations with return values"); 22016f03a10eSMahesh Ravishankar } 22026f03a10eSMahesh Ravishankar 220397f91982SMahesh Ravishankar SmallVector<Range> domain = op.getIterationDomain(rewriter); 22046f03a10eSMahesh Ravishankar SmallVector<Value> ivs; 22056f03a10eSMahesh Ravishankar SmallVector<scf::ForOp> loops; 22066f03a10eSMahesh Ravishankar Location loc = op.getLoc(); 22076f03a10eSMahesh Ravishankar for (auto loopRange : domain) { 22086f03a10eSMahesh Ravishankar Value offsetVal = 22096f03a10eSMahesh Ravishankar getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 22106f03a10eSMahesh Ravishankar Value sizeVal = 22116f03a10eSMahesh Ravishankar getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 22126f03a10eSMahesh Ravishankar Value strideVal = 22136f03a10eSMahesh Ravishankar getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 22146f03a10eSMahesh Ravishankar auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 22156f03a10eSMahesh Ravishankar strideVal, ValueRange{}); 22166f03a10eSMahesh Ravishankar loops.push_back(loop); 22176f03a10eSMahesh Ravishankar ivs.push_back(loop.getInductionVar()); 22186f03a10eSMahesh Ravishankar rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 22196f03a10eSMahesh Ravishankar } 22206f03a10eSMahesh Ravishankar if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 22216f03a10eSMahesh Ravishankar return failure(); 22226f03a10eSMahesh Ravishankar } 22236f03a10eSMahesh Ravishankar return loops; 22246f03a10eSMahesh Ravishankar } 2225