xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (revision 258256821753504836f797e38d83a8e88daa424d)
14a661602SNicolas Vasilache //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2307cfdf5SNicolas Vasilache //
3307cfdf5SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4307cfdf5SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
5307cfdf5SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6307cfdf5SNicolas Vasilache //
7307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===//
8307cfdf5SNicolas Vasilache //
9307cfdf5SNicolas Vasilache // This file implements logic and helpers to expose Linalg transforms as rewrite
10307cfdf5SNicolas Vasilache // patterns.
11307cfdf5SNicolas Vasilache //
12307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===//
13307cfdf5SNicolas Vasilache 
14307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15eda6f907SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
16abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1736550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
18b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
19307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Utils/Utils.h"
208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
21060208b4SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
220edb4127SLei Zhang #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
23644f0f83SHanhan Wang #include "mlir/Dialect/Tensor/Utils/Utils.h"
24644f0f83SHanhan Wang #include "mlir/Dialect/Utils/IndexingUtils.h"
25d624c1b5SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h"
26307cfdf5SNicolas Vasilache #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
28307cfdf5SNicolas Vasilache #include "mlir/IR/AffineExpr.h"
29307cfdf5SNicolas Vasilache #include "mlir/IR/Matchers.h"
30307cfdf5SNicolas Vasilache #include "mlir/Pass/Pass.h"
31307cfdf5SNicolas Vasilache #include "mlir/Support/LLVM.h"
32b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
333747eb9cSNicolas Vasilache #include "llvm/ADT/ScopeExit.h"
348faf35c0SMatthias Springer #include "llvm/ADT/TypeSwitch.h"
35307cfdf5SNicolas Vasilache #include "llvm/Support/Debug.h"
36307cfdf5SNicolas Vasilache #include "llvm/Support/raw_ostream.h"
37307cfdf5SNicolas Vasilache #include <type_traits>
381fc096afSMehdi Amini #include <utility>
39307cfdf5SNicolas Vasilache 
40307cfdf5SNicolas Vasilache #define DEBUG_TYPE "linalg-transforms"
41307cfdf5SNicolas Vasilache 
42307cfdf5SNicolas Vasilache using namespace mlir;
43307cfdf5SNicolas Vasilache using namespace mlir::linalg;
44307cfdf5SNicolas Vasilache 
4556ce65e2SNicolas Vasilache #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
4602371c5dSNicolas Vasilache #define DBGSNL() (llvm::dbgs() << "\n")
473110e7b0SNicolas Vasilache 
48bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===//
49bb2ae985SNicolas Vasilache // Transformations exposed as functional-style API calls.
50bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===//
51bb2ae985SNicolas Vasilache 
52bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===//
53bb2ae985SNicolas Vasilache // peelLoop transformation.
54bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===//
55bb2ae985SNicolas Vasilache 
56bb2ae985SNicolas Vasilache /// Try to peel and canonicalize loop `op` and return the new result.
57bb2ae985SNicolas Vasilache /// Also applies affine_min/max bounds simplification on the fly where relevant.
582190f8a8SMatthias Springer // TODO: Add support for scf.parallel and affine.for loops.
59d4c4e491SNicolas Vasilache SmallVector<Value> mlir::linalg::peelLoop(RewriterBase &rewriter,
60d4c4e491SNicolas Vasilache                                           Operation *op) {
618faf35c0SMatthias Springer   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
628faf35c0SMatthias Springer       .Case<scf::ForOp>([&](scf::ForOp forOp) {
638faf35c0SMatthias Springer         scf::ForOp partialIteration;
64bb2ae985SNicolas Vasilache         if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
658faf35c0SMatthias Springer                                                         partialIteration)))
668faf35c0SMatthias Springer           return partialIteration->getResults();
678faf35c0SMatthias Springer         assert(!partialIteration && "expected that loop was not peeled");
688faf35c0SMatthias Springer         return forOp->getResults();
698faf35c0SMatthias Springer       })
708faf35c0SMatthias Springer       .Default([&](Operation *op) { return op->getResults(); });
718faf35c0SMatthias Springer }
728faf35c0SMatthias Springer 
73bb2ae985SNicolas Vasilache /// Peel 'loops' and applies affine_min/max bounds simplification on the fly
74bb2ae985SNicolas Vasilache /// where relevant.
759a79b1b0SDiego Caballero void mlir::linalg::peelLoops(RewriterBase &rewriter,
769a79b1b0SDiego Caballero                              ArrayRef<scf::ForOp> loops) {
77d4c4e491SNicolas Vasilache   for (auto loopOp : loops)
78d4c4e491SNicolas Vasilache     peelLoop(rewriter, loopOp);
794a661602SNicolas Vasilache }
804a661602SNicolas Vasilache 
81bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===//
82bb2ae985SNicolas Vasilache // pack transformation.
83bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===//
84bb2ae985SNicolas Vasilache 
85bb2ae985SNicolas Vasilache #ifndef NDEBUG
86bb2ae985SNicolas Vasilache /// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
87bb2ae985SNicolas Vasilache static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
88bb2ae985SNicolas Vasilache   bool found = false;
89bb2ae985SNicolas Vasilache   for (AffineExpr e : map.getResults()) {
90bb2ae985SNicolas Vasilache     if (!e.isFunctionOfDim(dim))
91bb2ae985SNicolas Vasilache       continue;
92bb2ae985SNicolas Vasilache     if (found)
93bb2ae985SNicolas Vasilache       return false;
94bb2ae985SNicolas Vasilache     found = true;
95bb2ae985SNicolas Vasilache   }
96bb2ae985SNicolas Vasilache   return true;
97bb2ae985SNicolas Vasilache }
98bb2ae985SNicolas Vasilache #endif // NDEBUG
99bb2ae985SNicolas Vasilache 
100bb2ae985SNicolas Vasilache /// Return the index of the first result of `map` that is a function of
101bb2ae985SNicolas Vasilache /// AffineDimExpr(dim), std::nullopt otherwise.
102bb2ae985SNicolas Vasilache static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
103bb2ae985SNicolas Vasilache                                                             int64_t dim) {
104bb2ae985SNicolas Vasilache   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
105bb2ae985SNicolas Vasilache     AffineExpr expr = map.getResult(i);
106bb2ae985SNicolas Vasilache     if (!expr.isFunctionOfDim(dim))
107bb2ae985SNicolas Vasilache       continue;
108bb2ae985SNicolas Vasilache     return i;
109bb2ae985SNicolas Vasilache   }
110bb2ae985SNicolas Vasilache   return std::nullopt;
111bb2ae985SNicolas Vasilache }
112bb2ae985SNicolas Vasilache 
113bb2ae985SNicolas Vasilache /// Perform one step of packing of a LinalgOp's metadata along `dim` into the
114bb2ae985SNicolas Vasilache /// `newDim` at `iteratorTypes.size()` by:
115bb2ae985SNicolas Vasilache ///   1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
116bb2ae985SNicolas Vasilache ///   2. Appending a `newDim` to the domain of every indexing map.
117bb2ae985SNicolas Vasilache ///   3. For each operand (i.e. for each map in `indexingMaps`), perform packing
118bb2ae985SNicolas Vasilache ///      by potentially adding a `newDim` result to `map`.
119bb2ae985SNicolas Vasilache /// The preserved invariant is that `iteratorTypes.size()` is always equal to
120bb2ae985SNicolas Vasilache /// `map.getNumDims()` for every map in `indexingMaps`.
121bb2ae985SNicolas Vasilache ///
122bb2ae985SNicolas Vasilache /// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
123bb2ae985SNicolas Vasilache /// Return a vector that records the optional packing for each operand.
124bb2ae985SNicolas Vasilache /// Return failure if the packed indexing cannot be represented with a LinalgOp.
125bb2ae985SNicolas Vasilache ///
126bb2ae985SNicolas Vasilache /// Further details:
127bb2ae985SNicolas Vasilache /// ================
128bb2ae985SNicolas Vasilache /// The current implementation of packing (i.e. data tiling) consists of
129bb2ae985SNicolas Vasilache /// rewriting a linearized strip-mined form into a higher-dimensional access.
130bb2ae985SNicolas Vasilache /// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
131bb2ae985SNicolas Vasilache /// `I` into `4 * i + ii`, where `0 <= ii < 4`.
132bb2ae985SNicolas Vasilache /// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
133bb2ae985SNicolas Vasilache ///
134bb2ae985SNicolas Vasilache /// This rewrite into higher dimensional access is not possible for general
135bb2ae985SNicolas Vasilache /// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
136bb2ae985SNicolas Vasilache /// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
137bb2ae985SNicolas Vasilache /// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
138bb2ae985SNicolas Vasilache /// The rewrite of the access would be a form not representable in Linalg:
139bb2ae985SNicolas Vasilache ///   `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
140bb2ae985SNicolas Vasilache /// Note however that as `J` and `ii` iterate, the accesses do not have a
141bb2ae985SNicolas Vasilache /// particular alignment, so packing does not achieve alignment in this case
142bb2ae985SNicolas Vasilache ///
143bb2ae985SNicolas Vasilache /// In the future, we may want to consider a mixed-form that allows some
144bb2ae985SNicolas Vasilache /// alignment in the presence of multiple accesses:
145bb2ae985SNicolas Vasilache ///   `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
146bb2ae985SNicolas Vasilache /// And would rewrite accesses as:
147bb2ae985SNicolas Vasilache ///   `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
148bb2ae985SNicolas Vasilache static FailureOr<SmallVector<std::optional<int64_t>>>
149bb2ae985SNicolas Vasilache packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps,
150bb2ae985SNicolas Vasilache                        SmallVectorImpl<utils::IteratorType> &iteratorTypes,
151bb2ae985SNicolas Vasilache                        int64_t dim) {
152bb2ae985SNicolas Vasilache   int64_t newDim = iteratorTypes.size();
153bb2ae985SNicolas Vasilache   iteratorTypes.push_back(iteratorTypes[dim]);
154bb2ae985SNicolas Vasilache 
155bb2ae985SNicolas Vasilache   SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
156bb2ae985SNicolas Vasilache       indexingMaps.size(), std::nullopt);
157bb2ae985SNicolas Vasilache   SmallVector<AffineMap> newMaps;
158bb2ae985SNicolas Vasilache   for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
159bb2ae985SNicolas Vasilache        ++operandIdx) {
160bb2ae985SNicolas Vasilache     AffineMap map = indexingMaps[operandIdx];
161bb2ae985SNicolas Vasilache 
162bb2ae985SNicolas Vasilache     // Add the `newDim` to map whatever the case.
163bb2ae985SNicolas Vasilache     assert(map.getNumDims() == newDim && "num dims invariant violation");
164bb2ae985SNicolas Vasilache     map = map.shiftDims(1, newDim);
165bb2ae985SNicolas Vasilache 
166bb2ae985SNicolas Vasilache     // Get the at-most-1 index of the result that is a function of `dim`.
167bb2ae985SNicolas Vasilache     // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
168bb2ae985SNicolas Vasilache     // logically chunks dimension `dim` into `K * dim + newDim`, where the
169bb2ae985SNicolas Vasilache     // packing factor `K` is specified separately.
170bb2ae985SNicolas Vasilache     assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
171bb2ae985SNicolas Vasilache            "num results invariant violation");
172bb2ae985SNicolas Vasilache     auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
173bb2ae985SNicolas Vasilache     if (!maybeOperandDimensionToPack.has_value()) {
174bb2ae985SNicolas Vasilache       newMaps.push_back(map);
175bb2ae985SNicolas Vasilache       continue;
176bb2ae985SNicolas Vasilache     }
177bb2ae985SNicolas Vasilache 
178bb2ae985SNicolas Vasilache     // We can only pack AffineDimExpr atm.
1791609f1c2Slong.chen     if (!isa<AffineDimExpr>(map.getResult(maybeOperandDimensionToPack.value())))
180bb2ae985SNicolas Vasilache       return failure();
181bb2ae985SNicolas Vasilache 
182bb2ae985SNicolas Vasilache     // Add `newDim` to the results of the map.
183bb2ae985SNicolas Vasilache     map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
184bb2ae985SNicolas Vasilache                            map.getNumResults());
185bb2ae985SNicolas Vasilache     newMaps.push_back(map);
186bb2ae985SNicolas Vasilache 
187bb2ae985SNicolas Vasilache     // Record the that `operandIdx` is packed.
188bb2ae985SNicolas Vasilache     packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
189bb2ae985SNicolas Vasilache   }
190bb2ae985SNicolas Vasilache   indexingMaps = newMaps;
191bb2ae985SNicolas Vasilache 
192bb2ae985SNicolas Vasilache   return packedDimPerIndexingMap;
193bb2ae985SNicolas Vasilache }
194bb2ae985SNicolas Vasilache 
195bb2ae985SNicolas Vasilache namespace {
196bb2ae985SNicolas Vasilache 
197bb2ae985SNicolas Vasilache /// Helper struct to encode packing along one dimension of a LinalgOp.
198bb2ae985SNicolas Vasilache struct PackedOperandsDim {
199bb2ae985SNicolas Vasilache   OpFoldResult packedSize;
200bb2ae985SNicolas Vasilache   SmallVector<std::optional<int64_t>> packedDimForEachOperand;
201bb2ae985SNicolas Vasilache };
202bb2ae985SNicolas Vasilache 
203bb2ae985SNicolas Vasilache /// Helper struct to encode packing along all dimensions of a LinalgOp.
204bb2ae985SNicolas Vasilache struct PackedOperandsDimList {
2053af5ab21SMehdi Amini   void pushBack(PackedOperandsDim &&packedOperandsDims) {
206bb2ae985SNicolas Vasilache     spec.emplace_back(packedOperandsDims);
207bb2ae985SNicolas Vasilache   }
208bb2ae985SNicolas Vasilache   /// Return all the dims that have been packed for operand @ `operandPos`.
209bb2ae985SNicolas Vasilache   SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
210bb2ae985SNicolas Vasilache   /// Return all the pack sizes by which an operand @ `operandPos` is packed.
211bb2ae985SNicolas Vasilache   SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
212bb2ae985SNicolas Vasilache 
213bb2ae985SNicolas Vasilache private:
214bb2ae985SNicolas Vasilache   SmallVector<PackedOperandsDim> spec;
215bb2ae985SNicolas Vasilache };
216bb2ae985SNicolas Vasilache 
217bb2ae985SNicolas Vasilache } // namespace
218bb2ae985SNicolas Vasilache 
219ddcc5072SHanhan Wang FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
220d6590c1bSZhuoran Yin                                              tensor::PackOp packOp,
221d6590c1bSZhuoran Yin                                              bool lowerPadLikeWithInsertSlice) {
222ddcc5072SHanhan Wang   // 1. Filter out NYI cases.
223ddcc5072SHanhan Wang   auto packedTensorType =
2245550c821STres Popp       cast<RankedTensorType>(packOp->getResultTypes().front());
2259d3057c1SHanhan Wang   if (llvm::any_of(packOp.getStaticInnerTiles(),
2269d3057c1SHanhan Wang                    [](int64_t size) { return ShapedType::isDynamic(size); })) {
227ddcc5072SHanhan Wang     return rewriter.notifyMatchFailure(
228ddcc5072SHanhan Wang         packOp,
229ddcc5072SHanhan Wang         "non-static shape NYI, needs a more powerful tensor.expand_shape op");
230ddcc5072SHanhan Wang   }
231ddcc5072SHanhan Wang 
232ddcc5072SHanhan Wang   Location loc = packOp->getLoc();
233ddcc5072SHanhan Wang   OpBuilder::InsertionGuard g(rewriter);
234ddcc5072SHanhan Wang   rewriter.setInsertionPoint(packOp);
235ddcc5072SHanhan Wang 
2366f87b50bSHanhan Wang   // 2. Compute the permutation vector to shuffle packed shape into the shape
2377880b2c8SMax191   // before any outer or inner permutations have been applied.
238ddcc5072SHanhan Wang   PackingMetadata packingMetadata = computePackingMetadata(
239ddcc5072SHanhan Wang       packedTensorType.getRank(), packOp.getInnerDimsPos());
2407880b2c8SMax191   SmallVector<int64_t> packedToStripMinedShapePerm =
241adf838daSBalaji V. Iyer       tensor::getPackInverseDestPerm(packOp);
2426f87b50bSHanhan Wang 
243ddcc5072SHanhan Wang   // 3. Compute the stripMinedShape: this is the packed shape before any outer
244ddcc5072SHanhan Wang   // or inner permutations have been applied.
245ddcc5072SHanhan Wang   SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
2466f87b50bSHanhan Wang   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
247ddcc5072SHanhan Wang 
248ddcc5072SHanhan Wang   // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
2499d3057c1SHanhan Wang   SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
2509d3057c1SHanhan Wang                                  rewriter.getIndexAttr(0));
2519d3057c1SHanhan Wang   SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
2529d3057c1SHanhan Wang                                   rewriter.getIndexAttr(0));
2539d3057c1SHanhan Wang   for (auto [pos, innerSize] :
2549d3057c1SHanhan Wang        llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
25558e4231bSHanhan Wang     int outerPos =
25658e4231bSHanhan Wang         packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
2576596b0ddSMatthias Springer     OpFoldResult origSize =
2586596b0ddSMatthias Springer         tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos);
2596596b0ddSMatthias Springer     OpFoldResult outerSize =
2606596b0ddSMatthias Springer         tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos);
26158e4231bSHanhan Wang     AffineExpr s0, d0, d1;
26258e4231bSHanhan Wang     bindDims(rewriter.getContext(), d0, d1);
2639d3057c1SHanhan Wang     bindSymbols(rewriter.getContext(), s0);
26458e4231bSHanhan Wang     auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1);
26558e4231bSHanhan Wang     highs[pos] = affine::makeComposedFoldedAffineApply(
26658e4231bSHanhan Wang         rewriter, loc, map, {outerSize, origSize, innerSize});
2679d3057c1SHanhan Wang   }
268ddcc5072SHanhan Wang   RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
269ddcc5072SHanhan Wang       RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
270ddcc5072SHanhan Wang       packingMetadata.reassociations);
271ddcc5072SHanhan Wang   Value paddingValue = packOp.getPaddingValue();
272ddcc5072SHanhan Wang   if (!paddingValue) {
273ddcc5072SHanhan Wang     paddingValue = rewriter.create<arith::ConstantOp>(
274ddcc5072SHanhan Wang         loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
275ddcc5072SHanhan Wang   }
276ddcc5072SHanhan Wang   auto padOp =
2779d3057c1SHanhan Wang       rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
2789d3057c1SHanhan Wang                                      highs, paddingValue, /*nofold=*/false);
279ddcc5072SHanhan Wang 
280ddcc5072SHanhan Wang   LLVM_DEBUG(
281ddcc5072SHanhan Wang       DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
282ddcc5072SHanhan Wang                                                 DBGS() << "insertPositions: ");
2836f87b50bSHanhan Wang       DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
2846f87b50bSHanhan Wang                                       DBGS() << "outerPositions: ");
285ddcc5072SHanhan Wang       DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
286ddcc5072SHanhan Wang                                       DBGS() << "packedShape: ");
287ddcc5072SHanhan Wang       DBGSNL();
2886f87b50bSHanhan Wang       llvm::interleaveComma(packedToStripMinedShapePerm,
2896f87b50bSHanhan Wang                             DBGS() << "packedToStripMinedShapePerm: ");
290ddcc5072SHanhan Wang       DBGSNL(); llvm::interleaveComma(
291ddcc5072SHanhan Wang           packingMetadata.reassociations, DBGS() << "reassociations: ",
292ddcc5072SHanhan Wang           [&](ReassociationIndices ri) {
293ddcc5072SHanhan Wang             llvm::interleaveComma(ri, llvm::dbgs() << "|");
294ddcc5072SHanhan Wang           });
295ddcc5072SHanhan Wang       DBGSNL();
296ddcc5072SHanhan Wang       llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
297ddcc5072SHanhan Wang       DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
298ddcc5072SHanhan Wang 
299d6590c1bSZhuoran Yin   if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
300f32427e0SSpenser Bauman     // Pack ops which operate as simple pads may not produce legal
301f32427e0SSpenser Bauman     // tensor.insert_slice operations when the packed type does not rank reduce
302f32427e0SSpenser Bauman     // to the padded type.
303f32427e0SSpenser Bauman     SliceVerificationResult rankReduces =
304f32427e0SSpenser Bauman         isRankReducedType(packedTensorType, padOp.getResultType());
305f32427e0SSpenser Bauman 
306f32427e0SSpenser Bauman     if (rankReduces == SliceVerificationResult::Success) {
307ddcc5072SHanhan Wang       // This pack is just a plain pad.
308ddcc5072SHanhan Wang       // Just insert the pad in the higher ranked tensor.
309ddcc5072SHanhan Wang       // Offsets.
3107880b2c8SMax191       SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
3117880b2c8SMax191                                       rewriter.getIndexAttr(0));
312ddcc5072SHanhan Wang       // Strides.
3137880b2c8SMax191       SmallVector<OpFoldResult> ones(packOp.getDestRank(),
3147880b2c8SMax191                                      rewriter.getIndexAttr(1));
315ddcc5072SHanhan Wang       SmallVector<OpFoldResult> sizes =
316be6d96e9SMatthias Springer           tensor::getMixedSizes(rewriter, loc, packOp.getDest());
317ddcc5072SHanhan Wang 
318ddcc5072SHanhan Wang       auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
319e982d7fdSMax191           loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
320e982d7fdSMax191           /*offsets=*/zeros, sizes, /*strides=*/ones);
321ddcc5072SHanhan Wang 
322ddcc5072SHanhan Wang       LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
323ddcc5072SHanhan Wang 
324ddcc5072SHanhan Wang       rewriter.replaceOp(packOp, insertSliceOp->getResults());
325ddcc5072SHanhan Wang 
326ddcc5072SHanhan Wang       return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
327ddcc5072SHanhan Wang                              /*transposeOp=*/nullptr};
328ddcc5072SHanhan Wang     }
329f32427e0SSpenser Bauman   }
33097069a86SGaurav Shukla 
331ddcc5072SHanhan Wang   // 5. Expand from the padded result to the stripMinedShape.
33297069a86SGaurav Shukla   auto expandShapeResultType =
33397069a86SGaurav Shukla       RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
334ddcc5072SHanhan Wang   auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
33597069a86SGaurav Shukla       loc, expandShapeResultType, padOp.getResult(),
33697069a86SGaurav Shukla       packingMetadata.reassociations);
337ddcc5072SHanhan Wang 
338ddcc5072SHanhan Wang   // 6. Transpose stripMinedShape to packedShape.
3396f87b50bSHanhan Wang   SmallVector<int64_t> transpPerm =
3406f87b50bSHanhan Wang       invertPermutationVector(packedToStripMinedShapePerm);
341ddcc5072SHanhan Wang   auto transposeOp = rewriter.create<linalg::TransposeOp>(
3426f87b50bSHanhan Wang       loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
343ddcc5072SHanhan Wang 
344ddcc5072SHanhan Wang   LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
345ddcc5072SHanhan Wang              DBGS() << "reshape op: " << reshapeOp; DBGSNL();
3466f87b50bSHanhan Wang              llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
347ddcc5072SHanhan Wang              DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
348ddcc5072SHanhan Wang 
349ddcc5072SHanhan Wang   // 7. Replace packOp by transposeOp.
350ddcc5072SHanhan Wang   rewriter.replaceOp(packOp, transposeOp->getResults());
351ddcc5072SHanhan Wang 
352ddcc5072SHanhan Wang   return LowerPackResult{padOp, reshapeOp, transposeOp};
353ddcc5072SHanhan Wang }
354ddcc5072SHanhan Wang 
355d6590c1bSZhuoran Yin FailureOr<LowerUnPackOpResult>
356d6590c1bSZhuoran Yin linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
357d6590c1bSZhuoran Yin                     bool lowerUnpadLikeWithExtractSlice) {
358ddcc5072SHanhan Wang   Location loc = unPackOp->getLoc();
359ddcc5072SHanhan Wang   OpBuilder::InsertionGuard g(rewriter);
360ddcc5072SHanhan Wang   rewriter.setInsertionPoint(unPackOp);
361ddcc5072SHanhan Wang 
362b26ee975Ssrcarroll   RankedTensorType packedTensorType = unPackOp.getSourceType();
363ddcc5072SHanhan Wang   int64_t packedRank = packedTensorType.getRank();
364ddcc5072SHanhan Wang 
365ddcc5072SHanhan Wang   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
3665550c821STres Popp   auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
367d6590c1bSZhuoran Yin   if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
368ddcc5072SHanhan Wang     // This unpack is just a plain unpad.
369ddcc5072SHanhan Wang     // Just extract the slice from the higher ranked tensor.
370ddcc5072SHanhan Wang     ArrayRef<int64_t> destShape = destTensorType.getShape();
371ddcc5072SHanhan Wang     // The inner dimensions stay the same as the destination tensor, but the
372ddcc5072SHanhan Wang     // outer ones are additional 1s.
373ddcc5072SHanhan Wang     SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
374be6d96e9SMatthias Springer     sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()));
375ddcc5072SHanhan Wang 
376ddcc5072SHanhan Wang     auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
377ddcc5072SHanhan Wang         loc, destTensorType, unPackOp.getSource(),
378ddcc5072SHanhan Wang         SmallVector<OpFoldResult>(packedRank, zero), sizes,
379ddcc5072SHanhan Wang         SmallVector<OpFoldResult>(packedRank, one));
380ddcc5072SHanhan Wang 
381ddcc5072SHanhan Wang     rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
382ddcc5072SHanhan Wang 
383ddcc5072SHanhan Wang     return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
384ddcc5072SHanhan Wang                                /*reshapeOp=*/nullptr, extractSliceOp};
385ddcc5072SHanhan Wang   }
386ddcc5072SHanhan Wang 
3875b2f7a19SRyan Holt   // 1. Compute the permutation vector to shuffle packed shape into the shape
3885b2f7a19SRyan Holt   // before any outer or inner permutations have been applied.
3895b2f7a19SRyan Holt   PackingMetadata packingMetadata;
3905b2f7a19SRyan Holt   SmallVector<int64_t> packedToStripMinedShapePerm =
3915b2f7a19SRyan Holt       tensor::getUnPackInverseSrcPerm(unPackOp, packingMetadata);
3925b2f7a19SRyan Holt 
3935b2f7a19SRyan Holt   // 2. Compute the stripMinedShape: this is the packed shape without outer and
394ddcc5072SHanhan Wang   // inner permutations.
395ddcc5072SHanhan Wang   SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
3965b2f7a19SRyan Holt   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
397ddcc5072SHanhan Wang 
3985b2f7a19SRyan Holt   // 3. Transpose packedShape to stripMinedShape.
399ddcc5072SHanhan Wang   RankedTensorType stripMinedTensorType =
400ddcc5072SHanhan Wang       RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
401ddcc5072SHanhan Wang   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
402ddcc5072SHanhan Wang       stripMinedTensorType, packingMetadata.reassociations);
403b26ee975Ssrcarroll 
4045b2f7a19SRyan Holt   // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
405b26ee975Ssrcarroll   // permutation.
406b26ee975Ssrcarroll   SmallVector<OpFoldResult, 4> dims =
407b26ee975Ssrcarroll       tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
4085b2f7a19SRyan Holt   applyPermutationToVector(dims, packedToStripMinedShapePerm);
409b26ee975Ssrcarroll   auto emptyOp = rewriter.create<tensor::EmptyOp>(
410b26ee975Ssrcarroll       loc, dims, stripMinedTensorType.getElementType());
411ddcc5072SHanhan Wang   auto transposeOp = rewriter.create<linalg::TransposeOp>(
4125b2f7a19SRyan Holt       loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
413ddcc5072SHanhan Wang 
414ddcc5072SHanhan Wang   LLVM_DEBUG(
415ddcc5072SHanhan Wang       DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
416ddcc5072SHanhan Wang                                                 DBGS() << "insertPositions: ");
417ddcc5072SHanhan Wang       DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
418ddcc5072SHanhan Wang                                       DBGS() << "packedShape: ");
419ddcc5072SHanhan Wang       DBGSNL();
4205b2f7a19SRyan Holt       llvm::interleaveComma(packedToStripMinedShapePerm,
4215b2f7a19SRyan Holt                             DBGS() << "packedToStripMinedShapePerm: ");
422ddcc5072SHanhan Wang       DBGSNL(); llvm::interleaveComma(
423ddcc5072SHanhan Wang           packingMetadata.reassociations, DBGS() << "reassociations: ",
424ddcc5072SHanhan Wang           [&](ReassociationIndices ri) {
425ddcc5072SHanhan Wang             llvm::interleaveComma(ri, llvm::dbgs() << "|");
426ddcc5072SHanhan Wang           });
427ddcc5072SHanhan Wang       DBGSNL();
428ddcc5072SHanhan Wang       llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
429ddcc5072SHanhan Wang       DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
430ddcc5072SHanhan Wang 
4315b2f7a19SRyan Holt   // 4. Collapse from the stripMinedShape to the padded result.
432ddcc5072SHanhan Wang   auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
433ddcc5072SHanhan Wang       loc, collapsedType, transposeOp->getResult(0),
434ddcc5072SHanhan Wang       packingMetadata.reassociations);
435ddcc5072SHanhan Wang 
4365b2f7a19SRyan Holt   // 5. ExtractSlice.
437ddcc5072SHanhan Wang   int64_t destRank = destTensorType.getRank();
438ddcc5072SHanhan Wang   auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
439ddcc5072SHanhan Wang       loc, destTensorType, reshapeOp->getResult(0),
440ddcc5072SHanhan Wang       SmallVector<OpFoldResult>(destRank, zero),
4417050ff46Sqcolombet       tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
442ddcc5072SHanhan Wang       SmallVector<OpFoldResult>(destRank, one));
443ddcc5072SHanhan Wang 
4445b2f7a19SRyan Holt   // 6. Inject a copy to preserve DPS.
445d2f2ef84SLorenzo Chelini   auto copyOp = rewriter.create<linalg::CopyOp>(
446d2f2ef84SLorenzo Chelini       loc, extractSliceOp->getResult(0), unPackOp.getDest());
447d2f2ef84SLorenzo Chelini 
4485b2f7a19SRyan Holt   // 7. Replace unPackOp by copyOp.
449d2f2ef84SLorenzo Chelini   rewriter.replaceOp(unPackOp, copyOp->getResults());
450ddcc5072SHanhan Wang 
451ddcc5072SHanhan Wang   return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
452ddcc5072SHanhan Wang }
453ddcc5072SHanhan Wang 
454bb2ae985SNicolas Vasilache SmallVector<int64_t>
455bb2ae985SNicolas Vasilache PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
456bb2ae985SNicolas Vasilache   SmallVector<int64_t> res;
457c0fe2b89SMehdi Amini   for (auto &i : spec) {
458c0fe2b89SMehdi Amini     if (!i.packedDimForEachOperand[operandPos].has_value())
459bb2ae985SNicolas Vasilache       continue;
460c0fe2b89SMehdi Amini     res.push_back(i.packedDimForEachOperand[operandPos].value());
461bb2ae985SNicolas Vasilache   }
462bb2ae985SNicolas Vasilache   return res;
463bb2ae985SNicolas Vasilache }
464bb2ae985SNicolas Vasilache 
465bb2ae985SNicolas Vasilache SmallVector<OpFoldResult>
466bb2ae985SNicolas Vasilache PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
467bb2ae985SNicolas Vasilache   SmallVector<OpFoldResult> res;
468c0fe2b89SMehdi Amini   for (auto &i : spec) {
469c0fe2b89SMehdi Amini     if (!i.packedDimForEachOperand[operandPos].has_value())
470bb2ae985SNicolas Vasilache       continue;
471c0fe2b89SMehdi Amini     res.push_back(i.packedSize);
472bb2ae985SNicolas Vasilache   }
473bb2ae985SNicolas Vasilache   return res;
474bb2ae985SNicolas Vasilache }
475bb2ae985SNicolas Vasilache 
476bb2ae985SNicolas Vasilache /// Implement packing of a single LinalgOp by performing packing by
477bb2ae985SNicolas Vasilache /// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
478bb2ae985SNicolas Vasilache /// Return the packed Linalg op on success, failure otherwise.
479bb2ae985SNicolas Vasilache FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
480bb2ae985SNicolas Vasilache                                    linalg::LinalgOp linalgOp,
481bb2ae985SNicolas Vasilache                                    ArrayRef<OpFoldResult> packedSizes) {
482bb2ae985SNicolas Vasilache   if (packedSizes.size() != linalgOp.getNumLoops()) {
483bb2ae985SNicolas Vasilache     return rewriter.notifyMatchFailure(linalgOp,
484bb2ae985SNicolas Vasilache                                        "incorrect number of pack sizes");
485bb2ae985SNicolas Vasilache   }
486bb2ae985SNicolas Vasilache 
487bb2ae985SNicolas Vasilache   Location loc = linalgOp->getLoc();
488bb2ae985SNicolas Vasilache   SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
489bb2ae985SNicolas Vasilache   SmallVector<utils::IteratorType> iteratorTypes =
490bb2ae985SNicolas Vasilache       linalgOp.getIteratorTypesArray();
491bb2ae985SNicolas Vasilache   LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n";
492bb2ae985SNicolas Vasilache              llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
493bb2ae985SNicolas Vasilache              llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: ");
494bb2ae985SNicolas Vasilache              DBGSNL(););
495bb2ae985SNicolas Vasilache 
496bb2ae985SNicolas Vasilache   SmallVector<tensor::PackOp> packOps;
497bb2ae985SNicolas Vasilache   SmallVector<tensor::UnPackOp> unPackOps;
498bb2ae985SNicolas Vasilache   // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
499bb2ae985SNicolas Vasilache   PackedOperandsDimList listOfPackedOperandsDim;
500bb2ae985SNicolas Vasilache   for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
501bb2ae985SNicolas Vasilache     std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
502bb2ae985SNicolas Vasilache     // Skip tile sizes explicitly set to 0.
503bb2ae985SNicolas Vasilache     if (maybeConstant.has_value() && maybeConstant.value() == 0)
504bb2ae985SNicolas Vasilache       continue;
505bb2ae985SNicolas Vasilache 
506bb2ae985SNicolas Vasilache     PackedOperandsDim packedOperandsDims;
507bb2ae985SNicolas Vasilache     packedOperandsDims.packedSize = packedSizes[i];
508bb2ae985SNicolas Vasilache     FailureOr<SmallVector<std::optional<int64_t>>>
509bb2ae985SNicolas Vasilache         maybePackedDimForEachOperand =
510bb2ae985SNicolas Vasilache             packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
511bb2ae985SNicolas Vasilache     if (failed(maybePackedDimForEachOperand))
512bb2ae985SNicolas Vasilache       return failure();
513bb2ae985SNicolas Vasilache     packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
5143af5ab21SMehdi Amini     listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
515bb2ae985SNicolas Vasilache 
516bb2ae985SNicolas Vasilache     LLVM_DEBUG(
517bb2ae985SNicolas Vasilache         DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
518bb2ae985SNicolas Vasilache                << "\n";
519bb2ae985SNicolas Vasilache         llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
520bb2ae985SNicolas Vasilache         llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL();
521bb2ae985SNicolas Vasilache         llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
522bb2ae985SNicolas Vasilache                               DBGS() << "packedDimForEachOperand: ");
523bb2ae985SNicolas Vasilache         DBGSNL(););
524bb2ae985SNicolas Vasilache   }
525bb2ae985SNicolas Vasilache 
526bb2ae985SNicolas Vasilache   // Step 2. Propagate packing to all LinalgOp operands.
527bb2ae985SNicolas Vasilache   SmallVector<Value> inputsAndInits, results;
5280b2197b0SMatthias Springer   SmallVector<OpOperand *> initOperands = llvm::to_vector(llvm::map_range(
5290b2197b0SMatthias Springer       linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
5300b2197b0SMatthias Springer   SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
5310b2197b0SMatthias Springer   for (const auto &operandsList : {inputOperands, initOperands}) {
5320b2197b0SMatthias Springer     for (OpOperand *opOperand : operandsList) {
5330b2197b0SMatthias Springer       int64_t pos = opOperand->getOperandNumber();
5340b2197b0SMatthias Springer       Value operand = opOperand->get();
535bb2ae985SNicolas Vasilache       SmallVector<int64_t> innerPos =
536bb2ae985SNicolas Vasilache           listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
537bb2ae985SNicolas Vasilache       SmallVector<OpFoldResult> innerPackSizes =
538bb2ae985SNicolas Vasilache           listOfPackedOperandsDim.extractPackSizesForOperand(pos);
539bb2ae985SNicolas Vasilache       LLVM_DEBUG(
540bb2ae985SNicolas Vasilache           DBGS() << "operand: " << operand << "\n";
541bb2ae985SNicolas Vasilache           llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL();
542bb2ae985SNicolas Vasilache           llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: ");
543bb2ae985SNicolas Vasilache           DBGSNL(););
544bb2ae985SNicolas Vasilache       if (innerPackSizes.empty()) {
545bb2ae985SNicolas Vasilache         inputsAndInits.push_back(operand);
546bb2ae985SNicolas Vasilache         continue;
547bb2ae985SNicolas Vasilache       }
548bb2ae985SNicolas Vasilache       Value dest = tensor::PackOp::createDestinationTensor(
549bb2ae985SNicolas Vasilache           rewriter, loc, operand, innerPackSizes, innerPos,
550bb2ae985SNicolas Vasilache           /*outerDimsPerm=*/{});
551a5757c5bSChristian Sigg       ShapedType operandType = cast<ShapedType>(operand.getType());
552d21beb59SLorenzo Chelini       bool areConstantTiles =
553d21beb59SLorenzo Chelini           llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
554d21beb59SLorenzo Chelini             return getConstantIntValue(tile).has_value();
555d21beb59SLorenzo Chelini           });
556d21beb59SLorenzo Chelini       if (areConstantTiles && operandType.hasStaticShape() &&
5579466c4e6Ssrcarroll           !tensor::PackOp::requirePaddingValue(
5589466c4e6Ssrcarroll               operandType.getShape(), innerPos,
559a5757c5bSChristian Sigg               cast<ShapedType>(dest.getType()).getShape(), {},
560d21beb59SLorenzo Chelini               innerPackSizes)) {
561d21beb59SLorenzo Chelini         packOps.push_back(rewriter.create<tensor::PackOp>(
562d21beb59SLorenzo Chelini             loc, operand, dest, innerPos, innerPackSizes));
563d21beb59SLorenzo Chelini       } else {
564d21beb59SLorenzo Chelini         // TODO: value of the padding attribute should be determined by
565d21beb59SLorenzo Chelini         // consumers.
5666089d612SRahul Kayaith         auto zeroAttr =
567bb2ae985SNicolas Vasilache             rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
568bb2ae985SNicolas Vasilache         Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
569bb2ae985SNicolas Vasilache         packOps.push_back(rewriter.create<tensor::PackOp>(
570bb2ae985SNicolas Vasilache             loc, operand, dest, innerPos, innerPackSizes, zero));
571d21beb59SLorenzo Chelini       }
572bb2ae985SNicolas Vasilache       inputsAndInits.push_back(packOps.back());
573bb2ae985SNicolas Vasilache     }
574bb2ae985SNicolas Vasilache   }
575bb2ae985SNicolas Vasilache 
576bb2ae985SNicolas Vasilache   // Step 3. Build the packed op, use the type of `inits` as result types.
577bb2ae985SNicolas Vasilache   ValueRange inputs =
578bb2ae985SNicolas Vasilache       ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
579bb2ae985SNicolas Vasilache   ValueRange inits =
580bb2ae985SNicolas Vasilache       ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
581bb2ae985SNicolas Vasilache   auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
582bb2ae985SNicolas Vasilache       linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
583bb2ae985SNicolas Vasilache       iteratorTypes);
584bb2ae985SNicolas Vasilache   packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
585bb2ae985SNicolas Vasilache 
586bb2ae985SNicolas Vasilache   // Step 4. Propagate packing to all the op results.
587bb2ae985SNicolas Vasilache   for (OpResult result : packedLinalgOp->getResults()) {
588bb2ae985SNicolas Vasilache     int64_t resultNum = result.getResultNumber();
589bb2ae985SNicolas Vasilache     tensor::PackOp maybePackedInit =
590bb2ae985SNicolas Vasilache         inits[resultNum].getDefiningOp<tensor::PackOp>();
591bb2ae985SNicolas Vasilache     if (!maybePackedInit) {
592bb2ae985SNicolas Vasilache       results.push_back(result);
593bb2ae985SNicolas Vasilache       continue;
594bb2ae985SNicolas Vasilache     }
595bb2ae985SNicolas Vasilache     // Build the symmetrical UnPackOp to the existing PackOp.
596bb2ae985SNicolas Vasilache     unPackOps.push_back(rewriter.create<tensor::UnPackOp>(
597bb2ae985SNicolas Vasilache         packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
598bb2ae985SNicolas Vasilache         maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
599bb2ae985SNicolas Vasilache     results.push_back(unPackOps.back());
600bb2ae985SNicolas Vasilache   }
601bb2ae985SNicolas Vasilache 
602bb2ae985SNicolas Vasilache   // Step 5. Replace `linalgOp`.
603bb2ae985SNicolas Vasilache   rewriter.replaceOp(linalgOp, results);
604bb2ae985SNicolas Vasilache 
605bb2ae985SNicolas Vasilache   // Return packedLinalgOp.
606bb2ae985SNicolas Vasilache   return PackResult{packOps,
607bb2ae985SNicolas Vasilache                     cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
608bb2ae985SNicolas Vasilache                     unPackOps};
609bb2ae985SNicolas Vasilache }
610bb2ae985SNicolas Vasilache 
611bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===//
612bb2ae985SNicolas Vasilache // packTranspose transformation.
613bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===//
614bb2ae985SNicolas Vasilache 
615bb2ae985SNicolas Vasilache /// Return a copy of `tensorType` after permutation by `permutationVector`.
616bb2ae985SNicolas Vasilache // Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
617bb2ae985SNicolas Vasilache // but this would introduce a dependence on Dialect in IR.
618bb2ae985SNicolas Vasilache // TODO: Restructure.
619bb2ae985SNicolas Vasilache static RankedTensorType permuteShape(RankedTensorType tensorType,
620bb2ae985SNicolas Vasilache                                      ArrayRef<int64_t> permutationVector) {
621bb2ae985SNicolas Vasilache   SmallVector<int64_t> shape(tensorType.getShape());
622bb2ae985SNicolas Vasilache   applyPermutationToVector(shape, permutationVector);
623bb2ae985SNicolas Vasilache   return RankedTensorType::Builder(tensorType).setShape(shape);
624bb2ae985SNicolas Vasilache }
625bb2ae985SNicolas Vasilache 
626bb2ae985SNicolas Vasilache /// Return a new GenericOp obtained by transposing opOperand by the permutation
627bb2ae985SNicolas Vasilache /// vector:
628bb2ae985SNicolas Vasilache ///   - the corresponding indexing map is transposed by `permutation`
629bb2ae985SNicolas Vasilache ///   - the corresponding operand value is replaced by `transposedValue`
630bb2ae985SNicolas Vasilache /// `linalgOp` is replaced by the return op in the process.
631bb2ae985SNicolas Vasilache /// Asserts that `transposedValue` is of the proper transposed ShapedType.
632bb2ae985SNicolas Vasilache static LinalgOp transposeOneLinalgOperandAndReplace(
633bb2ae985SNicolas Vasilache     RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
634bb2ae985SNicolas Vasilache     ArrayRef<int64_t> permutation, Value transposedValue) {
635bb2ae985SNicolas Vasilache   // Sanity check the operand.
636bb2ae985SNicolas Vasilache   assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
637bb2ae985SNicolas Vasilache 
638bb2ae985SNicolas Vasilache   // Sanity check of the expected transposed tensor type.
639bb2ae985SNicolas Vasilache   auto tensorType = permuteShape(
6405550c821STres Popp       cast<RankedTensorType>(opOperand.get().getType()), permutation);
641bb2ae985SNicolas Vasilache   (void)tensorType;
642bb2ae985SNicolas Vasilache   assert(tensorType == transposedValue.getType() &&
643bb2ae985SNicolas Vasilache          "expected tensor type mismatch");
644bb2ae985SNicolas Vasilache 
645bb2ae985SNicolas Vasilache   // Compute the transposed indexing map.
646bb2ae985SNicolas Vasilache   // Sigh unsigned pollution.
647bb2ae985SNicolas Vasilache   SmallVector<unsigned> tmpTransposition = llvm::to_vector(
648bb2ae985SNicolas Vasilache       llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
649bb2ae985SNicolas Vasilache   AffineMap permutationMap =
650bb2ae985SNicolas Vasilache       AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
651bb2ae985SNicolas Vasilache   AffineMap transposedMap =
652bb2ae985SNicolas Vasilache       permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
653bb2ae985SNicolas Vasilache 
654bb2ae985SNicolas Vasilache   // Set the transposed indexing map in the proper position.
655bb2ae985SNicolas Vasilache   SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
656bb2ae985SNicolas Vasilache   indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
657bb2ae985SNicolas Vasilache   // Set the transposedValue in the proper operand position.
658bb2ae985SNicolas Vasilache   SmallVector<Value> operands = linalgOp->getOperands();
659bb2ae985SNicolas Vasilache   operands[opOperand.getOperandNumber()] = transposedValue;
660bb2ae985SNicolas Vasilache 
661bb2ae985SNicolas Vasilache   ValueRange operandsRef(operands);
662bb2ae985SNicolas Vasilache   auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
663bb2ae985SNicolas Vasilache       /*location=*/linalgOp->getLoc(),
664bb2ae985SNicolas Vasilache       /*resultTensorTypes=*/
665bb2ae985SNicolas Vasilache       operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
666bb2ae985SNicolas Vasilache       /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
667bb2ae985SNicolas Vasilache       /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
668bb2ae985SNicolas Vasilache       /*indexingMaps=*/indexingMaps,
669bb2ae985SNicolas Vasilache       /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
670bb2ae985SNicolas Vasilache   transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
671bb2ae985SNicolas Vasilache   rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
672bb2ae985SNicolas Vasilache 
673bb2ae985SNicolas Vasilache   return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
674bb2ae985SNicolas Vasilache }
675bb2ae985SNicolas Vasilache 
676bb2ae985SNicolas Vasilache FailureOr<PackTransposeResult>
677bb2ae985SNicolas Vasilache linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
678bb2ae985SNicolas Vasilache                       linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
679bb2ae985SNicolas Vasilache                       ArrayRef<int64_t> outerPerm,
680bb2ae985SNicolas Vasilache                       ArrayRef<int64_t> innerPerm) {
681bb2ae985SNicolas Vasilache   Location loc = linalgOp.getLoc();
682bb2ae985SNicolas Vasilache 
683bb2ae985SNicolas Vasilache   // Step 1. Transpose packOp.
684bb2ae985SNicolas Vasilache   rewriter.setInsertionPoint(packOp);
685bb2ae985SNicolas Vasilache   tensor::PackOp transposedPackOp =
686bb2ae985SNicolas Vasilache       packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
687bb2ae985SNicolas Vasilache 
688bb2ae985SNicolas Vasilache   if (!packOp.getResult().hasOneUse())
689bb2ae985SNicolas Vasilache     return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
690bb2ae985SNicolas Vasilache 
691bb2ae985SNicolas Vasilache   OpOperand &packUse = *packOp->getUses().begin();
692bb2ae985SNicolas Vasilache   if (packUse.getOwner() != linalgOp) {
693bb2ae985SNicolas Vasilache     return rewriter.notifyMatchFailure(
694bb2ae985SNicolas Vasilache         linalgOp, "not a single use by the LinalgOp target");
695bb2ae985SNicolas Vasilache   }
696bb2ae985SNicolas Vasilache   if (maybeUnPackOp &&
697bb2ae985SNicolas Vasilache       (!linalgOp.isDpsInit(&packUse) ||
698bb2ae985SNicolas Vasilache        maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
699bb2ae985SNicolas Vasilache     return rewriter.notifyMatchFailure(linalgOp,
700bb2ae985SNicolas Vasilache                                        "not produced by the LinalgOp target");
701bb2ae985SNicolas Vasilache   }
702bb2ae985SNicolas Vasilache 
703bb2ae985SNicolas Vasilache   // Step 2. Transpose linalgOp.
704bb2ae985SNicolas Vasilache   // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
705bb2ae985SNicolas Vasilache   // identity. Don't rely on it.
706bb2ae985SNicolas Vasilache   int64_t numLeadingDims = packOp.getSourceRank();
707bb2ae985SNicolas Vasilache   int64_t numTrailingDims = packOp.getInnerDimsPos().size();
708bb2ae985SNicolas Vasilache   // Step 2.a. Compute the permutation on the whole operand.
709bb2ae985SNicolas Vasilache   // Leading part just reuse the outerPerm.
710bb2ae985SNicolas Vasilache   SmallVector<int64_t> permutation(outerPerm);
711bb2ae985SNicolas Vasilache   if (permutation.empty())
712bb2ae985SNicolas Vasilache     llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
713bb2ae985SNicolas Vasilache   // Trailing part needs to reindex positions by `numLeadingDims`.
714bb2ae985SNicolas Vasilache   if (innerPerm.empty()) {
715bb2ae985SNicolas Vasilache     llvm::append_range(
716bb2ae985SNicolas Vasilache         permutation,
717bb2ae985SNicolas Vasilache         llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
718bb2ae985SNicolas Vasilache   } else {
719bb2ae985SNicolas Vasilache     llvm::append_range(permutation,
720bb2ae985SNicolas Vasilache                        llvm::map_range(innerPerm, [&](int64_t pos) {
721bb2ae985SNicolas Vasilache                          return numLeadingDims + pos;
722bb2ae985SNicolas Vasilache                        }));
723bb2ae985SNicolas Vasilache   }
724bb2ae985SNicolas Vasilache   if (!isPermutationVector(permutation))
725bb2ae985SNicolas Vasilache     return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
726bb2ae985SNicolas Vasilache 
727bb2ae985SNicolas Vasilache   // Step 2.b. Save the transposedPackUse operand number in case we need to
728bb2ae985SNicolas Vasilache   // get the tied OpResult after `linalgOp` has been replaced.
729bb2ae985SNicolas Vasilache   int64_t packUseOperandNumber = packUse.getOperandNumber();
730bb2ae985SNicolas Vasilache   // Step 2.c. Actually perform the transposition.
731bb2ae985SNicolas Vasilache   rewriter.setInsertionPoint(linalgOp);
732bb2ae985SNicolas Vasilache   linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
733bb2ae985SNicolas Vasilache       rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
734bb2ae985SNicolas Vasilache 
735bb2ae985SNicolas Vasilache   // Step 3. Maybe transpose unPackOp.
736bb2ae985SNicolas Vasilache   tensor::UnPackOp transposedUnPackOp;
737bb2ae985SNicolas Vasilache   if (maybeUnPackOp) {
738bb2ae985SNicolas Vasilache     OpOperand &opOperand =
739bb2ae985SNicolas Vasilache         transposedLinalgOp->getOpOperand(packUseOperandNumber);
740bb2ae985SNicolas Vasilache     OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
741bb2ae985SNicolas Vasilache     rewriter.setInsertionPoint(maybeUnPackOp);
742bb2ae985SNicolas Vasilache     transposedUnPackOp = maybeUnPackOp.createTransposedClone(
743bb2ae985SNicolas Vasilache         rewriter, loc, transposedResult, innerPerm, outerPerm);
744bb2ae985SNicolas Vasilache 
745bb2ae985SNicolas Vasilache     rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
746bb2ae985SNicolas Vasilache   }
747bb2ae985SNicolas Vasilache 
748bb2ae985SNicolas Vasilache   // Step 4. Finally, replace packOp now that we don't need it anymore.
749bb2ae985SNicolas Vasilache   rewriter.replaceOp(packOp, transposedPackOp->getResults());
750bb2ae985SNicolas Vasilache 
751bb2ae985SNicolas Vasilache   return PackTransposeResult{transposedPackOp, transposedLinalgOp,
752bb2ae985SNicolas Vasilache                              transposedUnPackOp};
753bb2ae985SNicolas Vasilache }
754bb2ae985SNicolas Vasilache 
755bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===//
7564d74c845SLorenzo Chelini // packMatmulGreedily transformation.
7574d74c845SLorenzo Chelini //===----------------------------------------------------------------------===//
7584d74c845SLorenzo Chelini 
7594d74c845SLorenzo Chelini /// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
7604d74c845SLorenzo Chelini /// and n are proper parallel dimensions and k is a proper reduction
7614d74c845SLorenzo Chelini /// dimension. Packing occurs by rewriting the op as a linalg.generic and
7624d74c845SLorenzo Chelini /// calling linalg::pack by `mnkPackedSizes`. The order of the packed
7634d74c845SLorenzo Chelini /// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
7644d74c845SLorenzo Chelini /// to reorder {m, n, k} into one of the 8 possible forms. The outer
7654d74c845SLorenzo Chelini /// dimensions of the operands are not permuted at this time, this is left for
7664d74c845SLorenzo Chelini /// future work.
7674d74c845SLorenzo Chelini FailureOr<PackResult>
7684d74c845SLorenzo Chelini linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
7694d74c845SLorenzo Chelini                            ArrayRef<OpFoldResult> mnkPackedSizes,
7704d74c845SLorenzo Chelini                            ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
7714d74c845SLorenzo Chelini                            ArrayRef<int64_t> mnkOrder) {
7724d74c845SLorenzo Chelini   assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
7734d74c845SLorenzo Chelini   assert((mnkPaddedSizesNextMultipleOf.empty() ||
7744d74c845SLorenzo Chelini           mnkPaddedSizesNextMultipleOf.size() == 3) &&
7754d74c845SLorenzo Chelini          "num of packing sizes next multiple should be empty or of size 3");
7764d74c845SLorenzo Chelini   assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
7774d74c845SLorenzo Chelini   assert(isPermutationVector(mnkOrder) && "expected a permutation");
7784d74c845SLorenzo Chelini 
7794d74c845SLorenzo Chelini   int64_t numLoops = linalgOp.getNumLoops();
7804d74c845SLorenzo Chelini   if (numLoops <= 2) {
7814d74c845SLorenzo Chelini     LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got "
7824d74c845SLorenzo Chelini                       << numLoops << "\nin: " << linalgOp << "\n");
7834d74c845SLorenzo Chelini     return rewriter.notifyMatchFailure(
7844d74c845SLorenzo Chelini         linalgOp, "need 3+ loops to find a matmul to pack");
7854d74c845SLorenzo Chelini   }
7864d74c845SLorenzo Chelini 
7874d74c845SLorenzo Chelini   // Locally adjust the desired iterator position of mnk and packing sizes.
7884d74c845SLorenzo Chelini   int64_t numPackedDims = mnkPackedSizes.size();
7894d74c845SLorenzo Chelini   SmallVector<int64_t> mmnnkkPos(numPackedDims);
7904d74c845SLorenzo Chelini   for (int64_t i = 0, e = numPackedDims; i < e; ++i)
7914d74c845SLorenzo Chelini     mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
7924d74c845SLorenzo Chelini   SmallVector<OpFoldResult> packedSizes(numPackedDims);
7934d74c845SLorenzo Chelini   for (int64_t i = 0, e = numPackedDims; i < e; ++i)
7944d74c845SLorenzo Chelini     packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
7954d74c845SLorenzo Chelini   SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
7964d74c845SLorenzo Chelini   for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
7974d74c845SLorenzo Chelini     paddedSizesNextMultipleOf[mnkOrder[i]] =
7984d74c845SLorenzo Chelini         mnkPaddedSizesNextMultipleOf.empty() ? 0
7994d74c845SLorenzo Chelini                                              : mnkPaddedSizesNextMultipleOf[i];
8004d74c845SLorenzo Chelini   }
8014d74c845SLorenzo Chelini 
8024d74c845SLorenzo Chelini   // 1. Infer dims that are important for matmul.
8034d74c845SLorenzo Chelini   FailureOr<ContractionDimensions> maybeDimensions =
8044d74c845SLorenzo Chelini       inferContractionDims(linalgOp);
8054d74c845SLorenzo Chelini   if (failed(maybeDimensions)) {
8064d74c845SLorenzo Chelini     LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp
8074d74c845SLorenzo Chelini                       << "\n");
8084d74c845SLorenzo Chelini     return rewriter.notifyMatchFailure(linalgOp,
8094d74c845SLorenzo Chelini                                        "couldn't infer matmul iterators");
8104d74c845SLorenzo Chelini   }
8114d74c845SLorenzo Chelini 
8124d74c845SLorenzo Chelini   // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
8134d74c845SLorenzo Chelini   // minor iterators. In cases with multiple options for m, n, k bias towards
8144d74c845SLorenzo Chelini   // the most minor embedding.
8154d74c845SLorenzo Chelini   // If we wanted a different normalization order, this is where it would have
8164d74c845SLorenzo Chelini   // to plug a heuristic.
8174d74c845SLorenzo Chelini   int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
8184d74c845SLorenzo Chelini           kPos = maybeDimensions->k.back();
8194d74c845SLorenzo Chelini   LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
8204d74c845SLorenzo Chelini              DBGS() << "Start packing generic op greedily with (m@" << mPos
8214d74c845SLorenzo Chelini                     << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
8224d74c845SLorenzo Chelini                     << "\n";);
8234d74c845SLorenzo Chelini 
8244d74c845SLorenzo Chelini   // 2.a. Rewrite as a generic.
8254d74c845SLorenzo Chelini   auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
8264d74c845SLorenzo Chelini   if (!genericOp) {
8274d74c845SLorenzo Chelini     FailureOr<GenericOp> generalizeResult =
8284d74c845SLorenzo Chelini         generalizeNamedOp(rewriter, linalgOp);
8294d74c845SLorenzo Chelini     assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
8304d74c845SLorenzo Chelini     genericOp = *generalizeResult;
8314d74c845SLorenzo Chelini   }
8324d74c845SLorenzo Chelini 
8334d74c845SLorenzo Chelini   // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
8344d74c845SLorenzo Chelini   // iterators. Note that this only normalized the iteration order and does
8354d74c845SLorenzo Chelini   // not change the indexings of any operand.
8364d74c845SLorenzo Chelini   SmallVector<int64_t> permutation =
8374d74c845SLorenzo Chelini       computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
8384d74c845SLorenzo Chelini   LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL(););
8394d74c845SLorenzo Chelini   // Sign .. unsigned pollution.
8404d74c845SLorenzo Chelini   SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
8414d74c845SLorenzo Chelini   FailureOr<GenericOp> interchangeResult =
8424d74c845SLorenzo Chelini       interchangeGenericOp(rewriter, genericOp, unsignedPerm);
8434d74c845SLorenzo Chelini   assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
8444d74c845SLorenzo Chelini   genericOp = *interchangeResult;
8454d74c845SLorenzo Chelini   LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
8464d74c845SLorenzo Chelini 
8474d74c845SLorenzo Chelini   // At this point, the op iterators are normalized to {leading, k, m, n}.
8484d74c845SLorenzo Chelini   // The layouts induced by packing will always be:
8494d74c845SLorenzo Chelini   //   - LHS{leading_lhs, kk, mm}
8504d74c845SLorenzo Chelini   //   - RHS{leading_rhs, kk, nn}
8514d74c845SLorenzo Chelini   //   - RES{leading_res, mm, nn}
8524d74c845SLorenzo Chelini   // If we wanted to change the packed order, we would reorder (k, m, n) to
8534d74c845SLorenzo Chelini   // something else above.
8544d74c845SLorenzo Chelini   //
8554d74c845SLorenzo Chelini   // Additional permutations of the outer dims of the operands (i.e.
8564d74c845SLorenzo Chelini   // leading_lhs, leading_rhs and leading_res) could follow by computing the
8574d74c845SLorenzo Chelini   // desired outerPerm for each operand.
8584d74c845SLorenzo Chelini   // This is left for future work.
8594d74c845SLorenzo Chelini 
8604d74c845SLorenzo Chelini   // TODO: this creates too much IR, go use reifyResultShapes.
8614d74c845SLorenzo Chelini   SmallVector<Range, 4> loopRanges =
8624d74c845SLorenzo Chelini       cast<LinalgOp>(genericOp.getOperation())
8634d74c845SLorenzo Chelini           .createLoopRanges(rewriter, genericOp.getLoc());
8644d74c845SLorenzo Chelini 
8654d74c845SLorenzo Chelini   // Add leading zeros to match numLoops, we only pack the last 3 dimensions
8664d74c845SLorenzo Chelini   // post interchange.
8674d74c845SLorenzo Chelini   LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
8684d74c845SLorenzo Chelini                                    DBGS() << "paddedSizesNextMultipleOf: ");
8694d74c845SLorenzo Chelini              DBGSNL(););
8704d74c845SLorenzo Chelini   LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ",
8714d74c845SLorenzo Chelini                                    [](Range r) { llvm::dbgs() << r.size; });
8724d74c845SLorenzo Chelini              DBGSNL(););
8734d74c845SLorenzo Chelini   SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
8744d74c845SLorenzo Chelini                                                 rewriter.getIndexAttr(0));
8754d74c845SLorenzo Chelini   for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
8764d74c845SLorenzo Chelini     if (paddedSizesNextMultipleOf[i] == 0) {
8774d74c845SLorenzo Chelini       adjustedPackedSizes.push_back(packedSizes[i]);
8784d74c845SLorenzo Chelini       continue;
8794d74c845SLorenzo Chelini     }
8804d74c845SLorenzo Chelini     AffineExpr d0, s0;
8814d74c845SLorenzo Chelini     bindDims(rewriter.getContext(), d0);
8824d74c845SLorenzo Chelini     bindSymbols(rewriter.getContext(), s0);
8834d74c845SLorenzo Chelini     adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply(
8844d74c845SLorenzo Chelini         rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
8854d74c845SLorenzo Chelini         {loopRanges[adjustedPackedSizes.size()].size,
8864d74c845SLorenzo Chelini          rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
8874d74c845SLorenzo Chelini   }
8884d74c845SLorenzo Chelini   LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
8894d74c845SLorenzo Chelini                                    DBGS() << "adjustedPackedSizes: ");
8904d74c845SLorenzo Chelini              DBGSNL(););
8914d74c845SLorenzo Chelini 
8924d74c845SLorenzo Chelini   // TODO: If we wanted to give the genericOp a name after packing, after
8934d74c845SLorenzo Chelini   // calling `pack` would be a good time. One would still need to check that
8944d74c845SLorenzo Chelini   // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
8954d74c845SLorenzo Chelini   // also allow degenerate matmul cases (i.e. matvec, dot).
8964d74c845SLorenzo Chelini   return pack(rewriter, genericOp, adjustedPackedSizes);
8974d74c845SLorenzo Chelini }
8984d74c845SLorenzo Chelini 
8994d74c845SLorenzo Chelini //===----------------------------------------------------------------------===//
900bb2ae985SNicolas Vasilache // Transformations exposed as rewrite patterns.
901bb2ae985SNicolas Vasilache //===----------------------------------------------------------------------===//
902bb2ae985SNicolas Vasilache 
903bb2ae985SNicolas Vasilache LinalgTilingOptions &
904bb2ae985SNicolas Vasilache mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
905bb2ae985SNicolas Vasilache   assert(!tileSizeComputationFunction && "tile sizes already set");
9065262865aSKazu Hirata   SmallVector<int64_t, 4> tileSizes(ts);
907bb2ae985SNicolas Vasilache   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
908bb2ae985SNicolas Vasilache     OpBuilder::InsertionGuard guard(b);
909bb2ae985SNicolas Vasilache     b.setInsertionPointToStart(
910bb2ae985SNicolas Vasilache         &op->getParentOfType<func::FuncOp>().getBody().front());
911bb2ae985SNicolas Vasilache     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
912bb2ae985SNicolas Vasilache       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
913bb2ae985SNicolas Vasilache       return v;
914bb2ae985SNicolas Vasilache     }));
915bb2ae985SNicolas Vasilache   };
916bb2ae985SNicolas Vasilache   return *this;
917bb2ae985SNicolas Vasilache }
918bb2ae985SNicolas Vasilache 
919ebc81537SAlexander Belyaev LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
920ebc81537SAlexander Belyaev     memref::CopyOp copyOp, PatternRewriter &rewriter) const {
921ebc81537SAlexander Belyaev   return vectorizeCopy(rewriter, copyOp);
922ebc81537SAlexander Belyaev }
923ebc81537SAlexander Belyaev 
92435df2f6fSYi Zhang /// Filling `dest` using FillOp constant padding value if possible.
92535df2f6fSYi Zhang /// Otherwise, generate a tensor::GenerateOp.
9261b2c8f10SAndrzej Warzyński Value DecomposePadOpPattern::createFillOrGenerateOp(
9271cff4cbdSNicolas Vasilache     RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
92835df2f6fSYi Zhang     const SmallVector<Value> &dynSizes) const {
92935df2f6fSYi Zhang   auto padValue = padOp.getConstantPaddingValue();
93035df2f6fSYi Zhang   if (padValue)
93135df2f6fSYi Zhang     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
93235df2f6fSYi Zhang 
93335df2f6fSYi Zhang   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
93435df2f6fSYi Zhang   auto generateOp = rewriter.create<tensor::GenerateOp>(
93535df2f6fSYi Zhang       padOp.getLoc(), padOp.getResultType(), dynSizes);
93635df2f6fSYi Zhang   // Copy region to new op.
9374d67b278SJeff Niu   IRMapping bvm;
93804235d07SJacques Pienaar   padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
93935df2f6fSYi Zhang   return generateOp;
94035df2f6fSYi Zhang }
94135df2f6fSYi Zhang 
94235df2f6fSYi Zhang LogicalResult
9431b2c8f10SAndrzej Warzyński DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
94435df2f6fSYi Zhang                                        PatternRewriter &rewriter) const {
94535df2f6fSYi Zhang   // Given an OpFoldResult, return an index-typed value.
94635df2f6fSYi Zhang   auto getIdxValue = [&](OpFoldResult ofr) {
94768f58812STres Popp     if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
94835df2f6fSYi Zhang       return val;
94935df2f6fSYi Zhang     return rewriter
950a54f4eaeSMogball         .create<arith::ConstantIndexOp>(
9514f279a57SKazu Hirata             padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
95235df2f6fSYi Zhang         .getResult();
95335df2f6fSYi Zhang   };
95435df2f6fSYi Zhang 
95535df2f6fSYi Zhang   auto resultType = padOp.getResultType();
95681ca5aa4SMatthias Springer   // Compute size of EmptyOp. Any combination of static/dynamic is supported.
95735df2f6fSYi Zhang   SmallVector<Value> dynSizes;
95835df2f6fSYi Zhang   SmallVector<int64_t> staticSizes;
95935df2f6fSYi Zhang   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
96035df2f6fSYi Zhang     if (resultType.isDynamicDim(dim)) {
9616596b0ddSMatthias Springer       auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(),
9626596b0ddSMatthias Springer                                                       padOp.getSource(), dim));
96335df2f6fSYi Zhang       // Add low and high padding value.
964a54f4eaeSMogball       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
96535df2f6fSYi Zhang           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
966a54f4eaeSMogball       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
96735df2f6fSYi Zhang           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
96835df2f6fSYi Zhang       dynSizes.push_back(plusHigh);
96935df2f6fSYi Zhang     }
97035df2f6fSYi Zhang     staticSizes.push_back(resultType.getDimSize(dim));
97135df2f6fSYi Zhang   }
97235df2f6fSYi Zhang 
97335df2f6fSYi Zhang   // Init tensor and fill it with padding.
97481ca5aa4SMatthias Springer   Value emptyTensor = rewriter.create<tensor::EmptyOp>(
97581ca5aa4SMatthias Springer       padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
97681ca5aa4SMatthias Springer   Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
97735df2f6fSYi Zhang 
97839ad84e4SAndrzej Warzyński   // Generate a InsertSliceOp for copying the PadOp source.
97935df2f6fSYi Zhang   auto sourceType = padOp.getSourceType();
980fd0c6f53SAlexander Belyaev   // Compute size of source of tensor::PadOp.
9816596b0ddSMatthias Springer   SmallVector<OpFoldResult> srcSizes =
9826596b0ddSMatthias Springer       tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource());
98335df2f6fSYi Zhang   // Strides of InsertSliceOp are all 1.
98435df2f6fSYi Zhang   SmallVector<OpFoldResult> strides(sourceType.getRank(),
98535df2f6fSYi Zhang                                     rewriter.getIndexAttr(1));
98635df2f6fSYi Zhang   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
98704235d07SJacques Pienaar       padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
98804235d07SJacques Pienaar       strides);
98935df2f6fSYi Zhang 
99035df2f6fSYi Zhang   return success();
99135df2f6fSYi Zhang }
99235df2f6fSYi Zhang 
993060208b4SMatthias Springer LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
994060208b4SMatthias Springer     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
995060208b4SMatthias Springer   if (!sliceOp.hasUnitStride())
99624199f53SMatthias Springer     return failure();
99724199f53SMatthias Springer 
99804235d07SJacques Pienaar   auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
9990edb4127SLei Zhang   if (!padOp)
10000edb4127SLei Zhang     return failure();
10010edb4127SLei Zhang 
10020edb4127SLei Zhang   bool zeroSliceGuard = true;
10030edb4127SLei Zhang   if (controlFn) {
100422426110SRamkumar Ramachandra     if (std::optional<bool> control = controlFn(sliceOp))
10056d5fc1e3SKazu Hirata       zeroSliceGuard = *control;
10060edb4127SLei Zhang     else
10070edb4127SLei Zhang       return failure();
10080edb4127SLei Zhang   }
10090edb4127SLei Zhang 
1010809e3d8cSMahesh Ravishankar   FailureOr<TilingResult> tilingResult =
10110edb4127SLei Zhang       tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
10120edb4127SLei Zhang                                sliceOp.getMixedSizes(), zeroSliceGuard);
1013809e3d8cSMahesh Ravishankar   if (failed(tilingResult))
1014809e3d8cSMahesh Ravishankar     return failure();
101524199f53SMatthias Springer   // All shapes are static and the data source is actually used. Rewrite into
10160edb4127SLei Zhang   // pad(extract_slice(x)).
1017809e3d8cSMahesh Ravishankar   rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
101824199f53SMatthias Springer   return success();
101924199f53SMatthias Springer }
10207b615a87SLei Zhang 
102166f84c8bSAndrzej Warzyński /// If padding value is set, returns a tensor.pad Op for the source tensor,
102266f84c8bSAndrzej Warzyński /// with the output shape matching the output of `packOp`. Otherwise, returns
102366f84c8bSAndrzej Warzyński /// the source directly.
102466f84c8bSAndrzej Warzyński ///
102566f84c8bSAndrzej Warzyński /// This method assumes that all outer dims for this pack Op are 1.
1026644f0f83SHanhan Wang static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
1027644f0f83SHanhan Wang                                            tensor::PackOp packOp) {
1028644f0f83SHanhan Wang   Value input = packOp.getSource();
1029644f0f83SHanhan Wang   if (!packOp.getPaddingValue()) {
1030644f0f83SHanhan Wang     return input;
1031644f0f83SHanhan Wang   }
1032644f0f83SHanhan Wang 
1033c1826aeeSAndrzej Warzyński   assert(llvm::all_of(packOp.getAllOuterDims(),
1034c1826aeeSAndrzej Warzyński                       [](int64_t val) { return val == 1; }) &&
1035c1826aeeSAndrzej Warzyński          "some outer dims are != 1");
1036c1826aeeSAndrzej Warzyński 
1037644f0f83SHanhan Wang   Location loc = packOp.getLoc();
1038644f0f83SHanhan Wang   ShapedType inputType = packOp.getSourceType();
1039644f0f83SHanhan Wang   int64_t inputRank = inputType.getRank();
1040644f0f83SHanhan Wang 
1041644f0f83SHanhan Wang   DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
1042644f0f83SHanhan Wang       packOp.getDimAndTileMapping();
104366f84c8bSAndrzej Warzyński 
104466f84c8bSAndrzej Warzyński   // The sizes of dynamic tiles
104566f84c8bSAndrzej Warzyński   SmallVector<Value> dynamicTileSizes;
104666f84c8bSAndrzej Warzyński 
104766f84c8bSAndrzej Warzyński   // Collect dims for the padded shape.
104866f84c8bSAndrzej Warzyński   SmallVector<int64_t> paddedShape;
104966f84c8bSAndrzej Warzyński   for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
105066f84c8bSAndrzej Warzyński     // 1. Non-tiled outer dims.
105166f84c8bSAndrzej Warzyński     // These dims should be 1 and we simply preserve them.
105266f84c8bSAndrzej Warzyński     if (!tileAndPosMapping.count(dimIdx)) {
105366f84c8bSAndrzej Warzyński       int64_t inputDimSize = inputType.getDimSize(dimIdx);
105466f84c8bSAndrzej Warzyński       assert(inputDimSize == 1 &&
105566f84c8bSAndrzej Warzyński              "with all outer dims == 1, this non-tiled input dim should be 1!");
105666f84c8bSAndrzej Warzyński       paddedShape.push_back(inputDimSize);
1057644f0f83SHanhan Wang       continue;
1058644f0f83SHanhan Wang     }
1059644f0f83SHanhan Wang 
106066f84c8bSAndrzej Warzyński     // 2. Tiled outer dims
106166f84c8bSAndrzej Warzyński     // As all outer dims == 1, it is safe to use the tile size for the padded
106266f84c8bSAndrzej Warzyński     // shape.
106366f84c8bSAndrzej Warzyński     OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
106466f84c8bSAndrzej Warzyński 
106566f84c8bSAndrzej Warzyński     // 2.1 Static tile sizes
106666f84c8bSAndrzej Warzyński     std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
106766f84c8bSAndrzej Warzyński     if (cstTileSize.has_value()) {
106866f84c8bSAndrzej Warzyński       paddedShape.push_back(cstTileSize.value());
106966f84c8bSAndrzej Warzyński       continue;
107066f84c8bSAndrzej Warzyński     }
107166f84c8bSAndrzej Warzyński 
107266f84c8bSAndrzej Warzyński     // 2.2 Dynamic tile sizes
107366f84c8bSAndrzej Warzyński     paddedShape.push_back(ShapedType::kDynamic);
107466f84c8bSAndrzej Warzyński 
107566f84c8bSAndrzej Warzyński     // Get the value that holds the dynamic size.
107666f84c8bSAndrzej Warzyński     dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1077644f0f83SHanhan Wang   }
1078644f0f83SHanhan Wang   auto resultType =
1079644f0f83SHanhan Wang       RankedTensorType::get(paddedShape, inputType.getElementType());
1080644f0f83SHanhan Wang   return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
108166f84c8bSAndrzej Warzyński                                  /*nofold=*/false, loc, builder,
108266f84c8bSAndrzej Warzyński                                  dynamicTileSizes);
1083644f0f83SHanhan Wang }
1084644f0f83SHanhan Wang 
1085009c053eSQuinn Dawkins // Normalizes a permutation on a higher rank space to its actual size, e.g.
1086009c053eSQuinn Dawkins //   perm = [1, 4, 2]
1087009c053eSQuinn Dawkins // becomes
1088009c053eSQuinn Dawkins //   norm = [0, 2, 1]
10893ebc6beeSHanhan Wang static SmallVector<int64_t>
1090009c053eSQuinn Dawkins getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) {
10913ebc6beeSHanhan Wang   constexpr int64_t kNonTiledMarker = -1;
10923ebc6beeSHanhan Wang   SmallVector<int64_t> vec(rank, kNonTiledMarker);
1093009c053eSQuinn Dawkins   for (auto [index, value] : llvm::enumerate(perm))
10943ebc6beeSHanhan Wang     vec[value] = index;
1095f4d75863SJakub Kuderski   SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector(
1096f4d75863SJakub Kuderski       vec, [&](int64_t v) { return v != kNonTiledMarker; });
1097009c053eSQuinn Dawkins   // This inverts the permutation in addition to normalizing so invert back.
1098009c053eSQuinn Dawkins   return invertPermutationVector(normalizedPerm);
1099009c053eSQuinn Dawkins }
1100009c053eSQuinn Dawkins 
1101009c053eSQuinn Dawkins // Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
1102009c053eSQuinn Dawkins // assuming rank reduction of unit outer dims.
1103009c053eSQuinn Dawkins static SmallVector<int64_t>
1104009c053eSQuinn Dawkins getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
1105009c053eSQuinn Dawkins                              ArrayRef<int64_t> innerDimsPos,
1106009c053eSQuinn Dawkins                              ArrayRef<int64_t> outerDimsPerm) {
1107009c053eSQuinn Dawkins   SmallVector<int64_t> rankReducedOuterDimsPerm;
1108009c053eSQuinn Dawkins   SmallVector<int64_t> outerDims;
1109009c053eSQuinn Dawkins   SmallVector<int64_t> innerDims;
1110009c053eSQuinn Dawkins   int64_t dim = 0;
1111009c053eSQuinn Dawkins   int64_t unpackedRank = shape.size();
1112009c053eSQuinn Dawkins   for (auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1113009c053eSQuinn Dawkins     if (llvm::is_contained(innerDimsPos, i)) {
1114009c053eSQuinn Dawkins       innerDims.push_back(dim++);
1115009c053eSQuinn Dawkins       continue;
1116009c053eSQuinn Dawkins     }
1117009c053eSQuinn Dawkins     if (shape[i] == 1)
1118009c053eSQuinn Dawkins       continue;
1119009c053eSQuinn Dawkins     outerDims.push_back(dim++);
1120009c053eSQuinn Dawkins     if (!outerDimsPerm.empty())
1121009c053eSQuinn Dawkins       rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1122009c053eSQuinn Dawkins   }
1123009c053eSQuinn Dawkins 
1124009c053eSQuinn Dawkins   // Get the position of the inner dims after permutation.
1125009c053eSQuinn Dawkins   SmallVector<int64_t> innerPerm =
1126009c053eSQuinn Dawkins       getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos);
1127009c053eSQuinn Dawkins   applyPermutationToVector<int64_t>(innerDims, innerPerm);
1128009c053eSQuinn Dawkins 
1129009c053eSQuinn Dawkins   // Ditto for the outer dims.
1130009c053eSQuinn Dawkins   SmallVector<int64_t> perm = outerDims;
1131009c053eSQuinn Dawkins 
1132009c053eSQuinn Dawkins   rankReducedOuterDimsPerm =
1133009c053eSQuinn Dawkins       getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm);
1134009c053eSQuinn Dawkins   if (!rankReducedOuterDimsPerm.empty())
1135009c053eSQuinn Dawkins     applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1136009c053eSQuinn Dawkins 
1137009c053eSQuinn Dawkins   // The tile always ends up as the inner most dims after packing.
1138009c053eSQuinn Dawkins   perm.append(innerDims);
1139009c053eSQuinn Dawkins 
11403ebc6beeSHanhan Wang   return perm;
11413ebc6beeSHanhan Wang }
11423ebc6beeSHanhan Wang 
114307750882SAndrzej Warzyński LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
1144644f0f83SHanhan Wang     tensor::PackOp packOp, PatternRewriter &rewriter) const {
1145009c053eSQuinn Dawkins   // TODO: support the case that outer dimensions are not all 1s. A
1146009c053eSQuinn Dawkins   // tensor.expand_shape will be generated in this case.
1147e9bafa35SAndrzej Warzyński   if (llvm::any_of(packOp.getAllOuterDims(),
1148c1826aeeSAndrzej Warzyński                    [](int64_t dim) { return dim != 1; })) {
1149009c053eSQuinn Dawkins     return rewriter.notifyMatchFailure(
1150e9bafa35SAndrzej Warzyński         packOp, "not all outer dimensions of the result are 1s");
1151009c053eSQuinn Dawkins   }
1152009c053eSQuinn Dawkins 
1153e9bafa35SAndrzej Warzyński   Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1154e9bafa35SAndrzej Warzyński   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1155644f0f83SHanhan Wang   Location loc = packOp.getLoc();
1156e9bafa35SAndrzej Warzyński 
1157009c053eSQuinn Dawkins   Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1158009c053eSQuinn Dawkins   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1159009c053eSQuinn Dawkins       packOp.getDimAndTileMapping();
1160c1826aeeSAndrzej Warzyński   int64_t srcRank = packOp.getSourceRank();
1161e9bafa35SAndrzej Warzyński   int64_t destRank = packOp.getDestRank();
11627ebfbf9cSAndrzej Warzyński   int64_t numTiles = destRank - srcRank;
1163e9bafa35SAndrzej Warzyński 
11647ebfbf9cSAndrzej Warzyński   if (!llvm::all_of(packOp.getInnerDimsPos(),
11657ebfbf9cSAndrzej Warzyński                     [&srcRank, &numTiles](int64_t dimPos) {
11667ebfbf9cSAndrzej Warzyński                       return dimPos >= (srcRank - numTiles - 1);
11677ebfbf9cSAndrzej Warzyński                     }))
11687ebfbf9cSAndrzej Warzyński     return rewriter.notifyMatchFailure(
11697ebfbf9cSAndrzej Warzyński         packOp, "Attempting to tile non-trailing source dims!");
1170e9bafa35SAndrzej Warzyński 
11717ebfbf9cSAndrzej Warzyński   // 1. Extract the inner tile sizes.
11727ebfbf9cSAndrzej Warzyński   // Where possible, values are replaced with constant attributes (to match the
11737ebfbf9cSAndrzej Warzyński   // behaviour of `getPackOpSourceOrPaddedSource`).
11747ebfbf9cSAndrzej Warzyński   SmallVector<OpFoldResult> tileSizes;
1175644f0f83SHanhan Wang   for (auto i : llvm::seq<unsigned>(0, srcRank)) {
1176009c053eSQuinn Dawkins     if (dimAndTileMapping.count(i)) {
11777ebfbf9cSAndrzej Warzyński       // Rather than taking the tile size as is, extact the actual constant
11787ebfbf9cSAndrzej Warzyński       // value Attribute where possible, e.g.:
11797ebfbf9cSAndrzej Warzyński       //    [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
11807ebfbf9cSAndrzej Warzyński       auto [_, tileSize] =
1181e9bafa35SAndrzej Warzyński           getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
11827ebfbf9cSAndrzej Warzyński       tileSizes.push_back(tileSize);
118366f84c8bSAndrzej Warzyński     }
1184009c053eSQuinn Dawkins   }
1185009c053eSQuinn Dawkins 
11867ebfbf9cSAndrzej Warzyński   // 2. Transpose the input to match the inner tile order:
1187e9bafa35SAndrzej Warzyński   //    %init = tensor.empty()
11887ebfbf9cSAndrzej Warzyński   //    %transposed_tile = linalg.transpose ins(%source_or_padded_source),
11897ebfbf9cSAndrzej Warzyński   //                                        outs(%init)
11907ebfbf9cSAndrzej Warzyński   // Two assumptions are made:
11917ebfbf9cSAndrzej Warzyński   //  1. All outer dims are 1 - the corresponding transposition doesn't matter.
11927ebfbf9cSAndrzej Warzyński   //  2. Inner dims position correspond to the trailing `numTiles` dims.
11937ebfbf9cSAndrzej Warzyński   SmallVector<int64_t> tilesPermNormalized =
11947ebfbf9cSAndrzej Warzyński       getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
11957ebfbf9cSAndrzej Warzyński   SmallVector<int64_t> srcPermForTranspose;
11967ebfbf9cSAndrzej Warzyński   for (int64_t i = 0; i < (srcRank - numTiles); i++)
11977ebfbf9cSAndrzej Warzyński     srcPermForTranspose.push_back(i);
11987ebfbf9cSAndrzej Warzyński 
11997ebfbf9cSAndrzej Warzyński   srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
1200bbf1d80dSQuinn Dawkins 
1201bbf1d80dSQuinn Dawkins   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
12027ebfbf9cSAndrzej Warzyński              llvm::interleaveComma(srcPermForTranspose, DBGS() << "perm: ");
12037ebfbf9cSAndrzej Warzyński              DBGSNL(););
1204bbf1d80dSQuinn Dawkins 
1205e9bafa35SAndrzej Warzyński   // 2.1 Create tensor.empty (init value for TransposeOp)
12067ebfbf9cSAndrzej Warzyński   SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
12077ebfbf9cSAndrzej Warzyński                                                  oneIdxAttr);
12087ebfbf9cSAndrzej Warzyński   transShapeForEmptyOp.append(tileSizes);
1209644f0f83SHanhan Wang 
12107ebfbf9cSAndrzej Warzyński   applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
12117ebfbf9cSAndrzej Warzyński                                          srcPermForTranspose);
12127ebfbf9cSAndrzej Warzyński   Value empty = rewriter.create<tensor::EmptyOp>(
12137ebfbf9cSAndrzej Warzyński       loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
1214e9bafa35SAndrzej Warzyński 
1215e9bafa35SAndrzej Warzyński   // 2.2 Create linalg.transpose
12167ebfbf9cSAndrzej Warzyński   auto transposedOp = rewriter.create<linalg::TransposeOp>(loc, input, empty,
12177ebfbf9cSAndrzej Warzyński                                                            srcPermForTranspose);
1218644f0f83SHanhan Wang 
1219e9bafa35SAndrzej Warzyński   // 3. Insert the inner tile to the destination:
1220e9bafa35SAndrzej Warzyński   //  %inserted_tile = tensor.insert_slice(%transposed_tile)
1221644f0f83SHanhan Wang   SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1222644f0f83SHanhan Wang   SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1223e9bafa35SAndrzej Warzyński   // Outer dims are all 1s!
1224e9bafa35SAndrzej Warzyński   SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
1225e9bafa35SAndrzej Warzyński                                        oneIdxAttr);
1226e9bafa35SAndrzej Warzyński   SmallVector<int64_t> writeShape;
1227644f0f83SHanhan Wang 
1228e9bafa35SAndrzej Warzyński   for (auto tileSize : packOp.getMixedTiles()) {
1229e9bafa35SAndrzej Warzyński     auto [tileSizeStatic, tileSizeOfr] =
1230e9bafa35SAndrzej Warzyński         getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
1231e9bafa35SAndrzej Warzyński     writeSizes.push_back(tileSizeOfr);
1232e9bafa35SAndrzej Warzyński     writeShape.push_back(tileSizeStatic);
1233e9bafa35SAndrzej Warzyński   }
1234e9bafa35SAndrzej Warzyński 
1235e9bafa35SAndrzej Warzyński   // 4. Replace tensor.packOp with tensor.insert_slice created above
1236644f0f83SHanhan Wang   auto insert = rewriter.create<tensor::InsertSliceOp>(
1237644f0f83SHanhan Wang       loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1238644f0f83SHanhan Wang       writeSizes, writeStrides);
1239644f0f83SHanhan Wang   rewriter.replaceOp(packOp, insert.getResult());
1240644f0f83SHanhan Wang 
1241644f0f83SHanhan Wang   return success();
1242644f0f83SHanhan Wang }
1243644f0f83SHanhan Wang 
124407750882SAndrzej Warzyński LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12453ebc6beeSHanhan Wang     tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const {
12463ebc6beeSHanhan Wang   int64_t srcRank = unpackOp.getSourceRank();
12473ebc6beeSHanhan Wang   int64_t destRank = unpackOp.getDestRank();
12483ebc6beeSHanhan Wang   ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
1249009c053eSQuinn Dawkins   ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1250c1826aeeSAndrzej Warzyński   if (llvm::any_of(unpackOp.getTiledOuterDims(),
1251c1826aeeSAndrzej Warzyński                    [](int64_t dim) { return dim != 1; })) {
12523ebc6beeSHanhan Wang     return rewriter.notifyMatchFailure(
1253009c053eSQuinn Dawkins         unpackOp,
1254009c053eSQuinn Dawkins         "require the tiled outer dimensions of the result are all 1s");
12553ebc6beeSHanhan Wang   }
12563ebc6beeSHanhan Wang 
125758da789eSAndrzej Warzyński   // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
125858da789eSAndrzej Warzyński   //    %extracted_tile = tensor.extract_slice(%unpack_op_input)
12593ebc6beeSHanhan Wang   Location loc = unpackOp.getLoc();
1260009c053eSQuinn Dawkins   Value source = unpackOp.getSource();
1261009c053eSQuinn Dawkins   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1262009c053eSQuinn Dawkins       unpackOp.getDimAndTileMapping();
12633ebc6beeSHanhan Wang   Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
12643ebc6beeSHanhan Wang   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
126558da789eSAndrzej Warzyński 
126658da789eSAndrzej Warzyński   // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
126758da789eSAndrzej Warzyński   // dims:
126858da789eSAndrzej Warzyński   //    [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
126958da789eSAndrzej Warzyński   SmallVector<int64_t> readShapeForExtractSlice;
127058da789eSAndrzej Warzyński   // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
127158da789eSAndrzej Warzyński   // outer-tiled-dims being all 1), this will be
127258da789eSAndrzej Warzyński   //    [ outer-untiled-dims, tile-sizes ]
127358da789eSAndrzej Warzyński   SmallVector<OpFoldResult> extractSliceSizes;
127458da789eSAndrzej Warzyński   // The offset and strides attributes for ExtractSliceOp.
127558da789eSAndrzej Warzyński   SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
127658da789eSAndrzej Warzyński   SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
127758da789eSAndrzej Warzyński 
127858da789eSAndrzej Warzyński   // Shape for EmptyOp that's used as the init value for TransposeOp below.
127958da789eSAndrzej Warzyński   // This should be:
128058da789eSAndrzej Warzyński   //    [ outer-untiled-dims, tile-sizes ]
128158da789eSAndrzej Warzyński   // However, skip unit dims - TransposeOp (below) applies rank-reduced
128258da789eSAndrzej Warzyński   // permutation.
128358da789eSAndrzej Warzyński   SmallVector<OpFoldResult> shapeForEmptyOp;
128458da789eSAndrzej Warzyński 
1285009c053eSQuinn Dawkins   for (auto i : llvm::seq<unsigned>(0, destRank)) {
128658da789eSAndrzej Warzyński     // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
128758da789eSAndrzej Warzyński     //
128858da789eSAndrzej Warzyński     // As all outer tiled dims are 1, so the corresponding
128958da789eSAndrzej Warzyński     // slice size to read will also 1. As this will be rank-reducing "extract
129058da789eSAndrzej Warzyński     // slice" (i.e. the unit dims will be "collapsed"), there's no need to
129158da789eSAndrzej Warzyński     // update:
129258da789eSAndrzej Warzyński     //  * the output shape for ExtractSliceOp, nor
129358da789eSAndrzej Warzyński     //  * the shape for EmptyOp.
1294009c053eSQuinn Dawkins     if (dimAndTileMapping.count(i)) {
129558da789eSAndrzej Warzyński       extractSliceSizes.push_back(oneIdxAttr);
1296009c053eSQuinn Dawkins       continue;
1297009c053eSQuinn Dawkins     }
12983ebc6beeSHanhan Wang 
129958da789eSAndrzej Warzyński     // Compute sizes attribute for ExtractSliceOp + EmptyOp -
130058da789eSAndrzej Warzyński     // outer-untiled-dims
1301009c053eSQuinn Dawkins     if (ShapedType::isDynamic(srcShape[i])) {
130258da789eSAndrzej Warzyński       OpFoldResult dynamicDim =
1303a44b787eSqcolombet           rewriter.create<tensor::DimOp>(loc, source, i).getResult();
130458da789eSAndrzej Warzyński       extractSliceSizes.push_back(dynamicDim);
130558da789eSAndrzej Warzyński       shapeForEmptyOp.push_back(dynamicDim);
1306009c053eSQuinn Dawkins     } else {
130758da789eSAndrzej Warzyński       extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1308009c053eSQuinn Dawkins       if (srcShape[i] != 1)
130958da789eSAndrzej Warzyński         shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
1310009c053eSQuinn Dawkins     }
131158da789eSAndrzej Warzyński     // Compute the output shape for ExtractSliceOp  - outer-untiled-dims (take
131258da789eSAndrzej Warzyński     // into account rank-reducing)
131358da789eSAndrzej Warzyński     if (srcShape[i] != 1) {
131458da789eSAndrzej Warzyński       readShapeForExtractSlice.push_back(srcShape[i]);
131558da789eSAndrzej Warzyński     }
131658da789eSAndrzej Warzyński   }
131758da789eSAndrzej Warzyński   // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
131858da789eSAndrzej Warzyński   // shape for EmptyOp.
13193ebc6beeSHanhan Wang   auto mixedTiles = unpackOp.getMixedTiles();
132058da789eSAndrzej Warzyński   extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
132158da789eSAndrzej Warzyński   shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
13223ebc6beeSHanhan Wang 
13233ebc6beeSHanhan Wang   // Explicitly create the type for extract_slice op because the inner tile
13243ebc6beeSHanhan Wang   // size could be 1. We want to represent the whole inner tile in this case.
1325009c053eSQuinn Dawkins   auto tileShape = srcShape.drop_front(destRank);
1326009c053eSQuinn Dawkins   // Append the inner tile shape to the permuted and rank-reduced outer shape.
132758da789eSAndrzej Warzyński   readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
13283ebc6beeSHanhan Wang   Type elemType = unpackOp.getSourceType().getElementType();
132958da789eSAndrzej Warzyński   auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
13303ebc6beeSHanhan Wang   Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
133158da789eSAndrzej Warzyński       loc, readType, unpackOp.getSource(), extractSliceOffsets,
133258da789eSAndrzej Warzyński       extractSliceSizes, extractSliceStrides);
13333ebc6beeSHanhan Wang 
13343ebc6beeSHanhan Wang   // 2. Transpose the tile to match the outer corresponding tile order.
1335009c053eSQuinn Dawkins   SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
1336009c053eSQuinn Dawkins       srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1337009c053eSQuinn Dawkins   // Unpack is a transition out of packed space so we invert the permutation.
1338009c053eSQuinn Dawkins   perm = invertPermutationVector(perm);
133958da789eSAndrzej Warzyński   applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
13403ebc6beeSHanhan Wang 
1341a44b787eSqcolombet   Value empty =
134258da789eSAndrzej Warzyński       rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType);
13433ebc6beeSHanhan Wang   auto transposedOp =
13443ebc6beeSHanhan Wang       rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
13453ebc6beeSHanhan Wang 
13463ebc6beeSHanhan Wang   // 3. Handle in-complete tiles if needed. It truncates trailing data from the
13473ebc6beeSHanhan Wang   // transposed tile.
134858da789eSAndrzej Warzyński   int numLoops = shapeForEmptyOp.size();
13493ebc6beeSHanhan Wang   SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
13503ebc6beeSHanhan Wang   SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
13513ebc6beeSHanhan Wang   SmallVector<OpFoldResult> tileSizes;
1352009c053eSQuinn Dawkins   ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
1353009c053eSQuinn Dawkins   for (auto i : llvm::seq<unsigned>(0, destRank)) {
1354009c053eSQuinn Dawkins     if (dimAndTileMapping.count(i) || destShape[i] != 1)
13556596b0ddSMatthias Springer       tileSizes.push_back(
13566596b0ddSMatthias Springer           tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i));
1357009c053eSQuinn Dawkins   }
13583ebc6beeSHanhan Wang 
13593ebc6beeSHanhan Wang   auto partialTile = rewriter.create<tensor::ExtractSliceOp>(
13603ebc6beeSHanhan Wang       loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
13613ebc6beeSHanhan Wang 
13623ebc6beeSHanhan Wang   // 4. Insert the result to the destination tensor.
13633ebc6beeSHanhan Wang   SmallVector<OpFoldResult> writeSizes;
13643ebc6beeSHanhan Wang   SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
13653ebc6beeSHanhan Wang   SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
13663ebc6beeSHanhan Wang   for (int i = 0, idx = 0; i < destRank; ++i) {
1367009c053eSQuinn Dawkins     if (dimAndTileMapping.count(i) || destShape[i] != 1)
13683ebc6beeSHanhan Wang       writeSizes.push_back(tileSizes[idx++]);
13693ebc6beeSHanhan Wang     else
13703ebc6beeSHanhan Wang       writeSizes.push_back(oneIdxAttr);
13713ebc6beeSHanhan Wang   }
13723ebc6beeSHanhan Wang   auto insert = rewriter.create<tensor::InsertSliceOp>(
13733ebc6beeSHanhan Wang       loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
13743ebc6beeSHanhan Wang       writeStrides);
13753ebc6beeSHanhan Wang   rewriter.replaceOp(unpackOp, insert.getResult());
13763ebc6beeSHanhan Wang 
13773ebc6beeSHanhan Wang   return success();
13783ebc6beeSHanhan Wang }
13793ebc6beeSHanhan Wang 
13807b615a87SLei Zhang // The following are patterns for downscaling convolution ops with size-1
13817b615a87SLei Zhang // window dimensions.
13827b615a87SLei Zhang //
13837b615a87SLei Zhang // Note that we'd eventually want to write such transformations in a generic
13847b615a87SLei Zhang // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
13857b615a87SLei Zhang // and then turning back to named ops. But for now it's fine to have a few
13867b615a87SLei Zhang // patterns matching special ops to get started.
13877b615a87SLei Zhang 
13888e484b52SStanley Winata template <typename Conv2DOp, typename Conv1DOp>
13898e484b52SStanley Winata FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
13908e484b52SStanley Winata     returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
13910a8e3dd4SMatthias Springer   if (convOp.hasPureBufferSemantics())
1392ce2e198bSAlex Zinenko     return failure(); // To be implemented.
13937b615a87SLei Zhang 
1394d3b3f765SJacques Pienaar   Value input = convOp.getInputs().front();
1395d3b3f765SJacques Pienaar   Value kernel = convOp.getInputs().back();
1396d3b3f765SJacques Pienaar   Value output = convOp.getOutputs().front();
13977b615a87SLei Zhang 
13985550c821STres Popp   auto inputType = dyn_cast<RankedTensorType>(input.getType());
13995550c821STres Popp   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
14005550c821STres Popp   auto outputType = dyn_cast<RankedTensorType>(output.getType());
14017b615a87SLei Zhang 
140298dbcff1Sgysit   auto kernelShape = kernelType.getShape();
14037b615a87SLei Zhang   auto outputShape = outputType.getShape();
14047b615a87SLei Zhang 
14058e484b52SStanley Winata   // Get domain indices based on conv2D layout.
14061a151fdcSMurali Vijayaraghavan   auto [khIndex, kwIndex, ohIndex, owIndex] =
140702371c5dSNicolas Vasilache       TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>(
140802371c5dSNicolas Vasilache           convOp)
14098e484b52SStanley Winata           .Case([&](linalg::Conv2DNhwcHwcfOp op) {
14101a151fdcSMurali Vijayaraghavan             return std::make_tuple(0, 1, 1, 2);
14118e484b52SStanley Winata           })
14128e484b52SStanley Winata           .Case([&](linalg::Conv2DNchwFchwOp op) {
14131a151fdcSMurali Vijayaraghavan             return std::make_tuple(2, 3, 2, 3);
14141a151fdcSMurali Vijayaraghavan           })
14151a151fdcSMurali Vijayaraghavan           .Case([&](linalg::PoolingNhwcSumOp op) {
14161a151fdcSMurali Vijayaraghavan             return std::make_tuple(0, 1, 1, 2);
14171a151fdcSMurali Vijayaraghavan           })
14181a151fdcSMurali Vijayaraghavan           .Case([&](linalg::PoolingNchwSumOp op) {
14191a151fdcSMurali Vijayaraghavan             return std::make_tuple(0, 1, 2, 3);
14201a151fdcSMurali Vijayaraghavan           })
14211a151fdcSMurali Vijayaraghavan           .Case([&](linalg::PoolingNhwcMaxOp op) {
14221a151fdcSMurali Vijayaraghavan             return std::make_tuple(0, 1, 1, 2);
14231a151fdcSMurali Vijayaraghavan           })
14241a151fdcSMurali Vijayaraghavan           .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
14251a151fdcSMurali Vijayaraghavan             return std::make_tuple(0, 1, 1, 2);
14261a151fdcSMurali Vijayaraghavan           })
14271a151fdcSMurali Vijayaraghavan           .Case([&](linalg::PoolingNhwcMinOp op) {
14281a151fdcSMurali Vijayaraghavan             return std::make_tuple(0, 1, 1, 2);
14291a151fdcSMurali Vijayaraghavan           })
14301a151fdcSMurali Vijayaraghavan           .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
14311a151fdcSMurali Vijayaraghavan             return std::make_tuple(0, 1, 1, 2);
14321a151fdcSMurali Vijayaraghavan           })
14331a151fdcSMurali Vijayaraghavan           .Case([&](linalg::PoolingNchwMaxOp op) {
14341a151fdcSMurali Vijayaraghavan             return std::make_tuple(0, 1, 2, 3);
14358e484b52SStanley Winata           })
14368e484b52SStanley Winata           .Default([&](Operation *op) {
14371a151fdcSMurali Vijayaraghavan             llvm_unreachable("unexpected conv2d/pool2d operation.");
14381a151fdcSMurali Vijayaraghavan             return std::make_tuple(0, 0, 0, 0);
14398e484b52SStanley Winata           });
14408e484b52SStanley Winata 
14417b615a87SLei Zhang   // Only handle the case where at least one of the window dimensions is
14427b615a87SLei Zhang   // of size 1. Other cases can rely on tiling to reduce to such cases.
14438e484b52SStanley Winata   int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
14448e484b52SStanley Winata   int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
144598dbcff1Sgysit   bool removeH = (khSize == 1 && ohSize == 1);
144698dbcff1Sgysit   bool removeW = (kwSize == 1 && owSize == 1);
1447aa373180SNicolas Vasilache   if (!removeH && !removeW)
14487b615a87SLei Zhang     return failure();
14497b615a87SLei Zhang 
14507b615a87SLei Zhang   // Get new shapes and types for all operands by removing the size-1
14517b615a87SLei Zhang   // dimension.
1452aa373180SNicolas Vasilache   using RTTBuilder = RankedTensorType::Builder;
1453789c88e8SNicolas Vasilache   RankedTensorType newInputType =
14548e484b52SStanley Winata       RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
145598dbcff1Sgysit   RankedTensorType newKernelType =
14568e484b52SStanley Winata       RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
1457789c88e8SNicolas Vasilache   RankedTensorType newOutputType =
14588e484b52SStanley Winata       RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
14597b615a87SLei Zhang 
1460aa373180SNicolas Vasilache   // Rank-reduce operands.
14617b615a87SLei Zhang   Location loc = convOp.getLoc();
1462aa373180SNicolas Vasilache   Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1463aa373180SNicolas Vasilache       rewriter, loc, input, newInputType);
146498dbcff1Sgysit   Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
146598dbcff1Sgysit       rewriter, loc, kernel, newKernelType);
1466aa373180SNicolas Vasilache   Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1467aa373180SNicolas Vasilache       rewriter, loc, output, newOutputType);
14687b615a87SLei Zhang 
1469aa373180SNicolas Vasilache   // Rank-reduce strides and dilations too.
1470aa373180SNicolas Vasilache   // TODO: dropDim 1-liner helper.
14718e484b52SStanley Winata   auto strides =
14728e484b52SStanley Winata       llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1473aa373180SNicolas Vasilache   strides.erase(strides.begin() + (removeH ? 0 : 1));
1474aa373180SNicolas Vasilache   auto stridesAttr = rewriter.getI64VectorAttr(strides);
1475aa373180SNicolas Vasilache 
1476d3b3f765SJacques Pienaar   auto dilations =
14778e484b52SStanley Winata       llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1478aa373180SNicolas Vasilache   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1479aa373180SNicolas Vasilache   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
14807b615a87SLei Zhang 
14818e484b52SStanley Winata   auto conv1DOp = rewriter.create<Conv1DOp>(
148298dbcff1Sgysit       loc, newOutputType, ValueRange{newInput, newKernel},
14837b615a87SLei Zhang       ValueRange{newOutput}, stridesAttr, dilationsAttr);
14847b615a87SLei Zhang 
1485aa373180SNicolas Vasilache   // Insert back.
1486aa373180SNicolas Vasilache   Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1487aa373180SNicolas Vasilache       rewriter, loc, conv1DOp.getResult(0), output);
1488aa373180SNicolas Vasilache   rewriter.replaceOp(convOp, inserted);
1489aa373180SNicolas Vasilache 
1490ce2e198bSAlex Zinenko   return conv1DOp;
1491ce2e198bSAlex Zinenko }
149298dbcff1Sgysit 
14932b882f84SBenjamin Kramer template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
14942b882f84SBenjamin Kramer                                                               Conv1DNwcWcfOp>;
14952b882f84SBenjamin Kramer template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
14962b882f84SBenjamin Kramer                                                               Conv1DNcwFcwOp>;
14971a151fdcSMurali Vijayaraghavan template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
14981a151fdcSMurali Vijayaraghavan                                                               PoolingNwcSumOp>;
14991a151fdcSMurali Vijayaraghavan template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
15001a151fdcSMurali Vijayaraghavan                                                               PoolingNcwSumOp>;
15011a151fdcSMurali Vijayaraghavan template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
15021a151fdcSMurali Vijayaraghavan                                                               PoolingNwcMaxOp>;
15031a151fdcSMurali Vijayaraghavan template struct linalg::DownscaleSizeOneWindowed2DConvolution<
15041a151fdcSMurali Vijayaraghavan     PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
15051a151fdcSMurali Vijayaraghavan template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
15061a151fdcSMurali Vijayaraghavan                                                               PoolingNwcMinOp>;
15071a151fdcSMurali Vijayaraghavan template struct linalg::DownscaleSizeOneWindowed2DConvolution<
15081a151fdcSMurali Vijayaraghavan     PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
15091a151fdcSMurali Vijayaraghavan template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
15101a151fdcSMurali Vijayaraghavan                                                               PoolingNcwMaxOp>;
15112b882f84SBenjamin Kramer 
1512ce2e198bSAlex Zinenko FailureOr<DepthwiseConv1DNwcWcOp>
1513ce2e198bSAlex Zinenko DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
1514ce2e198bSAlex Zinenko     DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
15150a8e3dd4SMatthias Springer   if (convOp.hasPureBufferSemantics())
1516ce2e198bSAlex Zinenko     return failure(); // To be implemented.
1517b828506eSNicolas Vasilache 
1518d3b3f765SJacques Pienaar   Value input = convOp.getInputs().front();
1519d3b3f765SJacques Pienaar   Value kernel = convOp.getInputs().back();
1520d3b3f765SJacques Pienaar   Value output = convOp.getOutputs().front();
1521b828506eSNicolas Vasilache 
15225550c821STres Popp   auto inputType = dyn_cast<RankedTensorType>(input.getType());
15235550c821STres Popp   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
15245550c821STres Popp   auto outputType = dyn_cast<RankedTensorType>(output.getType());
1525b828506eSNicolas Vasilache 
1526b828506eSNicolas Vasilache   auto kernelShape = kernelType.getShape();
1527b828506eSNicolas Vasilache   auto outputShape = outputType.getShape();
1528b828506eSNicolas Vasilache 
1529b828506eSNicolas Vasilache   // Only handle the case where at least one of the window dimensions is
1530b828506eSNicolas Vasilache   // of size 1. Other cases can rely on tiling to reduce to such cases.
1531b828506eSNicolas Vasilache   int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1532b828506eSNicolas Vasilache   int64_t ohSize = outputShape[1], owSize = outputShape[2];
1533b828506eSNicolas Vasilache   bool removeH = (khSize == 1 && ohSize == 1);
1534b828506eSNicolas Vasilache   bool removeW = (kwSize == 1 && owSize == 1);
1535b828506eSNicolas Vasilache   if (!removeH && !removeW)
1536b828506eSNicolas Vasilache     return failure();
1537b828506eSNicolas Vasilache 
1538b828506eSNicolas Vasilache   // Get new shapes and types for all operands by removing the size-1
1539b828506eSNicolas Vasilache   // dimension.
1540b828506eSNicolas Vasilache   using RTTBuilder = RankedTensorType::Builder;
1541789c88e8SNicolas Vasilache   RankedTensorType newInputType =
1542789c88e8SNicolas Vasilache       RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1543789c88e8SNicolas Vasilache   RankedTensorType newKernelType =
1544789c88e8SNicolas Vasilache       RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1545789c88e8SNicolas Vasilache   RankedTensorType newOutputType =
1546789c88e8SNicolas Vasilache       RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1547b828506eSNicolas Vasilache 
1548b828506eSNicolas Vasilache   // Rank-reduce operands.
1549b828506eSNicolas Vasilache   Location loc = convOp.getLoc();
1550b828506eSNicolas Vasilache   Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1551b828506eSNicolas Vasilache       rewriter, loc, input, newInputType);
1552b828506eSNicolas Vasilache   Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1553b828506eSNicolas Vasilache       rewriter, loc, kernel, newKernelType);
1554b828506eSNicolas Vasilache   Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1555b828506eSNicolas Vasilache       rewriter, loc, output, newOutputType);
1556b828506eSNicolas Vasilache 
1557b828506eSNicolas Vasilache   // Rank-reduce strides and dilations too.
1558b828506eSNicolas Vasilache   // TODO: dropDim 1-liner helper.
1559d3b3f765SJacques Pienaar   auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1560b828506eSNicolas Vasilache   strides.erase(strides.begin() + (removeH ? 0 : 1));
1561b828506eSNicolas Vasilache   auto stridesAttr = rewriter.getI64VectorAttr(strides);
1562b828506eSNicolas Vasilache 
1563d3b3f765SJacques Pienaar   auto dilations =
1564d3b3f765SJacques Pienaar       llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1565b828506eSNicolas Vasilache   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1566b828506eSNicolas Vasilache   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1567b828506eSNicolas Vasilache 
1568b828506eSNicolas Vasilache   auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1569b828506eSNicolas Vasilache       loc, newOutputType, ValueRange{newInput, newKernel},
1570b828506eSNicolas Vasilache       ValueRange{newOutput}, stridesAttr, dilationsAttr);
1571b828506eSNicolas Vasilache 
1572b828506eSNicolas Vasilache   // Insert back.
1573b828506eSNicolas Vasilache   Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1574b828506eSNicolas Vasilache       rewriter, loc, conv1DOp.getResult(0), output);
1575b828506eSNicolas Vasilache   rewriter.replaceOp(convOp, inserted);
1576b828506eSNicolas Vasilache 
1577ce2e198bSAlex Zinenko   return conv1DOp;
1578ce2e198bSAlex Zinenko }
15797b615a87SLei Zhang 
1580991945f4SDevajith Valaparambil Sreeramaswamy FailureOr<Conv1DOp>
1581991945f4SDevajith Valaparambil Sreeramaswamy DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
1582991945f4SDevajith Valaparambil Sreeramaswamy                                             PatternRewriter &rewriter) const {
15830a8e3dd4SMatthias Springer   if (convOp.hasPureBufferSemantics())
1584991945f4SDevajith Valaparambil Sreeramaswamy     return failure(); // To be implemented.
1585991945f4SDevajith Valaparambil Sreeramaswamy 
1586991945f4SDevajith Valaparambil Sreeramaswamy   Value input = convOp.getInputs().front();
1587991945f4SDevajith Valaparambil Sreeramaswamy   Value kernel = convOp.getInputs().back();
1588991945f4SDevajith Valaparambil Sreeramaswamy   Value output = convOp.getOutputs().front();
1589991945f4SDevajith Valaparambil Sreeramaswamy 
15905550c821STres Popp   auto inputType = dyn_cast<RankedTensorType>(input.getType());
15915550c821STres Popp   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
15925550c821STres Popp   auto outputType = dyn_cast<RankedTensorType>(output.getType());
1593991945f4SDevajith Valaparambil Sreeramaswamy 
1594991945f4SDevajith Valaparambil Sreeramaswamy   auto kernelShape = kernelType.getShape();
1595991945f4SDevajith Valaparambil Sreeramaswamy   auto outputShape = outputType.getShape();
1596991945f4SDevajith Valaparambil Sreeramaswamy 
1597991945f4SDevajith Valaparambil Sreeramaswamy   // Only handle the case where at least one of the window dimensions is
1598991945f4SDevajith Valaparambil Sreeramaswamy   // of size 1. Other cases can rely on tiling to reduce to such cases.
1599991945f4SDevajith Valaparambil Sreeramaswamy   int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1600991945f4SDevajith Valaparambil Sreeramaswamy   int64_t ohSize = outputShape[0], owSize = outputShape[1];
1601991945f4SDevajith Valaparambil Sreeramaswamy   bool removeH = (khSize == 1 && ohSize == 1);
1602991945f4SDevajith Valaparambil Sreeramaswamy   bool removeW = (kwSize == 1 && owSize == 1);
1603991945f4SDevajith Valaparambil Sreeramaswamy   if (!removeH && !removeW)
1604991945f4SDevajith Valaparambil Sreeramaswamy     return failure();
1605991945f4SDevajith Valaparambil Sreeramaswamy 
1606991945f4SDevajith Valaparambil Sreeramaswamy   // Get new shapes and types for all operands by removing the size-1
1607991945f4SDevajith Valaparambil Sreeramaswamy   // dimension.
1608991945f4SDevajith Valaparambil Sreeramaswamy   using RTTBuilder = RankedTensorType::Builder;
1609991945f4SDevajith Valaparambil Sreeramaswamy   RankedTensorType newInputType =
1610991945f4SDevajith Valaparambil Sreeramaswamy       RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
1611991945f4SDevajith Valaparambil Sreeramaswamy   RankedTensorType newKernelType =
1612991945f4SDevajith Valaparambil Sreeramaswamy       RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1613991945f4SDevajith Valaparambil Sreeramaswamy   RankedTensorType newOutputType =
1614991945f4SDevajith Valaparambil Sreeramaswamy       RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
1615991945f4SDevajith Valaparambil Sreeramaswamy 
1616991945f4SDevajith Valaparambil Sreeramaswamy   // Rank-reduce operands.
1617991945f4SDevajith Valaparambil Sreeramaswamy   Location loc = convOp.getLoc();
1618991945f4SDevajith Valaparambil Sreeramaswamy   Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1619991945f4SDevajith Valaparambil Sreeramaswamy       rewriter, loc, input, newInputType);
1620991945f4SDevajith Valaparambil Sreeramaswamy   Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1621991945f4SDevajith Valaparambil Sreeramaswamy       rewriter, loc, kernel, newKernelType);
1622991945f4SDevajith Valaparambil Sreeramaswamy   Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1623991945f4SDevajith Valaparambil Sreeramaswamy       rewriter, loc, output, newOutputType);
1624991945f4SDevajith Valaparambil Sreeramaswamy 
1625991945f4SDevajith Valaparambil Sreeramaswamy   auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType,
1626991945f4SDevajith Valaparambil Sreeramaswamy                                             ValueRange{newInput, newKernel},
1627991945f4SDevajith Valaparambil Sreeramaswamy                                             ValueRange{newOutput});
1628991945f4SDevajith Valaparambil Sreeramaswamy 
1629991945f4SDevajith Valaparambil Sreeramaswamy   // Insert back.
1630991945f4SDevajith Valaparambil Sreeramaswamy   Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1631991945f4SDevajith Valaparambil Sreeramaswamy       rewriter, loc, conv1DOp.getResult(0), output);
1632991945f4SDevajith Valaparambil Sreeramaswamy   rewriter.replaceOp(convOp, inserted);
1633991945f4SDevajith Valaparambil Sreeramaswamy 
1634991945f4SDevajith Valaparambil Sreeramaswamy   return conv1DOp;
1635991945f4SDevajith Valaparambil Sreeramaswamy }
1636991945f4SDevajith Valaparambil Sreeramaswamy 
1637ad1efb51SNicolas Vasilache void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
16387b615a87SLei Zhang                                                   PatternBenefit benefit) {
16398e484b52SStanley Winata   patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
16408e484b52SStanley Winata                                                      Conv1DNwcWcfOp>,
16418e484b52SStanley Winata                DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
16428e484b52SStanley Winata                                                      Conv1DNcwFcwOp>,
1643991945f4SDevajith Valaparambil Sreeramaswamy                DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>(
1644991945f4SDevajith Valaparambil Sreeramaswamy       patterns.getContext(), benefit);
16451a151fdcSMurali Vijayaraghavan   patterns.add<
16461a151fdcSMurali Vijayaraghavan       DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>,
16471a151fdcSMurali Vijayaraghavan       DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>,
16481a151fdcSMurali Vijayaraghavan       DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>,
16491a151fdcSMurali Vijayaraghavan       DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
16501a151fdcSMurali Vijayaraghavan                                             PoolingNwcMaxUnsignedOp>,
16511a151fdcSMurali Vijayaraghavan       DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>,
16521a151fdcSMurali Vijayaraghavan       DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
16531a151fdcSMurali Vijayaraghavan                                             PoolingNwcMinUnsignedOp>,
16541a151fdcSMurali Vijayaraghavan       DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
16551a151fdcSMurali Vijayaraghavan       patterns.getContext(), benefit);
16567b615a87SLei Zhang }
165763b926afSAndrzej Warzyński 
165807750882SAndrzej Warzyński void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
165907750882SAndrzej Warzyński   patterns.add<DecomposeOuterUnitDimsPackOpPattern>(patterns.getContext());
1660*25825682SAndrzej Warzyński   patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(patterns.getContext());
166163b926afSAndrzej Warzyński }
16621b2c8f10SAndrzej Warzyński 
16631b2c8f10SAndrzej Warzyński void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) {
16641b2c8f10SAndrzej Warzyński   patterns.add<DecomposePadOpPattern>(patterns.getContext());
16651b2c8f10SAndrzej Warzyński }
1666