14a661602SNicolas Vasilache //===- Transforms.cpp - Linalg transformations as patterns ----------------===// 2307cfdf5SNicolas Vasilache // 3307cfdf5SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4307cfdf5SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information. 5307cfdf5SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6307cfdf5SNicolas Vasilache // 7307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===// 8307cfdf5SNicolas Vasilache // 9307cfdf5SNicolas Vasilache // This file implements logic and helpers to expose Linalg transforms as rewrite 10307cfdf5SNicolas Vasilache // patterns. 11307cfdf5SNicolas Vasilache // 12307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===// 13307cfdf5SNicolas Vasilache 14307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15eda6f907SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h" 16abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1736550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 18b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h" 19307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Utils/Utils.h" 208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h" 21060208b4SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 220edb4127SLei Zhang #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 23644f0f83SHanhan Wang #include "mlir/Dialect/Tensor/Utils/Utils.h" 24644f0f83SHanhan Wang #include "mlir/Dialect/Utils/IndexingUtils.h" 25d624c1b5SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h" 26307cfdf5SNicolas Vasilache #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 2799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 28307cfdf5SNicolas Vasilache #include "mlir/IR/AffineExpr.h" 29307cfdf5SNicolas Vasilache #include "mlir/IR/Matchers.h" 30307cfdf5SNicolas Vasilache #include "mlir/Pass/Pass.h" 31307cfdf5SNicolas Vasilache #include "mlir/Support/LLVM.h" 32b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 333747eb9cSNicolas Vasilache #include "llvm/ADT/ScopeExit.h" 348faf35c0SMatthias Springer #include "llvm/ADT/TypeSwitch.h" 35307cfdf5SNicolas Vasilache #include "llvm/Support/Debug.h" 36307cfdf5SNicolas Vasilache #include "llvm/Support/raw_ostream.h" 37307cfdf5SNicolas Vasilache #include <type_traits> 381fc096afSMehdi Amini #include <utility> 39307cfdf5SNicolas Vasilache 40307cfdf5SNicolas Vasilache #define DEBUG_TYPE "linalg-transforms" 41307cfdf5SNicolas Vasilache 42307cfdf5SNicolas Vasilache using namespace mlir; 43307cfdf5SNicolas Vasilache using namespace mlir::linalg; 44307cfdf5SNicolas Vasilache 4556ce65e2SNicolas Vasilache #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 4602371c5dSNicolas Vasilache #define DBGSNL() (llvm::dbgs() << "\n") 473110e7b0SNicolas Vasilache 48bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===// 49bb2ae985SNicolas Vasilache // Transformations exposed as functional-style API calls. 50bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===// 51bb2ae985SNicolas Vasilache 52bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===// 53bb2ae985SNicolas Vasilache // peelLoop transformation. 54bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===// 55bb2ae985SNicolas Vasilache 56bb2ae985SNicolas Vasilache /// Try to peel and canonicalize loop `op` and return the new result. 57bb2ae985SNicolas Vasilache /// Also applies affine_min/max bounds simplification on the fly where relevant. 582190f8a8SMatthias Springer // TODO: Add support for scf.parallel and affine.for loops. 59d4c4e491SNicolas Vasilache SmallVector<Value> mlir::linalg::peelLoop(RewriterBase &rewriter, 60d4c4e491SNicolas Vasilache Operation *op) { 618faf35c0SMatthias Springer return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 628faf35c0SMatthias Springer .Case<scf::ForOp>([&](scf::ForOp forOp) { 638faf35c0SMatthias Springer scf::ForOp partialIteration; 64bb2ae985SNicolas Vasilache if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp, 658faf35c0SMatthias Springer partialIteration))) 668faf35c0SMatthias Springer return partialIteration->getResults(); 678faf35c0SMatthias Springer assert(!partialIteration && "expected that loop was not peeled"); 688faf35c0SMatthias Springer return forOp->getResults(); 698faf35c0SMatthias Springer }) 708faf35c0SMatthias Springer .Default([&](Operation *op) { return op->getResults(); }); 718faf35c0SMatthias Springer } 728faf35c0SMatthias Springer 73bb2ae985SNicolas Vasilache /// Peel 'loops' and applies affine_min/max bounds simplification on the fly 74bb2ae985SNicolas Vasilache /// where relevant. 759a79b1b0SDiego Caballero void mlir::linalg::peelLoops(RewriterBase &rewriter, 769a79b1b0SDiego Caballero ArrayRef<scf::ForOp> loops) { 77d4c4e491SNicolas Vasilache for (auto loopOp : loops) 78d4c4e491SNicolas Vasilache peelLoop(rewriter, loopOp); 794a661602SNicolas Vasilache } 804a661602SNicolas Vasilache 81bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===// 82bb2ae985SNicolas Vasilache // pack transformation. 83bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===// 84bb2ae985SNicolas Vasilache 85bb2ae985SNicolas Vasilache #ifndef NDEBUG 86bb2ae985SNicolas Vasilache /// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim). 87bb2ae985SNicolas Vasilache static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) { 88bb2ae985SNicolas Vasilache bool found = false; 89bb2ae985SNicolas Vasilache for (AffineExpr e : map.getResults()) { 90bb2ae985SNicolas Vasilache if (!e.isFunctionOfDim(dim)) 91bb2ae985SNicolas Vasilache continue; 92bb2ae985SNicolas Vasilache if (found) 93bb2ae985SNicolas Vasilache return false; 94bb2ae985SNicolas Vasilache found = true; 95bb2ae985SNicolas Vasilache } 96bb2ae985SNicolas Vasilache return true; 97bb2ae985SNicolas Vasilache } 98bb2ae985SNicolas Vasilache #endif // NDEBUG 99bb2ae985SNicolas Vasilache 100bb2ae985SNicolas Vasilache /// Return the index of the first result of `map` that is a function of 101bb2ae985SNicolas Vasilache /// AffineDimExpr(dim), std::nullopt otherwise. 102bb2ae985SNicolas Vasilache static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map, 103bb2ae985SNicolas Vasilache int64_t dim) { 104bb2ae985SNicolas Vasilache for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 105bb2ae985SNicolas Vasilache AffineExpr expr = map.getResult(i); 106bb2ae985SNicolas Vasilache if (!expr.isFunctionOfDim(dim)) 107bb2ae985SNicolas Vasilache continue; 108bb2ae985SNicolas Vasilache return i; 109bb2ae985SNicolas Vasilache } 110bb2ae985SNicolas Vasilache return std::nullopt; 111bb2ae985SNicolas Vasilache } 112bb2ae985SNicolas Vasilache 113bb2ae985SNicolas Vasilache /// Perform one step of packing of a LinalgOp's metadata along `dim` into the 114bb2ae985SNicolas Vasilache /// `newDim` at `iteratorTypes.size()` by: 115bb2ae985SNicolas Vasilache /// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`. 116bb2ae985SNicolas Vasilache /// 2. Appending a `newDim` to the domain of every indexing map. 117bb2ae985SNicolas Vasilache /// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing 118bb2ae985SNicolas Vasilache /// by potentially adding a `newDim` result to `map`. 119bb2ae985SNicolas Vasilache /// The preserved invariant is that `iteratorTypes.size()` is always equal to 120bb2ae985SNicolas Vasilache /// `map.getNumDims()` for every map in `indexingMaps`. 121bb2ae985SNicolas Vasilache /// 122bb2ae985SNicolas Vasilache /// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update. 123bb2ae985SNicolas Vasilache /// Return a vector that records the optional packing for each operand. 124bb2ae985SNicolas Vasilache /// Return failure if the packed indexing cannot be represented with a LinalgOp. 125bb2ae985SNicolas Vasilache /// 126bb2ae985SNicolas Vasilache /// Further details: 127bb2ae985SNicolas Vasilache /// ================ 128bb2ae985SNicolas Vasilache /// The current implementation of packing (i.e. data tiling) consists of 129bb2ae985SNicolas Vasilache /// rewriting a linearized strip-mined form into a higher-dimensional access. 130bb2ae985SNicolas Vasilache /// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite 131bb2ae985SNicolas Vasilache /// `I` into `4 * i + ii`, where `0 <= ii < 4`. 132bb2ae985SNicolas Vasilache /// The access is further rewritten as `A[i][f(j, k, l)][ii]`. 133bb2ae985SNicolas Vasilache /// 134bb2ae985SNicolas Vasilache /// This rewrite into higher dimensional access is not possible for general 135bb2ae985SNicolas Vasilache /// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr: 136bb2ae985SNicolas Vasilache /// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we 137bb2ae985SNicolas Vasilache /// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`. 138bb2ae985SNicolas Vasilache /// The rewrite of the access would be a form not representable in Linalg: 139bb2ae985SNicolas Vasilache /// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`. 140bb2ae985SNicolas Vasilache /// Note however that as `J` and `ii` iterate, the accesses do not have a 141bb2ae985SNicolas Vasilache /// particular alignment, so packing does not achieve alignment in this case 142bb2ae985SNicolas Vasilache /// 143bb2ae985SNicolas Vasilache /// In the future, we may want to consider a mixed-form that allows some 144bb2ae985SNicolas Vasilache /// alignment in the presence of multiple accesses: 145bb2ae985SNicolas Vasilache /// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]` 146bb2ae985SNicolas Vasilache /// And would rewrite accesses as: 147bb2ae985SNicolas Vasilache /// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]` 148bb2ae985SNicolas Vasilache static FailureOr<SmallVector<std::optional<int64_t>>> 149bb2ae985SNicolas Vasilache packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps, 150bb2ae985SNicolas Vasilache SmallVectorImpl<utils::IteratorType> &iteratorTypes, 151bb2ae985SNicolas Vasilache int64_t dim) { 152bb2ae985SNicolas Vasilache int64_t newDim = iteratorTypes.size(); 153bb2ae985SNicolas Vasilache iteratorTypes.push_back(iteratorTypes[dim]); 154bb2ae985SNicolas Vasilache 155bb2ae985SNicolas Vasilache SmallVector<std::optional<int64_t>> packedDimPerIndexingMap( 156bb2ae985SNicolas Vasilache indexingMaps.size(), std::nullopt); 157bb2ae985SNicolas Vasilache SmallVector<AffineMap> newMaps; 158bb2ae985SNicolas Vasilache for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e; 159bb2ae985SNicolas Vasilache ++operandIdx) { 160bb2ae985SNicolas Vasilache AffineMap map = indexingMaps[operandIdx]; 161bb2ae985SNicolas Vasilache 162bb2ae985SNicolas Vasilache // Add the `newDim` to map whatever the case. 163bb2ae985SNicolas Vasilache assert(map.getNumDims() == newDim && "num dims invariant violation"); 164bb2ae985SNicolas Vasilache map = map.shiftDims(1, newDim); 165bb2ae985SNicolas Vasilache 166bb2ae985SNicolas Vasilache // Get the at-most-1 index of the result that is a function of `dim`. 167bb2ae985SNicolas Vasilache // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which 168bb2ae985SNicolas Vasilache // logically chunks dimension `dim` into `K * dim + newDim`, where the 169bb2ae985SNicolas Vasilache // packing factor `K` is specified separately. 170bb2ae985SNicolas Vasilache assert(hasAtMostOneResultFunctionOfDim(map, dim) && 171bb2ae985SNicolas Vasilache "num results invariant violation"); 172bb2ae985SNicolas Vasilache auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim); 173bb2ae985SNicolas Vasilache if (!maybeOperandDimensionToPack.has_value()) { 174bb2ae985SNicolas Vasilache newMaps.push_back(map); 175bb2ae985SNicolas Vasilache continue; 176bb2ae985SNicolas Vasilache } 177bb2ae985SNicolas Vasilache 178bb2ae985SNicolas Vasilache // We can only pack AffineDimExpr atm. 1791609f1c2Slong.chen if (!isa<AffineDimExpr>(map.getResult(maybeOperandDimensionToPack.value()))) 180bb2ae985SNicolas Vasilache return failure(); 181bb2ae985SNicolas Vasilache 182bb2ae985SNicolas Vasilache // Add `newDim` to the results of the map. 183bb2ae985SNicolas Vasilache map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim), 184bb2ae985SNicolas Vasilache map.getNumResults()); 185bb2ae985SNicolas Vasilache newMaps.push_back(map); 186bb2ae985SNicolas Vasilache 187bb2ae985SNicolas Vasilache // Record the that `operandIdx` is packed. 188bb2ae985SNicolas Vasilache packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack; 189bb2ae985SNicolas Vasilache } 190bb2ae985SNicolas Vasilache indexingMaps = newMaps; 191bb2ae985SNicolas Vasilache 192bb2ae985SNicolas Vasilache return packedDimPerIndexingMap; 193bb2ae985SNicolas Vasilache } 194bb2ae985SNicolas Vasilache 195bb2ae985SNicolas Vasilache namespace { 196bb2ae985SNicolas Vasilache 197bb2ae985SNicolas Vasilache /// Helper struct to encode packing along one dimension of a LinalgOp. 198bb2ae985SNicolas Vasilache struct PackedOperandsDim { 199bb2ae985SNicolas Vasilache OpFoldResult packedSize; 200bb2ae985SNicolas Vasilache SmallVector<std::optional<int64_t>> packedDimForEachOperand; 201bb2ae985SNicolas Vasilache }; 202bb2ae985SNicolas Vasilache 203bb2ae985SNicolas Vasilache /// Helper struct to encode packing along all dimensions of a LinalgOp. 204bb2ae985SNicolas Vasilache struct PackedOperandsDimList { 2053af5ab21SMehdi Amini void pushBack(PackedOperandsDim &&packedOperandsDims) { 206bb2ae985SNicolas Vasilache spec.emplace_back(packedOperandsDims); 207bb2ae985SNicolas Vasilache } 208bb2ae985SNicolas Vasilache /// Return all the dims that have been packed for operand @ `operandPos`. 209bb2ae985SNicolas Vasilache SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos); 210bb2ae985SNicolas Vasilache /// Return all the pack sizes by which an operand @ `operandPos` is packed. 211bb2ae985SNicolas Vasilache SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos); 212bb2ae985SNicolas Vasilache 213bb2ae985SNicolas Vasilache private: 214bb2ae985SNicolas Vasilache SmallVector<PackedOperandsDim> spec; 215bb2ae985SNicolas Vasilache }; 216bb2ae985SNicolas Vasilache 217bb2ae985SNicolas Vasilache } // namespace 218bb2ae985SNicolas Vasilache 219ddcc5072SHanhan Wang FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, 220d6590c1bSZhuoran Yin tensor::PackOp packOp, 221d6590c1bSZhuoran Yin bool lowerPadLikeWithInsertSlice) { 222ddcc5072SHanhan Wang // 1. Filter out NYI cases. 223ddcc5072SHanhan Wang auto packedTensorType = 2245550c821STres Popp cast<RankedTensorType>(packOp->getResultTypes().front()); 2259d3057c1SHanhan Wang if (llvm::any_of(packOp.getStaticInnerTiles(), 2269d3057c1SHanhan Wang [](int64_t size) { return ShapedType::isDynamic(size); })) { 227ddcc5072SHanhan Wang return rewriter.notifyMatchFailure( 228ddcc5072SHanhan Wang packOp, 229ddcc5072SHanhan Wang "non-static shape NYI, needs a more powerful tensor.expand_shape op"); 230ddcc5072SHanhan Wang } 231ddcc5072SHanhan Wang 232ddcc5072SHanhan Wang Location loc = packOp->getLoc(); 233ddcc5072SHanhan Wang OpBuilder::InsertionGuard g(rewriter); 234ddcc5072SHanhan Wang rewriter.setInsertionPoint(packOp); 235ddcc5072SHanhan Wang 2366f87b50bSHanhan Wang // 2. Compute the permutation vector to shuffle packed shape into the shape 2377880b2c8SMax191 // before any outer or inner permutations have been applied. 238ddcc5072SHanhan Wang PackingMetadata packingMetadata = computePackingMetadata( 239ddcc5072SHanhan Wang packedTensorType.getRank(), packOp.getInnerDimsPos()); 2407880b2c8SMax191 SmallVector<int64_t> packedToStripMinedShapePerm = 241adf838daSBalaji V. Iyer tensor::getPackInverseDestPerm(packOp); 2426f87b50bSHanhan Wang 243ddcc5072SHanhan Wang // 3. Compute the stripMinedShape: this is the packed shape before any outer 244ddcc5072SHanhan Wang // or inner permutations have been applied. 245ddcc5072SHanhan Wang SmallVector<int64_t> stripMinedShape(packedTensorType.getShape()); 2466f87b50bSHanhan Wang applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm); 247ddcc5072SHanhan Wang 248ddcc5072SHanhan Wang // 4. Pad the source of packOp to a shape we can expand into stripMinedShape. 2499d3057c1SHanhan Wang SmallVector<OpFoldResult> lows(packOp.getSourceRank(), 2509d3057c1SHanhan Wang rewriter.getIndexAttr(0)); 2519d3057c1SHanhan Wang SmallVector<OpFoldResult> highs(packOp.getSourceRank(), 2529d3057c1SHanhan Wang rewriter.getIndexAttr(0)); 2539d3057c1SHanhan Wang for (auto [pos, innerSize] : 2549d3057c1SHanhan Wang llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) { 25558e4231bSHanhan Wang int outerPos = 25658e4231bSHanhan Wang packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]]; 2576596b0ddSMatthias Springer OpFoldResult origSize = 2586596b0ddSMatthias Springer tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos); 2596596b0ddSMatthias Springer OpFoldResult outerSize = 2606596b0ddSMatthias Springer tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos); 26158e4231bSHanhan Wang AffineExpr s0, d0, d1; 26258e4231bSHanhan Wang bindDims(rewriter.getContext(), d0, d1); 2639d3057c1SHanhan Wang bindSymbols(rewriter.getContext(), s0); 26458e4231bSHanhan Wang auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1); 26558e4231bSHanhan Wang highs[pos] = affine::makeComposedFoldedAffineApply( 26658e4231bSHanhan Wang rewriter, loc, map, {outerSize, origSize, innerSize}); 2679d3057c1SHanhan Wang } 268ddcc5072SHanhan Wang RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType( 269ddcc5072SHanhan Wang RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), 270ddcc5072SHanhan Wang packingMetadata.reassociations); 271ddcc5072SHanhan Wang Value paddingValue = packOp.getPaddingValue(); 272ddcc5072SHanhan Wang if (!paddingValue) { 273ddcc5072SHanhan Wang paddingValue = rewriter.create<arith::ConstantOp>( 274ddcc5072SHanhan Wang loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed))); 275ddcc5072SHanhan Wang } 276ddcc5072SHanhan Wang auto padOp = 2779d3057c1SHanhan Wang rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows, 2789d3057c1SHanhan Wang highs, paddingValue, /*nofold=*/false); 279ddcc5072SHanhan Wang 280ddcc5072SHanhan Wang LLVM_DEBUG( 281ddcc5072SHanhan Wang DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, 282ddcc5072SHanhan Wang DBGS() << "insertPositions: "); 2836f87b50bSHanhan Wang DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions, 2846f87b50bSHanhan Wang DBGS() << "outerPositions: "); 285ddcc5072SHanhan Wang DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), 286ddcc5072SHanhan Wang DBGS() << "packedShape: "); 287ddcc5072SHanhan Wang DBGSNL(); 2886f87b50bSHanhan Wang llvm::interleaveComma(packedToStripMinedShapePerm, 2896f87b50bSHanhan Wang DBGS() << "packedToStripMinedShapePerm: "); 290ddcc5072SHanhan Wang DBGSNL(); llvm::interleaveComma( 291ddcc5072SHanhan Wang packingMetadata.reassociations, DBGS() << "reassociations: ", 292ddcc5072SHanhan Wang [&](ReassociationIndices ri) { 293ddcc5072SHanhan Wang llvm::interleaveComma(ri, llvm::dbgs() << "|"); 294ddcc5072SHanhan Wang }); 295ddcc5072SHanhan Wang DBGSNL(); 296ddcc5072SHanhan Wang llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); 297ddcc5072SHanhan Wang DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); 298ddcc5072SHanhan Wang 299d6590c1bSZhuoran Yin if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) { 300f32427e0SSpenser Bauman // Pack ops which operate as simple pads may not produce legal 301f32427e0SSpenser Bauman // tensor.insert_slice operations when the packed type does not rank reduce 302f32427e0SSpenser Bauman // to the padded type. 303f32427e0SSpenser Bauman SliceVerificationResult rankReduces = 304f32427e0SSpenser Bauman isRankReducedType(packedTensorType, padOp.getResultType()); 305f32427e0SSpenser Bauman 306f32427e0SSpenser Bauman if (rankReduces == SliceVerificationResult::Success) { 307ddcc5072SHanhan Wang // This pack is just a plain pad. 308ddcc5072SHanhan Wang // Just insert the pad in the higher ranked tensor. 309ddcc5072SHanhan Wang // Offsets. 3107880b2c8SMax191 SmallVector<OpFoldResult> zeros(packOp.getDestRank(), 3117880b2c8SMax191 rewriter.getIndexAttr(0)); 312ddcc5072SHanhan Wang // Strides. 3137880b2c8SMax191 SmallVector<OpFoldResult> ones(packOp.getDestRank(), 3147880b2c8SMax191 rewriter.getIndexAttr(1)); 315ddcc5072SHanhan Wang SmallVector<OpFoldResult> sizes = 316be6d96e9SMatthias Springer tensor::getMixedSizes(rewriter, loc, packOp.getDest()); 317ddcc5072SHanhan Wang 318ddcc5072SHanhan Wang auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>( 319e982d7fdSMax191 loc, /*source=*/padOp, /*dest=*/packOp.getDest(), 320e982d7fdSMax191 /*offsets=*/zeros, sizes, /*strides=*/ones); 321ddcc5072SHanhan Wang 322ddcc5072SHanhan Wang LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL();); 323ddcc5072SHanhan Wang 324ddcc5072SHanhan Wang rewriter.replaceOp(packOp, insertSliceOp->getResults()); 325ddcc5072SHanhan Wang 326ddcc5072SHanhan Wang return LowerPackResult{padOp, /*reshapeOp=*/nullptr, 327ddcc5072SHanhan Wang /*transposeOp=*/nullptr}; 328ddcc5072SHanhan Wang } 329f32427e0SSpenser Bauman } 33097069a86SGaurav Shukla 331ddcc5072SHanhan Wang // 5. Expand from the padded result to the stripMinedShape. 33297069a86SGaurav Shukla auto expandShapeResultType = 33397069a86SGaurav Shukla RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); 334ddcc5072SHanhan Wang auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>( 33597069a86SGaurav Shukla loc, expandShapeResultType, padOp.getResult(), 33697069a86SGaurav Shukla packingMetadata.reassociations); 337ddcc5072SHanhan Wang 338ddcc5072SHanhan Wang // 6. Transpose stripMinedShape to packedShape. 3396f87b50bSHanhan Wang SmallVector<int64_t> transpPerm = 3406f87b50bSHanhan Wang invertPermutationVector(packedToStripMinedShapePerm); 341ddcc5072SHanhan Wang auto transposeOp = rewriter.create<linalg::TransposeOp>( 3426f87b50bSHanhan Wang loc, reshapeOp.getResult(), packOp.getDest(), transpPerm); 343ddcc5072SHanhan Wang 344ddcc5072SHanhan Wang LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); 345ddcc5072SHanhan Wang DBGS() << "reshape op: " << reshapeOp; DBGSNL(); 3466f87b50bSHanhan Wang llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: "); 347ddcc5072SHanhan Wang DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); 348ddcc5072SHanhan Wang 349ddcc5072SHanhan Wang // 7. Replace packOp by transposeOp. 350ddcc5072SHanhan Wang rewriter.replaceOp(packOp, transposeOp->getResults()); 351ddcc5072SHanhan Wang 352ddcc5072SHanhan Wang return LowerPackResult{padOp, reshapeOp, transposeOp}; 353ddcc5072SHanhan Wang } 354ddcc5072SHanhan Wang 355d6590c1bSZhuoran Yin FailureOr<LowerUnPackOpResult> 356d6590c1bSZhuoran Yin linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp, 357d6590c1bSZhuoran Yin bool lowerUnpadLikeWithExtractSlice) { 358ddcc5072SHanhan Wang Location loc = unPackOp->getLoc(); 359ddcc5072SHanhan Wang OpBuilder::InsertionGuard g(rewriter); 360ddcc5072SHanhan Wang rewriter.setInsertionPoint(unPackOp); 361ddcc5072SHanhan Wang 362b26ee975Ssrcarroll RankedTensorType packedTensorType = unPackOp.getSourceType(); 363ddcc5072SHanhan Wang int64_t packedRank = packedTensorType.getRank(); 364ddcc5072SHanhan Wang 365ddcc5072SHanhan Wang OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); 3665550c821STres Popp auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType()); 367d6590c1bSZhuoran Yin if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) { 368ddcc5072SHanhan Wang // This unpack is just a plain unpad. 369ddcc5072SHanhan Wang // Just extract the slice from the higher ranked tensor. 370ddcc5072SHanhan Wang ArrayRef<int64_t> destShape = destTensorType.getShape(); 371ddcc5072SHanhan Wang // The inner dimensions stay the same as the destination tensor, but the 372ddcc5072SHanhan Wang // outer ones are additional 1s. 373ddcc5072SHanhan Wang SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one); 374be6d96e9SMatthias Springer sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest())); 375ddcc5072SHanhan Wang 376ddcc5072SHanhan Wang auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>( 377ddcc5072SHanhan Wang loc, destTensorType, unPackOp.getSource(), 378ddcc5072SHanhan Wang SmallVector<OpFoldResult>(packedRank, zero), sizes, 379ddcc5072SHanhan Wang SmallVector<OpFoldResult>(packedRank, one)); 380ddcc5072SHanhan Wang 381ddcc5072SHanhan Wang rewriter.replaceOp(unPackOp, extractSliceOp->getResults()); 382ddcc5072SHanhan Wang 383ddcc5072SHanhan Wang return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr, 384ddcc5072SHanhan Wang /*reshapeOp=*/nullptr, extractSliceOp}; 385ddcc5072SHanhan Wang } 386ddcc5072SHanhan Wang 3875b2f7a19SRyan Holt // 1. Compute the permutation vector to shuffle packed shape into the shape 3885b2f7a19SRyan Holt // before any outer or inner permutations have been applied. 3895b2f7a19SRyan Holt PackingMetadata packingMetadata; 3905b2f7a19SRyan Holt SmallVector<int64_t> packedToStripMinedShapePerm = 3915b2f7a19SRyan Holt tensor::getUnPackInverseSrcPerm(unPackOp, packingMetadata); 3925b2f7a19SRyan Holt 3935b2f7a19SRyan Holt // 2. Compute the stripMinedShape: this is the packed shape without outer and 394ddcc5072SHanhan Wang // inner permutations. 395ddcc5072SHanhan Wang SmallVector<int64_t> stripMinedShape(packedTensorType.getShape()); 3965b2f7a19SRyan Holt applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm); 397ddcc5072SHanhan Wang 3985b2f7a19SRyan Holt // 3. Transpose packedShape to stripMinedShape. 399ddcc5072SHanhan Wang RankedTensorType stripMinedTensorType = 400ddcc5072SHanhan Wang RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); 401ddcc5072SHanhan Wang RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( 402ddcc5072SHanhan Wang stripMinedTensorType, packingMetadata.reassociations); 403b26ee975Ssrcarroll 4045b2f7a19SRyan Holt // Get dynamic dims from input tensor based on packedToStripMinedShapePerm 405b26ee975Ssrcarroll // permutation. 406b26ee975Ssrcarroll SmallVector<OpFoldResult, 4> dims = 407b26ee975Ssrcarroll tensor::getMixedSizes(rewriter, loc, unPackOp.getSource()); 4085b2f7a19SRyan Holt applyPermutationToVector(dims, packedToStripMinedShapePerm); 409b26ee975Ssrcarroll auto emptyOp = rewriter.create<tensor::EmptyOp>( 410b26ee975Ssrcarroll loc, dims, stripMinedTensorType.getElementType()); 411ddcc5072SHanhan Wang auto transposeOp = rewriter.create<linalg::TransposeOp>( 4125b2f7a19SRyan Holt loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm); 413ddcc5072SHanhan Wang 414ddcc5072SHanhan Wang LLVM_DEBUG( 415ddcc5072SHanhan Wang DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, 416ddcc5072SHanhan Wang DBGS() << "insertPositions: "); 417ddcc5072SHanhan Wang DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), 418ddcc5072SHanhan Wang DBGS() << "packedShape: "); 419ddcc5072SHanhan Wang DBGSNL(); 4205b2f7a19SRyan Holt llvm::interleaveComma(packedToStripMinedShapePerm, 4215b2f7a19SRyan Holt DBGS() << "packedToStripMinedShapePerm: "); 422ddcc5072SHanhan Wang DBGSNL(); llvm::interleaveComma( 423ddcc5072SHanhan Wang packingMetadata.reassociations, DBGS() << "reassociations: ", 424ddcc5072SHanhan Wang [&](ReassociationIndices ri) { 425ddcc5072SHanhan Wang llvm::interleaveComma(ri, llvm::dbgs() << "|"); 426ddcc5072SHanhan Wang }); 427ddcc5072SHanhan Wang DBGSNL(); 428ddcc5072SHanhan Wang llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); 429ddcc5072SHanhan Wang DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); 430ddcc5072SHanhan Wang 4315b2f7a19SRyan Holt // 4. Collapse from the stripMinedShape to the padded result. 432ddcc5072SHanhan Wang auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>( 433ddcc5072SHanhan Wang loc, collapsedType, transposeOp->getResult(0), 434ddcc5072SHanhan Wang packingMetadata.reassociations); 435ddcc5072SHanhan Wang 4365b2f7a19SRyan Holt // 5. ExtractSlice. 437ddcc5072SHanhan Wang int64_t destRank = destTensorType.getRank(); 438ddcc5072SHanhan Wang auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>( 439ddcc5072SHanhan Wang loc, destTensorType, reshapeOp->getResult(0), 440ddcc5072SHanhan Wang SmallVector<OpFoldResult>(destRank, zero), 4417050ff46Sqcolombet tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()), 442ddcc5072SHanhan Wang SmallVector<OpFoldResult>(destRank, one)); 443ddcc5072SHanhan Wang 4445b2f7a19SRyan Holt // 6. Inject a copy to preserve DPS. 445d2f2ef84SLorenzo Chelini auto copyOp = rewriter.create<linalg::CopyOp>( 446d2f2ef84SLorenzo Chelini loc, extractSliceOp->getResult(0), unPackOp.getDest()); 447d2f2ef84SLorenzo Chelini 4485b2f7a19SRyan Holt // 7. Replace unPackOp by copyOp. 449d2f2ef84SLorenzo Chelini rewriter.replaceOp(unPackOp, copyOp->getResults()); 450ddcc5072SHanhan Wang 451ddcc5072SHanhan Wang return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp}; 452ddcc5072SHanhan Wang } 453ddcc5072SHanhan Wang 454bb2ae985SNicolas Vasilache SmallVector<int64_t> 455bb2ae985SNicolas Vasilache PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) { 456bb2ae985SNicolas Vasilache SmallVector<int64_t> res; 457c0fe2b89SMehdi Amini for (auto &i : spec) { 458c0fe2b89SMehdi Amini if (!i.packedDimForEachOperand[operandPos].has_value()) 459bb2ae985SNicolas Vasilache continue; 460c0fe2b89SMehdi Amini res.push_back(i.packedDimForEachOperand[operandPos].value()); 461bb2ae985SNicolas Vasilache } 462bb2ae985SNicolas Vasilache return res; 463bb2ae985SNicolas Vasilache } 464bb2ae985SNicolas Vasilache 465bb2ae985SNicolas Vasilache SmallVector<OpFoldResult> 466bb2ae985SNicolas Vasilache PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) { 467bb2ae985SNicolas Vasilache SmallVector<OpFoldResult> res; 468c0fe2b89SMehdi Amini for (auto &i : spec) { 469c0fe2b89SMehdi Amini if (!i.packedDimForEachOperand[operandPos].has_value()) 470bb2ae985SNicolas Vasilache continue; 471c0fe2b89SMehdi Amini res.push_back(i.packedSize); 472bb2ae985SNicolas Vasilache } 473bb2ae985SNicolas Vasilache return res; 474bb2ae985SNicolas Vasilache } 475bb2ae985SNicolas Vasilache 476bb2ae985SNicolas Vasilache /// Implement packing of a single LinalgOp by performing packing by 477bb2ae985SNicolas Vasilache /// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator. 478bb2ae985SNicolas Vasilache /// Return the packed Linalg op on success, failure otherwise. 479bb2ae985SNicolas Vasilache FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, 480bb2ae985SNicolas Vasilache linalg::LinalgOp linalgOp, 481bb2ae985SNicolas Vasilache ArrayRef<OpFoldResult> packedSizes) { 482bb2ae985SNicolas Vasilache if (packedSizes.size() != linalgOp.getNumLoops()) { 483bb2ae985SNicolas Vasilache return rewriter.notifyMatchFailure(linalgOp, 484bb2ae985SNicolas Vasilache "incorrect number of pack sizes"); 485bb2ae985SNicolas Vasilache } 486bb2ae985SNicolas Vasilache 487bb2ae985SNicolas Vasilache Location loc = linalgOp->getLoc(); 488bb2ae985SNicolas Vasilache SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); 489bb2ae985SNicolas Vasilache SmallVector<utils::IteratorType> iteratorTypes = 490bb2ae985SNicolas Vasilache linalgOp.getIteratorTypesArray(); 491bb2ae985SNicolas Vasilache LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"; 492bb2ae985SNicolas Vasilache llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL(); 493bb2ae985SNicolas Vasilache llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); 494bb2ae985SNicolas Vasilache DBGSNL();); 495bb2ae985SNicolas Vasilache 496bb2ae985SNicolas Vasilache SmallVector<tensor::PackOp> packOps; 497bb2ae985SNicolas Vasilache SmallVector<tensor::UnPackOp> unPackOps; 498bb2ae985SNicolas Vasilache // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i]. 499bb2ae985SNicolas Vasilache PackedOperandsDimList listOfPackedOperandsDim; 500bb2ae985SNicolas Vasilache for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) { 501bb2ae985SNicolas Vasilache std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]); 502bb2ae985SNicolas Vasilache // Skip tile sizes explicitly set to 0. 503bb2ae985SNicolas Vasilache if (maybeConstant.has_value() && maybeConstant.value() == 0) 504bb2ae985SNicolas Vasilache continue; 505bb2ae985SNicolas Vasilache 506bb2ae985SNicolas Vasilache PackedOperandsDim packedOperandsDims; 507bb2ae985SNicolas Vasilache packedOperandsDims.packedSize = packedSizes[i]; 508bb2ae985SNicolas Vasilache FailureOr<SmallVector<std::optional<int64_t>>> 509bb2ae985SNicolas Vasilache maybePackedDimForEachOperand = 510bb2ae985SNicolas Vasilache packLinalgMetadataOnce(indexingMaps, iteratorTypes, i); 511bb2ae985SNicolas Vasilache if (failed(maybePackedDimForEachOperand)) 512bb2ae985SNicolas Vasilache return failure(); 513bb2ae985SNicolas Vasilache packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand; 5143af5ab21SMehdi Amini listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims)); 515bb2ae985SNicolas Vasilache 516bb2ae985SNicolas Vasilache LLVM_DEBUG( 517bb2ae985SNicolas Vasilache DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i] 518bb2ae985SNicolas Vasilache << "\n"; 519bb2ae985SNicolas Vasilache llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL(); 520bb2ae985SNicolas Vasilache llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL(); 521bb2ae985SNicolas Vasilache llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand, 522bb2ae985SNicolas Vasilache DBGS() << "packedDimForEachOperand: "); 523bb2ae985SNicolas Vasilache DBGSNL();); 524bb2ae985SNicolas Vasilache } 525bb2ae985SNicolas Vasilache 526bb2ae985SNicolas Vasilache // Step 2. Propagate packing to all LinalgOp operands. 527bb2ae985SNicolas Vasilache SmallVector<Value> inputsAndInits, results; 5280b2197b0SMatthias Springer SmallVector<OpOperand *> initOperands = llvm::to_vector(llvm::map_range( 5290b2197b0SMatthias Springer linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); 5300b2197b0SMatthias Springer SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands(); 5310b2197b0SMatthias Springer for (const auto &operandsList : {inputOperands, initOperands}) { 5320b2197b0SMatthias Springer for (OpOperand *opOperand : operandsList) { 5330b2197b0SMatthias Springer int64_t pos = opOperand->getOperandNumber(); 5340b2197b0SMatthias Springer Value operand = opOperand->get(); 535bb2ae985SNicolas Vasilache SmallVector<int64_t> innerPos = 536bb2ae985SNicolas Vasilache listOfPackedOperandsDim.extractPackedDimsForOperand(pos); 537bb2ae985SNicolas Vasilache SmallVector<OpFoldResult> innerPackSizes = 538bb2ae985SNicolas Vasilache listOfPackedOperandsDim.extractPackSizesForOperand(pos); 539bb2ae985SNicolas Vasilache LLVM_DEBUG( 540bb2ae985SNicolas Vasilache DBGS() << "operand: " << operand << "\n"; 541bb2ae985SNicolas Vasilache llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL(); 542bb2ae985SNicolas Vasilache llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: "); 543bb2ae985SNicolas Vasilache DBGSNL();); 544bb2ae985SNicolas Vasilache if (innerPackSizes.empty()) { 545bb2ae985SNicolas Vasilache inputsAndInits.push_back(operand); 546bb2ae985SNicolas Vasilache continue; 547bb2ae985SNicolas Vasilache } 548bb2ae985SNicolas Vasilache Value dest = tensor::PackOp::createDestinationTensor( 549bb2ae985SNicolas Vasilache rewriter, loc, operand, innerPackSizes, innerPos, 550bb2ae985SNicolas Vasilache /*outerDimsPerm=*/{}); 551a5757c5bSChristian Sigg ShapedType operandType = cast<ShapedType>(operand.getType()); 552d21beb59SLorenzo Chelini bool areConstantTiles = 553d21beb59SLorenzo Chelini llvm::all_of(innerPackSizes, [](OpFoldResult tile) { 554d21beb59SLorenzo Chelini return getConstantIntValue(tile).has_value(); 555d21beb59SLorenzo Chelini }); 556d21beb59SLorenzo Chelini if (areConstantTiles && operandType.hasStaticShape() && 5579466c4e6Ssrcarroll !tensor::PackOp::requirePaddingValue( 5589466c4e6Ssrcarroll operandType.getShape(), innerPos, 559a5757c5bSChristian Sigg cast<ShapedType>(dest.getType()).getShape(), {}, 560d21beb59SLorenzo Chelini innerPackSizes)) { 561d21beb59SLorenzo Chelini packOps.push_back(rewriter.create<tensor::PackOp>( 562d21beb59SLorenzo Chelini loc, operand, dest, innerPos, innerPackSizes)); 563d21beb59SLorenzo Chelini } else { 564d21beb59SLorenzo Chelini // TODO: value of the padding attribute should be determined by 565d21beb59SLorenzo Chelini // consumers. 5666089d612SRahul Kayaith auto zeroAttr = 567bb2ae985SNicolas Vasilache rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); 568bb2ae985SNicolas Vasilache Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr); 569bb2ae985SNicolas Vasilache packOps.push_back(rewriter.create<tensor::PackOp>( 570bb2ae985SNicolas Vasilache loc, operand, dest, innerPos, innerPackSizes, zero)); 571d21beb59SLorenzo Chelini } 572bb2ae985SNicolas Vasilache inputsAndInits.push_back(packOps.back()); 573bb2ae985SNicolas Vasilache } 574bb2ae985SNicolas Vasilache } 575bb2ae985SNicolas Vasilache 576bb2ae985SNicolas Vasilache // Step 3. Build the packed op, use the type of `inits` as result types. 577bb2ae985SNicolas Vasilache ValueRange inputs = 578bb2ae985SNicolas Vasilache ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs()); 579bb2ae985SNicolas Vasilache ValueRange inits = 580bb2ae985SNicolas Vasilache ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits()); 581bb2ae985SNicolas Vasilache auto packedLinalgOp = rewriter.create<linalg::GenericOp>( 582bb2ae985SNicolas Vasilache linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps, 583bb2ae985SNicolas Vasilache iteratorTypes); 584bb2ae985SNicolas Vasilache packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0)); 585bb2ae985SNicolas Vasilache 586bb2ae985SNicolas Vasilache // Step 4. Propagate packing to all the op results. 587bb2ae985SNicolas Vasilache for (OpResult result : packedLinalgOp->getResults()) { 588bb2ae985SNicolas Vasilache int64_t resultNum = result.getResultNumber(); 589bb2ae985SNicolas Vasilache tensor::PackOp maybePackedInit = 590bb2ae985SNicolas Vasilache inits[resultNum].getDefiningOp<tensor::PackOp>(); 591bb2ae985SNicolas Vasilache if (!maybePackedInit) { 592bb2ae985SNicolas Vasilache results.push_back(result); 593bb2ae985SNicolas Vasilache continue; 594bb2ae985SNicolas Vasilache } 595bb2ae985SNicolas Vasilache // Build the symmetrical UnPackOp to the existing PackOp. 596bb2ae985SNicolas Vasilache unPackOps.push_back(rewriter.create<tensor::UnPackOp>( 597bb2ae985SNicolas Vasilache packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), 598bb2ae985SNicolas Vasilache maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles())); 599bb2ae985SNicolas Vasilache results.push_back(unPackOps.back()); 600bb2ae985SNicolas Vasilache } 601bb2ae985SNicolas Vasilache 602bb2ae985SNicolas Vasilache // Step 5. Replace `linalgOp`. 603bb2ae985SNicolas Vasilache rewriter.replaceOp(linalgOp, results); 604bb2ae985SNicolas Vasilache 605bb2ae985SNicolas Vasilache // Return packedLinalgOp. 606bb2ae985SNicolas Vasilache return PackResult{packOps, 607bb2ae985SNicolas Vasilache cast<linalg::LinalgOp>(packedLinalgOp.getOperation()), 608bb2ae985SNicolas Vasilache unPackOps}; 609bb2ae985SNicolas Vasilache } 610bb2ae985SNicolas Vasilache 611bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===// 612bb2ae985SNicolas Vasilache // packTranspose transformation. 613bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===// 614bb2ae985SNicolas Vasilache 615bb2ae985SNicolas Vasilache /// Return a copy of `tensorType` after permutation by `permutationVector`. 616bb2ae985SNicolas Vasilache // Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder 617bb2ae985SNicolas Vasilache // but this would introduce a dependence on Dialect in IR. 618bb2ae985SNicolas Vasilache // TODO: Restructure. 619bb2ae985SNicolas Vasilache static RankedTensorType permuteShape(RankedTensorType tensorType, 620bb2ae985SNicolas Vasilache ArrayRef<int64_t> permutationVector) { 621bb2ae985SNicolas Vasilache SmallVector<int64_t> shape(tensorType.getShape()); 622bb2ae985SNicolas Vasilache applyPermutationToVector(shape, permutationVector); 623bb2ae985SNicolas Vasilache return RankedTensorType::Builder(tensorType).setShape(shape); 624bb2ae985SNicolas Vasilache } 625bb2ae985SNicolas Vasilache 626bb2ae985SNicolas Vasilache /// Return a new GenericOp obtained by transposing opOperand by the permutation 627bb2ae985SNicolas Vasilache /// vector: 628bb2ae985SNicolas Vasilache /// - the corresponding indexing map is transposed by `permutation` 629bb2ae985SNicolas Vasilache /// - the corresponding operand value is replaced by `transposedValue` 630bb2ae985SNicolas Vasilache /// `linalgOp` is replaced by the return op in the process. 631bb2ae985SNicolas Vasilache /// Asserts that `transposedValue` is of the proper transposed ShapedType. 632bb2ae985SNicolas Vasilache static LinalgOp transposeOneLinalgOperandAndReplace( 633bb2ae985SNicolas Vasilache RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, 634bb2ae985SNicolas Vasilache ArrayRef<int64_t> permutation, Value transposedValue) { 635bb2ae985SNicolas Vasilache // Sanity check the operand. 636bb2ae985SNicolas Vasilache assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand"); 637bb2ae985SNicolas Vasilache 638bb2ae985SNicolas Vasilache // Sanity check of the expected transposed tensor type. 639bb2ae985SNicolas Vasilache auto tensorType = permuteShape( 6405550c821STres Popp cast<RankedTensorType>(opOperand.get().getType()), permutation); 641bb2ae985SNicolas Vasilache (void)tensorType; 642bb2ae985SNicolas Vasilache assert(tensorType == transposedValue.getType() && 643bb2ae985SNicolas Vasilache "expected tensor type mismatch"); 644bb2ae985SNicolas Vasilache 645bb2ae985SNicolas Vasilache // Compute the transposed indexing map. 646bb2ae985SNicolas Vasilache // Sigh unsigned pollution. 647bb2ae985SNicolas Vasilache SmallVector<unsigned> tmpTransposition = llvm::to_vector( 648bb2ae985SNicolas Vasilache llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; })); 649bb2ae985SNicolas Vasilache AffineMap permutationMap = 650bb2ae985SNicolas Vasilache AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext()); 651bb2ae985SNicolas Vasilache AffineMap transposedMap = 652bb2ae985SNicolas Vasilache permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand)); 653bb2ae985SNicolas Vasilache 654bb2ae985SNicolas Vasilache // Set the transposed indexing map in the proper position. 655bb2ae985SNicolas Vasilache SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); 656bb2ae985SNicolas Vasilache indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap; 657bb2ae985SNicolas Vasilache // Set the transposedValue in the proper operand position. 658bb2ae985SNicolas Vasilache SmallVector<Value> operands = linalgOp->getOperands(); 659bb2ae985SNicolas Vasilache operands[opOperand.getOperandNumber()] = transposedValue; 660bb2ae985SNicolas Vasilache 661bb2ae985SNicolas Vasilache ValueRange operandsRef(operands); 662bb2ae985SNicolas Vasilache auto transposedGenericOp = rewriter.create<linalg::GenericOp>( 663bb2ae985SNicolas Vasilache /*location=*/linalgOp->getLoc(), 664bb2ae985SNicolas Vasilache /*resultTensorTypes=*/ 665bb2ae985SNicolas Vasilache operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(), 666bb2ae985SNicolas Vasilache /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()), 667bb2ae985SNicolas Vasilache /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()), 668bb2ae985SNicolas Vasilache /*indexingMaps=*/indexingMaps, 669bb2ae985SNicolas Vasilache /*iteratorTypes=*/linalgOp.getIteratorTypesArray()); 670bb2ae985SNicolas Vasilache transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0)); 671bb2ae985SNicolas Vasilache rewriter.replaceOp(linalgOp, transposedGenericOp->getResults()); 672bb2ae985SNicolas Vasilache 673bb2ae985SNicolas Vasilache return cast<linalg::LinalgOp>(transposedGenericOp.getOperation()); 674bb2ae985SNicolas Vasilache } 675bb2ae985SNicolas Vasilache 676bb2ae985SNicolas Vasilache FailureOr<PackTransposeResult> 677bb2ae985SNicolas Vasilache linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, 678bb2ae985SNicolas Vasilache linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, 679bb2ae985SNicolas Vasilache ArrayRef<int64_t> outerPerm, 680bb2ae985SNicolas Vasilache ArrayRef<int64_t> innerPerm) { 681bb2ae985SNicolas Vasilache Location loc = linalgOp.getLoc(); 682bb2ae985SNicolas Vasilache 683bb2ae985SNicolas Vasilache // Step 1. Transpose packOp. 684bb2ae985SNicolas Vasilache rewriter.setInsertionPoint(packOp); 685bb2ae985SNicolas Vasilache tensor::PackOp transposedPackOp = 686bb2ae985SNicolas Vasilache packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm); 687bb2ae985SNicolas Vasilache 688bb2ae985SNicolas Vasilache if (!packOp.getResult().hasOneUse()) 689bb2ae985SNicolas Vasilache return rewriter.notifyMatchFailure(linalgOp, "expect single pack use"); 690bb2ae985SNicolas Vasilache 691bb2ae985SNicolas Vasilache OpOperand &packUse = *packOp->getUses().begin(); 692bb2ae985SNicolas Vasilache if (packUse.getOwner() != linalgOp) { 693bb2ae985SNicolas Vasilache return rewriter.notifyMatchFailure( 694bb2ae985SNicolas Vasilache linalgOp, "not a single use by the LinalgOp target"); 695bb2ae985SNicolas Vasilache } 696bb2ae985SNicolas Vasilache if (maybeUnPackOp && 697bb2ae985SNicolas Vasilache (!linalgOp.isDpsInit(&packUse) || 698bb2ae985SNicolas Vasilache maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) { 699bb2ae985SNicolas Vasilache return rewriter.notifyMatchFailure(linalgOp, 700bb2ae985SNicolas Vasilache "not produced by the LinalgOp target"); 701bb2ae985SNicolas Vasilache } 702bb2ae985SNicolas Vasilache 703bb2ae985SNicolas Vasilache // Step 2. Transpose linalgOp. 704bb2ae985SNicolas Vasilache // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the 705bb2ae985SNicolas Vasilache // identity. Don't rely on it. 706bb2ae985SNicolas Vasilache int64_t numLeadingDims = packOp.getSourceRank(); 707bb2ae985SNicolas Vasilache int64_t numTrailingDims = packOp.getInnerDimsPos().size(); 708bb2ae985SNicolas Vasilache // Step 2.a. Compute the permutation on the whole operand. 709bb2ae985SNicolas Vasilache // Leading part just reuse the outerPerm. 710bb2ae985SNicolas Vasilache SmallVector<int64_t> permutation(outerPerm); 711bb2ae985SNicolas Vasilache if (permutation.empty()) 712bb2ae985SNicolas Vasilache llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims)); 713bb2ae985SNicolas Vasilache // Trailing part needs to reindex positions by `numLeadingDims`. 714bb2ae985SNicolas Vasilache if (innerPerm.empty()) { 715bb2ae985SNicolas Vasilache llvm::append_range( 716bb2ae985SNicolas Vasilache permutation, 717bb2ae985SNicolas Vasilache llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims)); 718bb2ae985SNicolas Vasilache } else { 719bb2ae985SNicolas Vasilache llvm::append_range(permutation, 720bb2ae985SNicolas Vasilache llvm::map_range(innerPerm, [&](int64_t pos) { 721bb2ae985SNicolas Vasilache return numLeadingDims + pos; 722bb2ae985SNicolas Vasilache })); 723bb2ae985SNicolas Vasilache } 724bb2ae985SNicolas Vasilache if (!isPermutationVector(permutation)) 725bb2ae985SNicolas Vasilache return rewriter.notifyMatchFailure(linalgOp, "invalid permutation"); 726bb2ae985SNicolas Vasilache 727bb2ae985SNicolas Vasilache // Step 2.b. Save the transposedPackUse operand number in case we need to 728bb2ae985SNicolas Vasilache // get the tied OpResult after `linalgOp` has been replaced. 729bb2ae985SNicolas Vasilache int64_t packUseOperandNumber = packUse.getOperandNumber(); 730bb2ae985SNicolas Vasilache // Step 2.c. Actually perform the transposition. 731bb2ae985SNicolas Vasilache rewriter.setInsertionPoint(linalgOp); 732bb2ae985SNicolas Vasilache linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace( 733bb2ae985SNicolas Vasilache rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult()); 734bb2ae985SNicolas Vasilache 735bb2ae985SNicolas Vasilache // Step 3. Maybe transpose unPackOp. 736bb2ae985SNicolas Vasilache tensor::UnPackOp transposedUnPackOp; 737bb2ae985SNicolas Vasilache if (maybeUnPackOp) { 738bb2ae985SNicolas Vasilache OpOperand &opOperand = 739bb2ae985SNicolas Vasilache transposedLinalgOp->getOpOperand(packUseOperandNumber); 740bb2ae985SNicolas Vasilache OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand); 741bb2ae985SNicolas Vasilache rewriter.setInsertionPoint(maybeUnPackOp); 742bb2ae985SNicolas Vasilache transposedUnPackOp = maybeUnPackOp.createTransposedClone( 743bb2ae985SNicolas Vasilache rewriter, loc, transposedResult, innerPerm, outerPerm); 744bb2ae985SNicolas Vasilache 745bb2ae985SNicolas Vasilache rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults()); 746bb2ae985SNicolas Vasilache } 747bb2ae985SNicolas Vasilache 748bb2ae985SNicolas Vasilache // Step 4. Finally, replace packOp now that we don't need it anymore. 749bb2ae985SNicolas Vasilache rewriter.replaceOp(packOp, transposedPackOp->getResults()); 750bb2ae985SNicolas Vasilache 751bb2ae985SNicolas Vasilache return PackTransposeResult{transposedPackOp, transposedLinalgOp, 752bb2ae985SNicolas Vasilache transposedUnPackOp}; 753bb2ae985SNicolas Vasilache } 754bb2ae985SNicolas Vasilache 755bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===// 7564d74c845SLorenzo Chelini // packMatmulGreedily transformation. 7574d74c845SLorenzo Chelini //===----------------------------------------------------------------------===// 7584d74c845SLorenzo Chelini 7594d74c845SLorenzo Chelini /// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m 7604d74c845SLorenzo Chelini /// and n are proper parallel dimensions and k is a proper reduction 7614d74c845SLorenzo Chelini /// dimension. Packing occurs by rewriting the op as a linalg.generic and 7624d74c845SLorenzo Chelini /// calling linalg::pack by `mnkPackedSizes`. The order of the packed 7634d74c845SLorenzo Chelini /// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2} 7644d74c845SLorenzo Chelini /// to reorder {m, n, k} into one of the 8 possible forms. The outer 7654d74c845SLorenzo Chelini /// dimensions of the operands are not permuted at this time, this is left for 7664d74c845SLorenzo Chelini /// future work. 7674d74c845SLorenzo Chelini FailureOr<PackResult> 7684d74c845SLorenzo Chelini linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, 7694d74c845SLorenzo Chelini ArrayRef<OpFoldResult> mnkPackedSizes, 7704d74c845SLorenzo Chelini ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf, 7714d74c845SLorenzo Chelini ArrayRef<int64_t> mnkOrder) { 7724d74c845SLorenzo Chelini assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes"); 7734d74c845SLorenzo Chelini assert((mnkPaddedSizesNextMultipleOf.empty() || 7744d74c845SLorenzo Chelini mnkPaddedSizesNextMultipleOf.size() == 3) && 7754d74c845SLorenzo Chelini "num of packing sizes next multiple should be empty or of size 3"); 7764d74c845SLorenzo Chelini assert(mnkOrder.size() == 3 && "unexpected mnkOrder size"); 7774d74c845SLorenzo Chelini assert(isPermutationVector(mnkOrder) && "expected a permutation"); 7784d74c845SLorenzo Chelini 7794d74c845SLorenzo Chelini int64_t numLoops = linalgOp.getNumLoops(); 7804d74c845SLorenzo Chelini if (numLoops <= 2) { 7814d74c845SLorenzo Chelini LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got " 7824d74c845SLorenzo Chelini << numLoops << "\nin: " << linalgOp << "\n"); 7834d74c845SLorenzo Chelini return rewriter.notifyMatchFailure( 7844d74c845SLorenzo Chelini linalgOp, "need 3+ loops to find a matmul to pack"); 7854d74c845SLorenzo Chelini } 7864d74c845SLorenzo Chelini 7874d74c845SLorenzo Chelini // Locally adjust the desired iterator position of mnk and packing sizes. 7884d74c845SLorenzo Chelini int64_t numPackedDims = mnkPackedSizes.size(); 7894d74c845SLorenzo Chelini SmallVector<int64_t> mmnnkkPos(numPackedDims); 7904d74c845SLorenzo Chelini for (int64_t i = 0, e = numPackedDims; i < e; ++i) 7914d74c845SLorenzo Chelini mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i]; 7924d74c845SLorenzo Chelini SmallVector<OpFoldResult> packedSizes(numPackedDims); 7934d74c845SLorenzo Chelini for (int64_t i = 0, e = numPackedDims; i < e; ++i) 7944d74c845SLorenzo Chelini packedSizes[mnkOrder[i]] = mnkPackedSizes[i]; 7954d74c845SLorenzo Chelini SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims); 7964d74c845SLorenzo Chelini for (int64_t i = 0, e = numPackedDims; i < e; ++i) { 7974d74c845SLorenzo Chelini paddedSizesNextMultipleOf[mnkOrder[i]] = 7984d74c845SLorenzo Chelini mnkPaddedSizesNextMultipleOf.empty() ? 0 7994d74c845SLorenzo Chelini : mnkPaddedSizesNextMultipleOf[i]; 8004d74c845SLorenzo Chelini } 8014d74c845SLorenzo Chelini 8024d74c845SLorenzo Chelini // 1. Infer dims that are important for matmul. 8034d74c845SLorenzo Chelini FailureOr<ContractionDimensions> maybeDimensions = 8044d74c845SLorenzo Chelini inferContractionDims(linalgOp); 8054d74c845SLorenzo Chelini if (failed(maybeDimensions)) { 8064d74c845SLorenzo Chelini LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp 8074d74c845SLorenzo Chelini << "\n"); 8084d74c845SLorenzo Chelini return rewriter.notifyMatchFailure(linalgOp, 8094d74c845SLorenzo Chelini "couldn't infer matmul iterators"); 8104d74c845SLorenzo Chelini } 8114d74c845SLorenzo Chelini 8124d74c845SLorenzo Chelini // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most 8134d74c845SLorenzo Chelini // minor iterators. In cases with multiple options for m, n, k bias towards 8144d74c845SLorenzo Chelini // the most minor embedding. 8154d74c845SLorenzo Chelini // If we wanted a different normalization order, this is where it would have 8164d74c845SLorenzo Chelini // to plug a heuristic. 8174d74c845SLorenzo Chelini int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(), 8184d74c845SLorenzo Chelini kPos = maybeDimensions->k.back(); 8194d74c845SLorenzo Chelini LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); 8204d74c845SLorenzo Chelini DBGS() << "Start packing generic op greedily with (m@" << mPos 8214d74c845SLorenzo Chelini << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp 8224d74c845SLorenzo Chelini << "\n";); 8234d74c845SLorenzo Chelini 8244d74c845SLorenzo Chelini // 2.a. Rewrite as a generic. 8254d74c845SLorenzo Chelini auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation()); 8264d74c845SLorenzo Chelini if (!genericOp) { 8274d74c845SLorenzo Chelini FailureOr<GenericOp> generalizeResult = 8284d74c845SLorenzo Chelini generalizeNamedOp(rewriter, linalgOp); 8294d74c845SLorenzo Chelini assert(succeeded(generalizeResult) && "unexpected failure generalizing op"); 8304d74c845SLorenzo Chelini genericOp = *generalizeResult; 8314d74c845SLorenzo Chelini } 8324d74c845SLorenzo Chelini 8334d74c845SLorenzo Chelini // 2.b. Interchange to move the dimensions (k, m, n) as most-minor 8344d74c845SLorenzo Chelini // iterators. Note that this only normalized the iteration order and does 8354d74c845SLorenzo Chelini // not change the indexings of any operand. 8364d74c845SLorenzo Chelini SmallVector<int64_t> permutation = 8374d74c845SLorenzo Chelini computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos); 8384d74c845SLorenzo Chelini LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL();); 8394d74c845SLorenzo Chelini // Sign .. unsigned pollution. 8404d74c845SLorenzo Chelini SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end()); 8414d74c845SLorenzo Chelini FailureOr<GenericOp> interchangeResult = 8424d74c845SLorenzo Chelini interchangeGenericOp(rewriter, genericOp, unsignedPerm); 8434d74c845SLorenzo Chelini assert(succeeded(interchangeResult) && "unexpected failure interchanging op"); 8444d74c845SLorenzo Chelini genericOp = *interchangeResult; 8454d74c845SLorenzo Chelini LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";); 8464d74c845SLorenzo Chelini 8474d74c845SLorenzo Chelini // At this point, the op iterators are normalized to {leading, k, m, n}. 8484d74c845SLorenzo Chelini // The layouts induced by packing will always be: 8494d74c845SLorenzo Chelini // - LHS{leading_lhs, kk, mm} 8504d74c845SLorenzo Chelini // - RHS{leading_rhs, kk, nn} 8514d74c845SLorenzo Chelini // - RES{leading_res, mm, nn} 8524d74c845SLorenzo Chelini // If we wanted to change the packed order, we would reorder (k, m, n) to 8534d74c845SLorenzo Chelini // something else above. 8544d74c845SLorenzo Chelini // 8554d74c845SLorenzo Chelini // Additional permutations of the outer dims of the operands (i.e. 8564d74c845SLorenzo Chelini // leading_lhs, leading_rhs and leading_res) could follow by computing the 8574d74c845SLorenzo Chelini // desired outerPerm for each operand. 8584d74c845SLorenzo Chelini // This is left for future work. 8594d74c845SLorenzo Chelini 8604d74c845SLorenzo Chelini // TODO: this creates too much IR, go use reifyResultShapes. 8614d74c845SLorenzo Chelini SmallVector<Range, 4> loopRanges = 8624d74c845SLorenzo Chelini cast<LinalgOp>(genericOp.getOperation()) 8634d74c845SLorenzo Chelini .createLoopRanges(rewriter, genericOp.getLoc()); 8644d74c845SLorenzo Chelini 8654d74c845SLorenzo Chelini // Add leading zeros to match numLoops, we only pack the last 3 dimensions 8664d74c845SLorenzo Chelini // post interchange. 8674d74c845SLorenzo Chelini LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf, 8684d74c845SLorenzo Chelini DBGS() << "paddedSizesNextMultipleOf: "); 8694d74c845SLorenzo Chelini DBGSNL();); 8704d74c845SLorenzo Chelini LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ", 8714d74c845SLorenzo Chelini [](Range r) { llvm::dbgs() << r.size; }); 8724d74c845SLorenzo Chelini DBGSNL();); 8734d74c845SLorenzo Chelini SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(), 8744d74c845SLorenzo Chelini rewriter.getIndexAttr(0)); 8754d74c845SLorenzo Chelini for (int64_t i = 0, e = numPackedDims; i < e; ++i) { 8764d74c845SLorenzo Chelini if (paddedSizesNextMultipleOf[i] == 0) { 8774d74c845SLorenzo Chelini adjustedPackedSizes.push_back(packedSizes[i]); 8784d74c845SLorenzo Chelini continue; 8794d74c845SLorenzo Chelini } 8804d74c845SLorenzo Chelini AffineExpr d0, s0; 8814d74c845SLorenzo Chelini bindDims(rewriter.getContext(), d0); 8824d74c845SLorenzo Chelini bindSymbols(rewriter.getContext(), s0); 8834d74c845SLorenzo Chelini adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply( 8844d74c845SLorenzo Chelini rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0, 8854d74c845SLorenzo Chelini {loopRanges[adjustedPackedSizes.size()].size, 8864d74c845SLorenzo Chelini rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])})); 8874d74c845SLorenzo Chelini } 8884d74c845SLorenzo Chelini LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes, 8894d74c845SLorenzo Chelini DBGS() << "adjustedPackedSizes: "); 8904d74c845SLorenzo Chelini DBGSNL();); 8914d74c845SLorenzo Chelini 8924d74c845SLorenzo Chelini // TODO: If we wanted to give the genericOp a name after packing, after 8934d74c845SLorenzo Chelini // calling `pack` would be a good time. One would still need to check that 8944d74c845SLorenzo Chelini // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we 8954d74c845SLorenzo Chelini // also allow degenerate matmul cases (i.e. matvec, dot). 8964d74c845SLorenzo Chelini return pack(rewriter, genericOp, adjustedPackedSizes); 8974d74c845SLorenzo Chelini } 8984d74c845SLorenzo Chelini 8994d74c845SLorenzo Chelini //===----------------------------------------------------------------------===// 900bb2ae985SNicolas Vasilache // Transformations exposed as rewrite patterns. 901bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===// 902bb2ae985SNicolas Vasilache 903bb2ae985SNicolas Vasilache LinalgTilingOptions & 904bb2ae985SNicolas Vasilache mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 905bb2ae985SNicolas Vasilache assert(!tileSizeComputationFunction && "tile sizes already set"); 9065262865aSKazu Hirata SmallVector<int64_t, 4> tileSizes(ts); 907bb2ae985SNicolas Vasilache tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 908bb2ae985SNicolas Vasilache OpBuilder::InsertionGuard guard(b); 909bb2ae985SNicolas Vasilache b.setInsertionPointToStart( 910bb2ae985SNicolas Vasilache &op->getParentOfType<func::FuncOp>().getBody().front()); 911bb2ae985SNicolas Vasilache return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 912bb2ae985SNicolas Vasilache Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 913bb2ae985SNicolas Vasilache return v; 914bb2ae985SNicolas Vasilache })); 915bb2ae985SNicolas Vasilache }; 916bb2ae985SNicolas Vasilache return *this; 917bb2ae985SNicolas Vasilache } 918bb2ae985SNicolas Vasilache 919ebc81537SAlexander Belyaev LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( 920ebc81537SAlexander Belyaev memref::CopyOp copyOp, PatternRewriter &rewriter) const { 921ebc81537SAlexander Belyaev return vectorizeCopy(rewriter, copyOp); 922ebc81537SAlexander Belyaev } 923ebc81537SAlexander Belyaev 92435df2f6fSYi Zhang /// Filling `dest` using FillOp constant padding value if possible. 92535df2f6fSYi Zhang /// Otherwise, generate a tensor::GenerateOp. 9261b2c8f10SAndrzej Warzyński Value DecomposePadOpPattern::createFillOrGenerateOp( 9271cff4cbdSNicolas Vasilache RewriterBase &rewriter, tensor::PadOp padOp, Value dest, 92835df2f6fSYi Zhang const SmallVector<Value> &dynSizes) const { 92935df2f6fSYi Zhang auto padValue = padOp.getConstantPaddingValue(); 93035df2f6fSYi Zhang if (padValue) 93135df2f6fSYi Zhang return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 93235df2f6fSYi Zhang 93335df2f6fSYi Zhang // Fill could not be optimized: Lower to tensor::GenerateOp with region. 93435df2f6fSYi Zhang auto generateOp = rewriter.create<tensor::GenerateOp>( 93535df2f6fSYi Zhang padOp.getLoc(), padOp.getResultType(), dynSizes); 93635df2f6fSYi Zhang // Copy region to new op. 9374d67b278SJeff Niu IRMapping bvm; 93804235d07SJacques Pienaar padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm); 93935df2f6fSYi Zhang return generateOp; 94035df2f6fSYi Zhang } 94135df2f6fSYi Zhang 94235df2f6fSYi Zhang LogicalResult 9431b2c8f10SAndrzej Warzyński DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp, 94435df2f6fSYi Zhang PatternRewriter &rewriter) const { 94535df2f6fSYi Zhang // Given an OpFoldResult, return an index-typed value. 94635df2f6fSYi Zhang auto getIdxValue = [&](OpFoldResult ofr) { 94768f58812STres Popp if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) 94835df2f6fSYi Zhang return val; 94935df2f6fSYi Zhang return rewriter 950a54f4eaeSMogball .create<arith::ConstantIndexOp>( 9514f279a57SKazu Hirata padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt()) 95235df2f6fSYi Zhang .getResult(); 95335df2f6fSYi Zhang }; 95435df2f6fSYi Zhang 95535df2f6fSYi Zhang auto resultType = padOp.getResultType(); 95681ca5aa4SMatthias Springer // Compute size of EmptyOp. Any combination of static/dynamic is supported. 95735df2f6fSYi Zhang SmallVector<Value> dynSizes; 95835df2f6fSYi Zhang SmallVector<int64_t> staticSizes; 95935df2f6fSYi Zhang for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 96035df2f6fSYi Zhang if (resultType.isDynamicDim(dim)) { 9616596b0ddSMatthias Springer auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(), 9626596b0ddSMatthias Springer padOp.getSource(), dim)); 96335df2f6fSYi Zhang // Add low and high padding value. 964a54f4eaeSMogball auto plusLow = rewriter.createOrFold<arith::AddIOp>( 96535df2f6fSYi Zhang padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 966a54f4eaeSMogball auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 96735df2f6fSYi Zhang padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 96835df2f6fSYi Zhang dynSizes.push_back(plusHigh); 96935df2f6fSYi Zhang } 97035df2f6fSYi Zhang staticSizes.push_back(resultType.getDimSize(dim)); 97135df2f6fSYi Zhang } 97235df2f6fSYi Zhang 97335df2f6fSYi Zhang // Init tensor and fill it with padding. 97481ca5aa4SMatthias Springer Value emptyTensor = rewriter.create<tensor::EmptyOp>( 97581ca5aa4SMatthias Springer padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes); 97681ca5aa4SMatthias Springer Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes); 97735df2f6fSYi Zhang 97839ad84e4SAndrzej Warzyński // Generate a InsertSliceOp for copying the PadOp source. 97935df2f6fSYi Zhang auto sourceType = padOp.getSourceType(); 980fd0c6f53SAlexander Belyaev // Compute size of source of tensor::PadOp. 9816596b0ddSMatthias Springer SmallVector<OpFoldResult> srcSizes = 9826596b0ddSMatthias Springer tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource()); 98335df2f6fSYi Zhang // Strides of InsertSliceOp are all 1. 98435df2f6fSYi Zhang SmallVector<OpFoldResult> strides(sourceType.getRank(), 98535df2f6fSYi Zhang rewriter.getIndexAttr(1)); 98635df2f6fSYi Zhang rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 98704235d07SJacques Pienaar padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes, 98804235d07SJacques Pienaar strides); 98935df2f6fSYi Zhang 99035df2f6fSYi Zhang return success(); 99135df2f6fSYi Zhang } 99235df2f6fSYi Zhang 993060208b4SMatthias Springer LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 994060208b4SMatthias Springer tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 995060208b4SMatthias Springer if (!sliceOp.hasUnitStride()) 99624199f53SMatthias Springer return failure(); 99724199f53SMatthias Springer 99804235d07SJacques Pienaar auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>(); 9990edb4127SLei Zhang if (!padOp) 10000edb4127SLei Zhang return failure(); 10010edb4127SLei Zhang 10020edb4127SLei Zhang bool zeroSliceGuard = true; 10030edb4127SLei Zhang if (controlFn) { 100422426110SRamkumar Ramachandra if (std::optional<bool> control = controlFn(sliceOp)) 10056d5fc1e3SKazu Hirata zeroSliceGuard = *control; 10060edb4127SLei Zhang else 10070edb4127SLei Zhang return failure(); 10080edb4127SLei Zhang } 10090edb4127SLei Zhang 1010809e3d8cSMahesh Ravishankar FailureOr<TilingResult> tilingResult = 10110edb4127SLei Zhang tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), 10120edb4127SLei Zhang sliceOp.getMixedSizes(), zeroSliceGuard); 1013809e3d8cSMahesh Ravishankar if (failed(tilingResult)) 1014809e3d8cSMahesh Ravishankar return failure(); 101524199f53SMatthias Springer // All shapes are static and the data source is actually used. Rewrite into 10160edb4127SLei Zhang // pad(extract_slice(x)). 1017809e3d8cSMahesh Ravishankar rewriter.replaceOp(sliceOp, tilingResult->tiledValues); 101824199f53SMatthias Springer return success(); 101924199f53SMatthias Springer } 10207b615a87SLei Zhang 102166f84c8bSAndrzej Warzyński /// If padding value is set, returns a tensor.pad Op for the source tensor, 102266f84c8bSAndrzej Warzyński /// with the output shape matching the output of `packOp`. Otherwise, returns 102366f84c8bSAndrzej Warzyński /// the source directly. 102466f84c8bSAndrzej Warzyński /// 102566f84c8bSAndrzej Warzyński /// This method assumes that all outer dims for this pack Op are 1. 1026644f0f83SHanhan Wang static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, 1027644f0f83SHanhan Wang tensor::PackOp packOp) { 1028644f0f83SHanhan Wang Value input = packOp.getSource(); 1029644f0f83SHanhan Wang if (!packOp.getPaddingValue()) { 1030644f0f83SHanhan Wang return input; 1031644f0f83SHanhan Wang } 1032644f0f83SHanhan Wang 1033c1826aeeSAndrzej Warzyński assert(llvm::all_of(packOp.getAllOuterDims(), 1034c1826aeeSAndrzej Warzyński [](int64_t val) { return val == 1; }) && 1035c1826aeeSAndrzej Warzyński "some outer dims are != 1"); 1036c1826aeeSAndrzej Warzyński 1037644f0f83SHanhan Wang Location loc = packOp.getLoc(); 1038644f0f83SHanhan Wang ShapedType inputType = packOp.getSourceType(); 1039644f0f83SHanhan Wang int64_t inputRank = inputType.getRank(); 1040644f0f83SHanhan Wang 1041644f0f83SHanhan Wang DenseMap<int64_t, OpFoldResult> tileAndPosMapping = 1042644f0f83SHanhan Wang packOp.getDimAndTileMapping(); 104366f84c8bSAndrzej Warzyński 104466f84c8bSAndrzej Warzyński // The sizes of dynamic tiles 104566f84c8bSAndrzej Warzyński SmallVector<Value> dynamicTileSizes; 104666f84c8bSAndrzej Warzyński 104766f84c8bSAndrzej Warzyński // Collect dims for the padded shape. 104866f84c8bSAndrzej Warzyński SmallVector<int64_t> paddedShape; 104966f84c8bSAndrzej Warzyński for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) { 105066f84c8bSAndrzej Warzyński // 1. Non-tiled outer dims. 105166f84c8bSAndrzej Warzyński // These dims should be 1 and we simply preserve them. 105266f84c8bSAndrzej Warzyński if (!tileAndPosMapping.count(dimIdx)) { 105366f84c8bSAndrzej Warzyński int64_t inputDimSize = inputType.getDimSize(dimIdx); 105466f84c8bSAndrzej Warzyński assert(inputDimSize == 1 && 105566f84c8bSAndrzej Warzyński "with all outer dims == 1, this non-tiled input dim should be 1!"); 105666f84c8bSAndrzej Warzyński paddedShape.push_back(inputDimSize); 1057644f0f83SHanhan Wang continue; 1058644f0f83SHanhan Wang } 1059644f0f83SHanhan Wang 106066f84c8bSAndrzej Warzyński // 2. Tiled outer dims 106166f84c8bSAndrzej Warzyński // As all outer dims == 1, it is safe to use the tile size for the padded 106266f84c8bSAndrzej Warzyński // shape. 106366f84c8bSAndrzej Warzyński OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx); 106466f84c8bSAndrzej Warzyński 106566f84c8bSAndrzej Warzyński // 2.1 Static tile sizes 106666f84c8bSAndrzej Warzyński std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim); 106766f84c8bSAndrzej Warzyński if (cstTileSize.has_value()) { 106866f84c8bSAndrzej Warzyński paddedShape.push_back(cstTileSize.value()); 106966f84c8bSAndrzej Warzyński continue; 107066f84c8bSAndrzej Warzyński } 107166f84c8bSAndrzej Warzyński 107266f84c8bSAndrzej Warzyński // 2.2 Dynamic tile sizes 107366f84c8bSAndrzej Warzyński paddedShape.push_back(ShapedType::kDynamic); 107466f84c8bSAndrzej Warzyński 107566f84c8bSAndrzej Warzyński // Get the value that holds the dynamic size. 107666f84c8bSAndrzej Warzyński dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim)); 1077644f0f83SHanhan Wang } 1078644f0f83SHanhan Wang auto resultType = 1079644f0f83SHanhan Wang RankedTensorType::get(paddedShape, inputType.getElementType()); 1080644f0f83SHanhan Wang return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(), 108166f84c8bSAndrzej Warzyński /*nofold=*/false, loc, builder, 108266f84c8bSAndrzej Warzyński dynamicTileSizes); 1083644f0f83SHanhan Wang } 1084644f0f83SHanhan Wang 1085009c053eSQuinn Dawkins // Normalizes a permutation on a higher rank space to its actual size, e.g. 1086009c053eSQuinn Dawkins // perm = [1, 4, 2] 1087009c053eSQuinn Dawkins // becomes 1088009c053eSQuinn Dawkins // norm = [0, 2, 1] 10893ebc6beeSHanhan Wang static SmallVector<int64_t> 1090009c053eSQuinn Dawkins getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) { 10913ebc6beeSHanhan Wang constexpr int64_t kNonTiledMarker = -1; 10923ebc6beeSHanhan Wang SmallVector<int64_t> vec(rank, kNonTiledMarker); 1093009c053eSQuinn Dawkins for (auto [index, value] : llvm::enumerate(perm)) 10943ebc6beeSHanhan Wang vec[value] = index; 1095f4d75863SJakub Kuderski SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector( 1096f4d75863SJakub Kuderski vec, [&](int64_t v) { return v != kNonTiledMarker; }); 1097009c053eSQuinn Dawkins // This inverts the permutation in addition to normalizing so invert back. 1098009c053eSQuinn Dawkins return invertPermutationVector(normalizedPerm); 1099009c053eSQuinn Dawkins } 1100009c053eSQuinn Dawkins 1101009c053eSQuinn Dawkins // Gets the normalized permutation implied by innerDimsPos and outerDimsPerm 1102009c053eSQuinn Dawkins // assuming rank reduction of unit outer dims. 1103009c053eSQuinn Dawkins static SmallVector<int64_t> 1104009c053eSQuinn Dawkins getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape, 1105009c053eSQuinn Dawkins ArrayRef<int64_t> innerDimsPos, 1106009c053eSQuinn Dawkins ArrayRef<int64_t> outerDimsPerm) { 1107009c053eSQuinn Dawkins SmallVector<int64_t> rankReducedOuterDimsPerm; 1108009c053eSQuinn Dawkins SmallVector<int64_t> outerDims; 1109009c053eSQuinn Dawkins SmallVector<int64_t> innerDims; 1110009c053eSQuinn Dawkins int64_t dim = 0; 1111009c053eSQuinn Dawkins int64_t unpackedRank = shape.size(); 1112009c053eSQuinn Dawkins for (auto i : llvm::seq<unsigned>(0, unpackedRank)) { 1113009c053eSQuinn Dawkins if (llvm::is_contained(innerDimsPos, i)) { 1114009c053eSQuinn Dawkins innerDims.push_back(dim++); 1115009c053eSQuinn Dawkins continue; 1116009c053eSQuinn Dawkins } 1117009c053eSQuinn Dawkins if (shape[i] == 1) 1118009c053eSQuinn Dawkins continue; 1119009c053eSQuinn Dawkins outerDims.push_back(dim++); 1120009c053eSQuinn Dawkins if (!outerDimsPerm.empty()) 1121009c053eSQuinn Dawkins rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]); 1122009c053eSQuinn Dawkins } 1123009c053eSQuinn Dawkins 1124009c053eSQuinn Dawkins // Get the position of the inner dims after permutation. 1125009c053eSQuinn Dawkins SmallVector<int64_t> innerPerm = 1126009c053eSQuinn Dawkins getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos); 1127009c053eSQuinn Dawkins applyPermutationToVector<int64_t>(innerDims, innerPerm); 1128009c053eSQuinn Dawkins 1129009c053eSQuinn Dawkins // Ditto for the outer dims. 1130009c053eSQuinn Dawkins SmallVector<int64_t> perm = outerDims; 1131009c053eSQuinn Dawkins 1132009c053eSQuinn Dawkins rankReducedOuterDimsPerm = 1133009c053eSQuinn Dawkins getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm); 1134009c053eSQuinn Dawkins if (!rankReducedOuterDimsPerm.empty()) 1135009c053eSQuinn Dawkins applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm); 1136009c053eSQuinn Dawkins 1137009c053eSQuinn Dawkins // The tile always ends up as the inner most dims after packing. 1138009c053eSQuinn Dawkins perm.append(innerDims); 1139009c053eSQuinn Dawkins 11403ebc6beeSHanhan Wang return perm; 11413ebc6beeSHanhan Wang } 11423ebc6beeSHanhan Wang 114307750882SAndrzej Warzyński LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( 1144644f0f83SHanhan Wang tensor::PackOp packOp, PatternRewriter &rewriter) const { 1145009c053eSQuinn Dawkins // TODO: support the case that outer dimensions are not all 1s. A 1146009c053eSQuinn Dawkins // tensor.expand_shape will be generated in this case. 1147e9bafa35SAndrzej Warzyński if (llvm::any_of(packOp.getAllOuterDims(), 1148c1826aeeSAndrzej Warzyński [](int64_t dim) { return dim != 1; })) { 1149009c053eSQuinn Dawkins return rewriter.notifyMatchFailure( 1150e9bafa35SAndrzej Warzyński packOp, "not all outer dimensions of the result are 1s"); 1151009c053eSQuinn Dawkins } 1152009c053eSQuinn Dawkins 1153e9bafa35SAndrzej Warzyński Attribute zeroIdxAttr = rewriter.getIndexAttr(0); 1154e9bafa35SAndrzej Warzyński Attribute oneIdxAttr = rewriter.getIndexAttr(1); 1155644f0f83SHanhan Wang Location loc = packOp.getLoc(); 1156e9bafa35SAndrzej Warzyński 1157009c053eSQuinn Dawkins Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); 1158009c053eSQuinn Dawkins DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 1159009c053eSQuinn Dawkins packOp.getDimAndTileMapping(); 1160c1826aeeSAndrzej Warzyński int64_t srcRank = packOp.getSourceRank(); 1161e9bafa35SAndrzej Warzyński int64_t destRank = packOp.getDestRank(); 11627ebfbf9cSAndrzej Warzyński int64_t numTiles = destRank - srcRank; 1163e9bafa35SAndrzej Warzyński 11647ebfbf9cSAndrzej Warzyński if (!llvm::all_of(packOp.getInnerDimsPos(), 11657ebfbf9cSAndrzej Warzyński [&srcRank, &numTiles](int64_t dimPos) { 11667ebfbf9cSAndrzej Warzyński return dimPos >= (srcRank - numTiles - 1); 11677ebfbf9cSAndrzej Warzyński })) 11687ebfbf9cSAndrzej Warzyński return rewriter.notifyMatchFailure( 11697ebfbf9cSAndrzej Warzyński packOp, "Attempting to tile non-trailing source dims!"); 1170e9bafa35SAndrzej Warzyński 11717ebfbf9cSAndrzej Warzyński // 1. Extract the inner tile sizes. 11727ebfbf9cSAndrzej Warzyński // Where possible, values are replaced with constant attributes (to match the 11737ebfbf9cSAndrzej Warzyński // behaviour of `getPackOpSourceOrPaddedSource`). 11747ebfbf9cSAndrzej Warzyński SmallVector<OpFoldResult> tileSizes; 1175644f0f83SHanhan Wang for (auto i : llvm::seq<unsigned>(0, srcRank)) { 1176009c053eSQuinn Dawkins if (dimAndTileMapping.count(i)) { 11777ebfbf9cSAndrzej Warzyński // Rather than taking the tile size as is, extact the actual constant 11787ebfbf9cSAndrzej Warzyński // value Attribute where possible, e.g.: 11797ebfbf9cSAndrzej Warzyński // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8] 11807ebfbf9cSAndrzej Warzyński auto [_, tileSize] = 1181e9bafa35SAndrzej Warzyński getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter); 11827ebfbf9cSAndrzej Warzyński tileSizes.push_back(tileSize); 118366f84c8bSAndrzej Warzyński } 1184009c053eSQuinn Dawkins } 1185009c053eSQuinn Dawkins 11867ebfbf9cSAndrzej Warzyński // 2. Transpose the input to match the inner tile order: 1187e9bafa35SAndrzej Warzyński // %init = tensor.empty() 11887ebfbf9cSAndrzej Warzyński // %transposed_tile = linalg.transpose ins(%source_or_padded_source), 11897ebfbf9cSAndrzej Warzyński // outs(%init) 11907ebfbf9cSAndrzej Warzyński // Two assumptions are made: 11917ebfbf9cSAndrzej Warzyński // 1. All outer dims are 1 - the corresponding transposition doesn't matter. 11927ebfbf9cSAndrzej Warzyński // 2. Inner dims position correspond to the trailing `numTiles` dims. 11937ebfbf9cSAndrzej Warzyński SmallVector<int64_t> tilesPermNormalized = 11947ebfbf9cSAndrzej Warzyński getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos()); 11957ebfbf9cSAndrzej Warzyński SmallVector<int64_t> srcPermForTranspose; 11967ebfbf9cSAndrzej Warzyński for (int64_t i = 0; i < (srcRank - numTiles); i++) 11977ebfbf9cSAndrzej Warzyński srcPermForTranspose.push_back(i); 11987ebfbf9cSAndrzej Warzyński 11997ebfbf9cSAndrzej Warzyński srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos())); 1200bbf1d80dSQuinn Dawkins 1201bbf1d80dSQuinn Dawkins LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"; 12027ebfbf9cSAndrzej Warzyński llvm::interleaveComma(srcPermForTranspose, DBGS() << "perm: "); 12037ebfbf9cSAndrzej Warzyński DBGSNL();); 1204bbf1d80dSQuinn Dawkins 1205e9bafa35SAndrzej Warzyński // 2.1 Create tensor.empty (init value for TransposeOp) 12067ebfbf9cSAndrzej Warzyński SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles, 12077ebfbf9cSAndrzej Warzyński oneIdxAttr); 12087ebfbf9cSAndrzej Warzyński transShapeForEmptyOp.append(tileSizes); 1209644f0f83SHanhan Wang 12107ebfbf9cSAndrzej Warzyński applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, 12117ebfbf9cSAndrzej Warzyński srcPermForTranspose); 12127ebfbf9cSAndrzej Warzyński Value empty = rewriter.create<tensor::EmptyOp>( 12137ebfbf9cSAndrzej Warzyński loc, transShapeForEmptyOp, packOp.getSourceType().getElementType()); 1214e9bafa35SAndrzej Warzyński 1215e9bafa35SAndrzej Warzyński // 2.2 Create linalg.transpose 12167ebfbf9cSAndrzej Warzyński auto transposedOp = rewriter.create<linalg::TransposeOp>(loc, input, empty, 12177ebfbf9cSAndrzej Warzyński srcPermForTranspose); 1218644f0f83SHanhan Wang 1219e9bafa35SAndrzej Warzyński // 3. Insert the inner tile to the destination: 1220e9bafa35SAndrzej Warzyński // %inserted_tile = tensor.insert_slice(%transposed_tile) 1221644f0f83SHanhan Wang SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); 1222644f0f83SHanhan Wang SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); 1223e9bafa35SAndrzej Warzyński // Outer dims are all 1s! 1224e9bafa35SAndrzej Warzyński SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(), 1225e9bafa35SAndrzej Warzyński oneIdxAttr); 1226e9bafa35SAndrzej Warzyński SmallVector<int64_t> writeShape; 1227644f0f83SHanhan Wang 1228e9bafa35SAndrzej Warzyński for (auto tileSize : packOp.getMixedTiles()) { 1229e9bafa35SAndrzej Warzyński auto [tileSizeStatic, tileSizeOfr] = 1230e9bafa35SAndrzej Warzyński getSimplifiedOfrAndStaticSizePair(tileSize, rewriter); 1231e9bafa35SAndrzej Warzyński writeSizes.push_back(tileSizeOfr); 1232e9bafa35SAndrzej Warzyński writeShape.push_back(tileSizeStatic); 1233e9bafa35SAndrzej Warzyński } 1234e9bafa35SAndrzej Warzyński 1235e9bafa35SAndrzej Warzyński // 4. Replace tensor.packOp with tensor.insert_slice created above 1236644f0f83SHanhan Wang auto insert = rewriter.create<tensor::InsertSliceOp>( 1237644f0f83SHanhan Wang loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, 1238644f0f83SHanhan Wang writeSizes, writeStrides); 1239644f0f83SHanhan Wang rewriter.replaceOp(packOp, insert.getResult()); 1240644f0f83SHanhan Wang 1241644f0f83SHanhan Wang return success(); 1242644f0f83SHanhan Wang } 1243644f0f83SHanhan Wang 124407750882SAndrzej Warzyński LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( 12453ebc6beeSHanhan Wang tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const { 12463ebc6beeSHanhan Wang int64_t srcRank = unpackOp.getSourceRank(); 12473ebc6beeSHanhan Wang int64_t destRank = unpackOp.getDestRank(); 12483ebc6beeSHanhan Wang ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape(); 1249009c053eSQuinn Dawkins ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); 1250c1826aeeSAndrzej Warzyński if (llvm::any_of(unpackOp.getTiledOuterDims(), 1251c1826aeeSAndrzej Warzyński [](int64_t dim) { return dim != 1; })) { 12523ebc6beeSHanhan Wang return rewriter.notifyMatchFailure( 1253009c053eSQuinn Dawkins unpackOp, 1254009c053eSQuinn Dawkins "require the tiled outer dimensions of the result are all 1s"); 12553ebc6beeSHanhan Wang } 12563ebc6beeSHanhan Wang 125758da789eSAndrzej Warzyński // 1. Use rank-reduced tensor.extract_slice op to extract the tile: 125858da789eSAndrzej Warzyński // %extracted_tile = tensor.extract_slice(%unpack_op_input) 12593ebc6beeSHanhan Wang Location loc = unpackOp.getLoc(); 1260009c053eSQuinn Dawkins Value source = unpackOp.getSource(); 1261009c053eSQuinn Dawkins DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 1262009c053eSQuinn Dawkins unpackOp.getDimAndTileMapping(); 12633ebc6beeSHanhan Wang Attribute zeroIdxAttr = rewriter.getIndexAttr(0); 12643ebc6beeSHanhan Wang Attribute oneIdxAttr = rewriter.getIndexAttr(1); 126558da789eSAndrzej Warzyński 126658da789eSAndrzej Warzyński // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of 126758da789eSAndrzej Warzyński // dims: 126858da789eSAndrzej Warzyński // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ] 126958da789eSAndrzej Warzyński SmallVector<int64_t> readShapeForExtractSlice; 127058da789eSAndrzej Warzyński // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and 127158da789eSAndrzej Warzyński // outer-tiled-dims being all 1), this will be 127258da789eSAndrzej Warzyński // [ outer-untiled-dims, tile-sizes ] 127358da789eSAndrzej Warzyński SmallVector<OpFoldResult> extractSliceSizes; 127458da789eSAndrzej Warzyński // The offset and strides attributes for ExtractSliceOp. 127558da789eSAndrzej Warzyński SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr); 127658da789eSAndrzej Warzyński SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr); 127758da789eSAndrzej Warzyński 127858da789eSAndrzej Warzyński // Shape for EmptyOp that's used as the init value for TransposeOp below. 127958da789eSAndrzej Warzyński // This should be: 128058da789eSAndrzej Warzyński // [ outer-untiled-dims, tile-sizes ] 128158da789eSAndrzej Warzyński // However, skip unit dims - TransposeOp (below) applies rank-reduced 128258da789eSAndrzej Warzyński // permutation. 128358da789eSAndrzej Warzyński SmallVector<OpFoldResult> shapeForEmptyOp; 128458da789eSAndrzej Warzyński 1285009c053eSQuinn Dawkins for (auto i : llvm::seq<unsigned>(0, destRank)) { 128658da789eSAndrzej Warzyński // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims. 128758da789eSAndrzej Warzyński // 128858da789eSAndrzej Warzyński // As all outer tiled dims are 1, so the corresponding 128958da789eSAndrzej Warzyński // slice size to read will also 1. As this will be rank-reducing "extract 129058da789eSAndrzej Warzyński // slice" (i.e. the unit dims will be "collapsed"), there's no need to 129158da789eSAndrzej Warzyński // update: 129258da789eSAndrzej Warzyński // * the output shape for ExtractSliceOp, nor 129358da789eSAndrzej Warzyński // * the shape for EmptyOp. 1294009c053eSQuinn Dawkins if (dimAndTileMapping.count(i)) { 129558da789eSAndrzej Warzyński extractSliceSizes.push_back(oneIdxAttr); 1296009c053eSQuinn Dawkins continue; 1297009c053eSQuinn Dawkins } 12983ebc6beeSHanhan Wang 129958da789eSAndrzej Warzyński // Compute sizes attribute for ExtractSliceOp + EmptyOp - 130058da789eSAndrzej Warzyński // outer-untiled-dims 1301009c053eSQuinn Dawkins if (ShapedType::isDynamic(srcShape[i])) { 130258da789eSAndrzej Warzyński OpFoldResult dynamicDim = 1303a44b787eSqcolombet rewriter.create<tensor::DimOp>(loc, source, i).getResult(); 130458da789eSAndrzej Warzyński extractSliceSizes.push_back(dynamicDim); 130558da789eSAndrzej Warzyński shapeForEmptyOp.push_back(dynamicDim); 1306009c053eSQuinn Dawkins } else { 130758da789eSAndrzej Warzyński extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i])); 1308009c053eSQuinn Dawkins if (srcShape[i] != 1) 130958da789eSAndrzej Warzyński shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i])); 1310009c053eSQuinn Dawkins } 131158da789eSAndrzej Warzyński // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take 131258da789eSAndrzej Warzyński // into account rank-reducing) 131358da789eSAndrzej Warzyński if (srcShape[i] != 1) { 131458da789eSAndrzej Warzyński readShapeForExtractSlice.push_back(srcShape[i]); 131558da789eSAndrzej Warzyński } 131658da789eSAndrzej Warzyński } 131758da789eSAndrzej Warzyński // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the 131858da789eSAndrzej Warzyński // shape for EmptyOp. 13193ebc6beeSHanhan Wang auto mixedTiles = unpackOp.getMixedTiles(); 132058da789eSAndrzej Warzyński extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end()); 132158da789eSAndrzej Warzyński shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end()); 13223ebc6beeSHanhan Wang 13233ebc6beeSHanhan Wang // Explicitly create the type for extract_slice op because the inner tile 13243ebc6beeSHanhan Wang // size could be 1. We want to represent the whole inner tile in this case. 1325009c053eSQuinn Dawkins auto tileShape = srcShape.drop_front(destRank); 1326009c053eSQuinn Dawkins // Append the inner tile shape to the permuted and rank-reduced outer shape. 132758da789eSAndrzej Warzyński readShapeForExtractSlice.append(tileShape.begin(), tileShape.end()); 13283ebc6beeSHanhan Wang Type elemType = unpackOp.getSourceType().getElementType(); 132958da789eSAndrzej Warzyński auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType); 13303ebc6beeSHanhan Wang Value innerTile = rewriter.create<tensor::ExtractSliceOp>( 133158da789eSAndrzej Warzyński loc, readType, unpackOp.getSource(), extractSliceOffsets, 133258da789eSAndrzej Warzyński extractSliceSizes, extractSliceStrides); 13333ebc6beeSHanhan Wang 13343ebc6beeSHanhan Wang // 2. Transpose the tile to match the outer corresponding tile order. 1335009c053eSQuinn Dawkins SmallVector<int64_t> perm = getPackUnpackRankReducedPerm( 1336009c053eSQuinn Dawkins srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm()); 1337009c053eSQuinn Dawkins // Unpack is a transition out of packed space so we invert the permutation. 1338009c053eSQuinn Dawkins perm = invertPermutationVector(perm); 133958da789eSAndrzej Warzyński applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm); 13403ebc6beeSHanhan Wang 1341a44b787eSqcolombet Value empty = 134258da789eSAndrzej Warzyński rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType); 13433ebc6beeSHanhan Wang auto transposedOp = 13443ebc6beeSHanhan Wang rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm); 13453ebc6beeSHanhan Wang 13463ebc6beeSHanhan Wang // 3. Handle in-complete tiles if needed. It truncates trailing data from the 13473ebc6beeSHanhan Wang // transposed tile. 134858da789eSAndrzej Warzyński int numLoops = shapeForEmptyOp.size(); 13493ebc6beeSHanhan Wang SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr); 13503ebc6beeSHanhan Wang SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr); 13513ebc6beeSHanhan Wang SmallVector<OpFoldResult> tileSizes; 1352009c053eSQuinn Dawkins ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape(); 1353009c053eSQuinn Dawkins for (auto i : llvm::seq<unsigned>(0, destRank)) { 1354009c053eSQuinn Dawkins if (dimAndTileMapping.count(i) || destShape[i] != 1) 13556596b0ddSMatthias Springer tileSizes.push_back( 13566596b0ddSMatthias Springer tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i)); 1357009c053eSQuinn Dawkins } 13583ebc6beeSHanhan Wang 13593ebc6beeSHanhan Wang auto partialTile = rewriter.create<tensor::ExtractSliceOp>( 13603ebc6beeSHanhan Wang loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides); 13613ebc6beeSHanhan Wang 13623ebc6beeSHanhan Wang // 4. Insert the result to the destination tensor. 13633ebc6beeSHanhan Wang SmallVector<OpFoldResult> writeSizes; 13643ebc6beeSHanhan Wang SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); 13653ebc6beeSHanhan Wang SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); 13663ebc6beeSHanhan Wang for (int i = 0, idx = 0; i < destRank; ++i) { 1367009c053eSQuinn Dawkins if (dimAndTileMapping.count(i) || destShape[i] != 1) 13683ebc6beeSHanhan Wang writeSizes.push_back(tileSizes[idx++]); 13693ebc6beeSHanhan Wang else 13703ebc6beeSHanhan Wang writeSizes.push_back(oneIdxAttr); 13713ebc6beeSHanhan Wang } 13723ebc6beeSHanhan Wang auto insert = rewriter.create<tensor::InsertSliceOp>( 13733ebc6beeSHanhan Wang loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes, 13743ebc6beeSHanhan Wang writeStrides); 13753ebc6beeSHanhan Wang rewriter.replaceOp(unpackOp, insert.getResult()); 13763ebc6beeSHanhan Wang 13773ebc6beeSHanhan Wang return success(); 13783ebc6beeSHanhan Wang } 13793ebc6beeSHanhan Wang 13807b615a87SLei Zhang // The following are patterns for downscaling convolution ops with size-1 13817b615a87SLei Zhang // window dimensions. 13827b615a87SLei Zhang // 13837b615a87SLei Zhang // Note that we'd eventually want to write such transformations in a generic 13847b615a87SLei Zhang // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 13857b615a87SLei Zhang // and then turning back to named ops. But for now it's fine to have a few 13867b615a87SLei Zhang // patterns matching special ops to get started. 13877b615a87SLei Zhang 13888e484b52SStanley Winata template <typename Conv2DOp, typename Conv1DOp> 13898e484b52SStanley Winata FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>:: 13908e484b52SStanley Winata returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { 13910a8e3dd4SMatthias Springer if (convOp.hasPureBufferSemantics()) 1392ce2e198bSAlex Zinenko return failure(); // To be implemented. 13937b615a87SLei Zhang 1394d3b3f765SJacques Pienaar Value input = convOp.getInputs().front(); 1395d3b3f765SJacques Pienaar Value kernel = convOp.getInputs().back(); 1396d3b3f765SJacques Pienaar Value output = convOp.getOutputs().front(); 13977b615a87SLei Zhang 13985550c821STres Popp auto inputType = dyn_cast<RankedTensorType>(input.getType()); 13995550c821STres Popp auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); 14005550c821STres Popp auto outputType = dyn_cast<RankedTensorType>(output.getType()); 14017b615a87SLei Zhang 140298dbcff1Sgysit auto kernelShape = kernelType.getShape(); 14037b615a87SLei Zhang auto outputShape = outputType.getShape(); 14047b615a87SLei Zhang 14058e484b52SStanley Winata // Get domain indices based on conv2D layout. 14061a151fdcSMurali Vijayaraghavan auto [khIndex, kwIndex, ohIndex, owIndex] = 140702371c5dSNicolas Vasilache TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>( 140802371c5dSNicolas Vasilache convOp) 14098e484b52SStanley Winata .Case([&](linalg::Conv2DNhwcHwcfOp op) { 14101a151fdcSMurali Vijayaraghavan return std::make_tuple(0, 1, 1, 2); 14118e484b52SStanley Winata }) 14128e484b52SStanley Winata .Case([&](linalg::Conv2DNchwFchwOp op) { 14131a151fdcSMurali Vijayaraghavan return std::make_tuple(2, 3, 2, 3); 14141a151fdcSMurali Vijayaraghavan }) 14151a151fdcSMurali Vijayaraghavan .Case([&](linalg::PoolingNhwcSumOp op) { 14161a151fdcSMurali Vijayaraghavan return std::make_tuple(0, 1, 1, 2); 14171a151fdcSMurali Vijayaraghavan }) 14181a151fdcSMurali Vijayaraghavan .Case([&](linalg::PoolingNchwSumOp op) { 14191a151fdcSMurali Vijayaraghavan return std::make_tuple(0, 1, 2, 3); 14201a151fdcSMurali Vijayaraghavan }) 14211a151fdcSMurali Vijayaraghavan .Case([&](linalg::PoolingNhwcMaxOp op) { 14221a151fdcSMurali Vijayaraghavan return std::make_tuple(0, 1, 1, 2); 14231a151fdcSMurali Vijayaraghavan }) 14241a151fdcSMurali Vijayaraghavan .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) { 14251a151fdcSMurali Vijayaraghavan return std::make_tuple(0, 1, 1, 2); 14261a151fdcSMurali Vijayaraghavan }) 14271a151fdcSMurali Vijayaraghavan .Case([&](linalg::PoolingNhwcMinOp op) { 14281a151fdcSMurali Vijayaraghavan return std::make_tuple(0, 1, 1, 2); 14291a151fdcSMurali Vijayaraghavan }) 14301a151fdcSMurali Vijayaraghavan .Case([&](linalg::PoolingNhwcMinUnsignedOp op) { 14311a151fdcSMurali Vijayaraghavan return std::make_tuple(0, 1, 1, 2); 14321a151fdcSMurali Vijayaraghavan }) 14331a151fdcSMurali Vijayaraghavan .Case([&](linalg::PoolingNchwMaxOp op) { 14341a151fdcSMurali Vijayaraghavan return std::make_tuple(0, 1, 2, 3); 14358e484b52SStanley Winata }) 14368e484b52SStanley Winata .Default([&](Operation *op) { 14371a151fdcSMurali Vijayaraghavan llvm_unreachable("unexpected conv2d/pool2d operation."); 14381a151fdcSMurali Vijayaraghavan return std::make_tuple(0, 0, 0, 0); 14398e484b52SStanley Winata }); 14408e484b52SStanley Winata 14417b615a87SLei Zhang // Only handle the case where at least one of the window dimensions is 14427b615a87SLei Zhang // of size 1. Other cases can rely on tiling to reduce to such cases. 14438e484b52SStanley Winata int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex]; 14448e484b52SStanley Winata int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex]; 144598dbcff1Sgysit bool removeH = (khSize == 1 && ohSize == 1); 144698dbcff1Sgysit bool removeW = (kwSize == 1 && owSize == 1); 1447aa373180SNicolas Vasilache if (!removeH && !removeW) 14487b615a87SLei Zhang return failure(); 14497b615a87SLei Zhang 14507b615a87SLei Zhang // Get new shapes and types for all operands by removing the size-1 14517b615a87SLei Zhang // dimension. 1452aa373180SNicolas Vasilache using RTTBuilder = RankedTensorType::Builder; 1453789c88e8SNicolas Vasilache RankedTensorType newInputType = 14548e484b52SStanley Winata RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex)); 145598dbcff1Sgysit RankedTensorType newKernelType = 14568e484b52SStanley Winata RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex)); 1457789c88e8SNicolas Vasilache RankedTensorType newOutputType = 14588e484b52SStanley Winata RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex)); 14597b615a87SLei Zhang 1460aa373180SNicolas Vasilache // Rank-reduce operands. 14617b615a87SLei Zhang Location loc = convOp.getLoc(); 1462aa373180SNicolas Vasilache Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1463aa373180SNicolas Vasilache rewriter, loc, input, newInputType); 146498dbcff1Sgysit Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 146598dbcff1Sgysit rewriter, loc, kernel, newKernelType); 1466aa373180SNicolas Vasilache Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1467aa373180SNicolas Vasilache rewriter, loc, output, newOutputType); 14687b615a87SLei Zhang 1469aa373180SNicolas Vasilache // Rank-reduce strides and dilations too. 1470aa373180SNicolas Vasilache // TODO: dropDim 1-liner helper. 14718e484b52SStanley Winata auto strides = 14728e484b52SStanley Winata llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>()); 1473aa373180SNicolas Vasilache strides.erase(strides.begin() + (removeH ? 0 : 1)); 1474aa373180SNicolas Vasilache auto stridesAttr = rewriter.getI64VectorAttr(strides); 1475aa373180SNicolas Vasilache 1476d3b3f765SJacques Pienaar auto dilations = 14778e484b52SStanley Winata llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>()); 1478aa373180SNicolas Vasilache dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1479aa373180SNicolas Vasilache auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 14807b615a87SLei Zhang 14818e484b52SStanley Winata auto conv1DOp = rewriter.create<Conv1DOp>( 148298dbcff1Sgysit loc, newOutputType, ValueRange{newInput, newKernel}, 14837b615a87SLei Zhang ValueRange{newOutput}, stridesAttr, dilationsAttr); 14847b615a87SLei Zhang 1485aa373180SNicolas Vasilache // Insert back. 1486aa373180SNicolas Vasilache Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1487aa373180SNicolas Vasilache rewriter, loc, conv1DOp.getResult(0), output); 1488aa373180SNicolas Vasilache rewriter.replaceOp(convOp, inserted); 1489aa373180SNicolas Vasilache 1490ce2e198bSAlex Zinenko return conv1DOp; 1491ce2e198bSAlex Zinenko } 149298dbcff1Sgysit 14932b882f84SBenjamin Kramer template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp, 14942b882f84SBenjamin Kramer Conv1DNwcWcfOp>; 14952b882f84SBenjamin Kramer template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp, 14962b882f84SBenjamin Kramer Conv1DNcwFcwOp>; 14971a151fdcSMurali Vijayaraghavan template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, 14981a151fdcSMurali Vijayaraghavan PoolingNwcSumOp>; 14991a151fdcSMurali Vijayaraghavan template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, 15001a151fdcSMurali Vijayaraghavan PoolingNcwSumOp>; 15011a151fdcSMurali Vijayaraghavan template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, 15021a151fdcSMurali Vijayaraghavan PoolingNwcMaxOp>; 15031a151fdcSMurali Vijayaraghavan template struct linalg::DownscaleSizeOneWindowed2DConvolution< 15041a151fdcSMurali Vijayaraghavan PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>; 15051a151fdcSMurali Vijayaraghavan template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, 15061a151fdcSMurali Vijayaraghavan PoolingNwcMinOp>; 15071a151fdcSMurali Vijayaraghavan template struct linalg::DownscaleSizeOneWindowed2DConvolution< 15081a151fdcSMurali Vijayaraghavan PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>; 15091a151fdcSMurali Vijayaraghavan template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, 15101a151fdcSMurali Vijayaraghavan PoolingNcwMaxOp>; 15112b882f84SBenjamin Kramer 1512ce2e198bSAlex Zinenko FailureOr<DepthwiseConv1DNwcWcOp> 1513ce2e198bSAlex Zinenko DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( 1514ce2e198bSAlex Zinenko DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const { 15150a8e3dd4SMatthias Springer if (convOp.hasPureBufferSemantics()) 1516ce2e198bSAlex Zinenko return failure(); // To be implemented. 1517b828506eSNicolas Vasilache 1518d3b3f765SJacques Pienaar Value input = convOp.getInputs().front(); 1519d3b3f765SJacques Pienaar Value kernel = convOp.getInputs().back(); 1520d3b3f765SJacques Pienaar Value output = convOp.getOutputs().front(); 1521b828506eSNicolas Vasilache 15225550c821STres Popp auto inputType = dyn_cast<RankedTensorType>(input.getType()); 15235550c821STres Popp auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); 15245550c821STres Popp auto outputType = dyn_cast<RankedTensorType>(output.getType()); 1525b828506eSNicolas Vasilache 1526b828506eSNicolas Vasilache auto kernelShape = kernelType.getShape(); 1527b828506eSNicolas Vasilache auto outputShape = outputType.getShape(); 1528b828506eSNicolas Vasilache 1529b828506eSNicolas Vasilache // Only handle the case where at least one of the window dimensions is 1530b828506eSNicolas Vasilache // of size 1. Other cases can rely on tiling to reduce to such cases. 1531b828506eSNicolas Vasilache int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1532b828506eSNicolas Vasilache int64_t ohSize = outputShape[1], owSize = outputShape[2]; 1533b828506eSNicolas Vasilache bool removeH = (khSize == 1 && ohSize == 1); 1534b828506eSNicolas Vasilache bool removeW = (kwSize == 1 && owSize == 1); 1535b828506eSNicolas Vasilache if (!removeH && !removeW) 1536b828506eSNicolas Vasilache return failure(); 1537b828506eSNicolas Vasilache 1538b828506eSNicolas Vasilache // Get new shapes and types for all operands by removing the size-1 1539b828506eSNicolas Vasilache // dimension. 1540b828506eSNicolas Vasilache using RTTBuilder = RankedTensorType::Builder; 1541789c88e8SNicolas Vasilache RankedTensorType newInputType = 1542789c88e8SNicolas Vasilache RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 1543789c88e8SNicolas Vasilache RankedTensorType newKernelType = 1544789c88e8SNicolas Vasilache RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1545789c88e8SNicolas Vasilache RankedTensorType newOutputType = 1546789c88e8SNicolas Vasilache RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1547b828506eSNicolas Vasilache 1548b828506eSNicolas Vasilache // Rank-reduce operands. 1549b828506eSNicolas Vasilache Location loc = convOp.getLoc(); 1550b828506eSNicolas Vasilache Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1551b828506eSNicolas Vasilache rewriter, loc, input, newInputType); 1552b828506eSNicolas Vasilache Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1553b828506eSNicolas Vasilache rewriter, loc, kernel, newKernelType); 1554b828506eSNicolas Vasilache Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1555b828506eSNicolas Vasilache rewriter, loc, output, newOutputType); 1556b828506eSNicolas Vasilache 1557b828506eSNicolas Vasilache // Rank-reduce strides and dilations too. 1558b828506eSNicolas Vasilache // TODO: dropDim 1-liner helper. 1559d3b3f765SJacques Pienaar auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>()); 1560b828506eSNicolas Vasilache strides.erase(strides.begin() + (removeH ? 0 : 1)); 1561b828506eSNicolas Vasilache auto stridesAttr = rewriter.getI64VectorAttr(strides); 1562b828506eSNicolas Vasilache 1563d3b3f765SJacques Pienaar auto dilations = 1564d3b3f765SJacques Pienaar llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>()); 1565b828506eSNicolas Vasilache dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1566b828506eSNicolas Vasilache auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1567b828506eSNicolas Vasilache 1568b828506eSNicolas Vasilache auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( 1569b828506eSNicolas Vasilache loc, newOutputType, ValueRange{newInput, newKernel}, 1570b828506eSNicolas Vasilache ValueRange{newOutput}, stridesAttr, dilationsAttr); 1571b828506eSNicolas Vasilache 1572b828506eSNicolas Vasilache // Insert back. 1573b828506eSNicolas Vasilache Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1574b828506eSNicolas Vasilache rewriter, loc, conv1DOp.getResult(0), output); 1575b828506eSNicolas Vasilache rewriter.replaceOp(convOp, inserted); 1576b828506eSNicolas Vasilache 1577ce2e198bSAlex Zinenko return conv1DOp; 1578ce2e198bSAlex Zinenko } 15797b615a87SLei Zhang 1580991945f4SDevajith Valaparambil Sreeramaswamy FailureOr<Conv1DOp> 1581991945f4SDevajith Valaparambil Sreeramaswamy DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp, 1582991945f4SDevajith Valaparambil Sreeramaswamy PatternRewriter &rewriter) const { 15830a8e3dd4SMatthias Springer if (convOp.hasPureBufferSemantics()) 1584991945f4SDevajith Valaparambil Sreeramaswamy return failure(); // To be implemented. 1585991945f4SDevajith Valaparambil Sreeramaswamy 1586991945f4SDevajith Valaparambil Sreeramaswamy Value input = convOp.getInputs().front(); 1587991945f4SDevajith Valaparambil Sreeramaswamy Value kernel = convOp.getInputs().back(); 1588991945f4SDevajith Valaparambil Sreeramaswamy Value output = convOp.getOutputs().front(); 1589991945f4SDevajith Valaparambil Sreeramaswamy 15905550c821STres Popp auto inputType = dyn_cast<RankedTensorType>(input.getType()); 15915550c821STres Popp auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); 15925550c821STres Popp auto outputType = dyn_cast<RankedTensorType>(output.getType()); 1593991945f4SDevajith Valaparambil Sreeramaswamy 1594991945f4SDevajith Valaparambil Sreeramaswamy auto kernelShape = kernelType.getShape(); 1595991945f4SDevajith Valaparambil Sreeramaswamy auto outputShape = outputType.getShape(); 1596991945f4SDevajith Valaparambil Sreeramaswamy 1597991945f4SDevajith Valaparambil Sreeramaswamy // Only handle the case where at least one of the window dimensions is 1598991945f4SDevajith Valaparambil Sreeramaswamy // of size 1. Other cases can rely on tiling to reduce to such cases. 1599991945f4SDevajith Valaparambil Sreeramaswamy int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1600991945f4SDevajith Valaparambil Sreeramaswamy int64_t ohSize = outputShape[0], owSize = outputShape[1]; 1601991945f4SDevajith Valaparambil Sreeramaswamy bool removeH = (khSize == 1 && ohSize == 1); 1602991945f4SDevajith Valaparambil Sreeramaswamy bool removeW = (kwSize == 1 && owSize == 1); 1603991945f4SDevajith Valaparambil Sreeramaswamy if (!removeH && !removeW) 1604991945f4SDevajith Valaparambil Sreeramaswamy return failure(); 1605991945f4SDevajith Valaparambil Sreeramaswamy 1606991945f4SDevajith Valaparambil Sreeramaswamy // Get new shapes and types for all operands by removing the size-1 1607991945f4SDevajith Valaparambil Sreeramaswamy // dimension. 1608991945f4SDevajith Valaparambil Sreeramaswamy using RTTBuilder = RankedTensorType::Builder; 1609991945f4SDevajith Valaparambil Sreeramaswamy RankedTensorType newInputType = 1610991945f4SDevajith Valaparambil Sreeramaswamy RTTBuilder(inputType).dropDim((removeH ? 0 : 1)); 1611991945f4SDevajith Valaparambil Sreeramaswamy RankedTensorType newKernelType = 1612991945f4SDevajith Valaparambil Sreeramaswamy RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1613991945f4SDevajith Valaparambil Sreeramaswamy RankedTensorType newOutputType = 1614991945f4SDevajith Valaparambil Sreeramaswamy RTTBuilder(outputType).dropDim(removeH ? 0 : 1); 1615991945f4SDevajith Valaparambil Sreeramaswamy 1616991945f4SDevajith Valaparambil Sreeramaswamy // Rank-reduce operands. 1617991945f4SDevajith Valaparambil Sreeramaswamy Location loc = convOp.getLoc(); 1618991945f4SDevajith Valaparambil Sreeramaswamy Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1619991945f4SDevajith Valaparambil Sreeramaswamy rewriter, loc, input, newInputType); 1620991945f4SDevajith Valaparambil Sreeramaswamy Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1621991945f4SDevajith Valaparambil Sreeramaswamy rewriter, loc, kernel, newKernelType); 1622991945f4SDevajith Valaparambil Sreeramaswamy Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1623991945f4SDevajith Valaparambil Sreeramaswamy rewriter, loc, output, newOutputType); 1624991945f4SDevajith Valaparambil Sreeramaswamy 1625991945f4SDevajith Valaparambil Sreeramaswamy auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType, 1626991945f4SDevajith Valaparambil Sreeramaswamy ValueRange{newInput, newKernel}, 1627991945f4SDevajith Valaparambil Sreeramaswamy ValueRange{newOutput}); 1628991945f4SDevajith Valaparambil Sreeramaswamy 1629991945f4SDevajith Valaparambil Sreeramaswamy // Insert back. 1630991945f4SDevajith Valaparambil Sreeramaswamy Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1631991945f4SDevajith Valaparambil Sreeramaswamy rewriter, loc, conv1DOp.getResult(0), output); 1632991945f4SDevajith Valaparambil Sreeramaswamy rewriter.replaceOp(convOp, inserted); 1633991945f4SDevajith Valaparambil Sreeramaswamy 1634991945f4SDevajith Valaparambil Sreeramaswamy return conv1DOp; 1635991945f4SDevajith Valaparambil Sreeramaswamy } 1636991945f4SDevajith Valaparambil Sreeramaswamy 1637ad1efb51SNicolas Vasilache void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, 16387b615a87SLei Zhang PatternBenefit benefit) { 16398e484b52SStanley Winata patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp, 16408e484b52SStanley Winata Conv1DNwcWcfOp>, 16418e484b52SStanley Winata DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp, 16428e484b52SStanley Winata Conv1DNcwFcwOp>, 1643991945f4SDevajith Valaparambil Sreeramaswamy DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>( 1644991945f4SDevajith Valaparambil Sreeramaswamy patterns.getContext(), benefit); 16451a151fdcSMurali Vijayaraghavan patterns.add< 16461a151fdcSMurali Vijayaraghavan DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>, 16471a151fdcSMurali Vijayaraghavan DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>, 16481a151fdcSMurali Vijayaraghavan DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>, 16491a151fdcSMurali Vijayaraghavan DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp, 16501a151fdcSMurali Vijayaraghavan PoolingNwcMaxUnsignedOp>, 16511a151fdcSMurali Vijayaraghavan DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>, 16521a151fdcSMurali Vijayaraghavan DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp, 16531a151fdcSMurali Vijayaraghavan PoolingNwcMinUnsignedOp>, 16541a151fdcSMurali Vijayaraghavan DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>( 16551a151fdcSMurali Vijayaraghavan patterns.getContext(), benefit); 16567b615a87SLei Zhang } 165763b926afSAndrzej Warzyński 165807750882SAndrzej Warzyński void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) { 165907750882SAndrzej Warzyński patterns.add<DecomposeOuterUnitDimsPackOpPattern>(patterns.getContext()); 1660*25825682SAndrzej Warzyński patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(patterns.getContext()); 166163b926afSAndrzej Warzyński } 16621b2c8f10SAndrzej Warzyński 16631b2c8f10SAndrzej Warzyński void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) { 16641b2c8f10SAndrzej Warzyński patterns.add<DecomposePadOpPattern>(patterns.getContext()); 16651b2c8f10SAndrzej Warzyński } 1666