xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (revision 258256821753504836f797e38d83a8e88daa424d)
1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
23 #include "mlir/Dialect/Tensor/Utils/Utils.h"
24 #include "mlir/Dialect/Utils/IndexingUtils.h"
25 #include "mlir/Dialect/Utils/StaticValueUtils.h"
26 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
27 #include "mlir/Dialect/Vector/IR/VectorOps.h"
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Matchers.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Support/LLVM.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <type_traits>
38 #include <utility>
39 
40 #define DEBUG_TYPE "linalg-transforms"
41 
42 using namespace mlir;
43 using namespace mlir::linalg;
44 
45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 #define DBGSNL() (llvm::dbgs() << "\n")
47 
48 //===----------------------------------------------------------------------===//
49 // Transformations exposed as functional-style API calls.
50 //===----------------------------------------------------------------------===//
51 
52 //===----------------------------------------------------------------------===//
53 // peelLoop transformation.
54 //===----------------------------------------------------------------------===//
55 
56 /// Try to peel and canonicalize loop `op` and return the new result.
57 /// Also applies affine_min/max bounds simplification on the fly where relevant.
58 // TODO: Add support for scf.parallel and affine.for loops.
59 SmallVector<Value> mlir::linalg::peelLoop(RewriterBase &rewriter,
60                                           Operation *op) {
61   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
62       .Case<scf::ForOp>([&](scf::ForOp forOp) {
63         scf::ForOp partialIteration;
64         if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
65                                                         partialIteration)))
66           return partialIteration->getResults();
67         assert(!partialIteration && "expected that loop was not peeled");
68         return forOp->getResults();
69       })
70       .Default([&](Operation *op) { return op->getResults(); });
71 }
72 
73 /// Peel 'loops' and applies affine_min/max bounds simplification on the fly
74 /// where relevant.
75 void mlir::linalg::peelLoops(RewriterBase &rewriter,
76                              ArrayRef<scf::ForOp> loops) {
77   for (auto loopOp : loops)
78     peelLoop(rewriter, loopOp);
79 }
80 
81 //===----------------------------------------------------------------------===//
82 // pack transformation.
83 //===----------------------------------------------------------------------===//
84 
85 #ifndef NDEBUG
86 /// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
87 static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
88   bool found = false;
89   for (AffineExpr e : map.getResults()) {
90     if (!e.isFunctionOfDim(dim))
91       continue;
92     if (found)
93       return false;
94     found = true;
95   }
96   return true;
97 }
98 #endif // NDEBUG
99 
100 /// Return the index of the first result of `map` that is a function of
101 /// AffineDimExpr(dim), std::nullopt otherwise.
102 static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
103                                                             int64_t dim) {
104   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
105     AffineExpr expr = map.getResult(i);
106     if (!expr.isFunctionOfDim(dim))
107       continue;
108     return i;
109   }
110   return std::nullopt;
111 }
112 
113 /// Perform one step of packing of a LinalgOp's metadata along `dim` into the
114 /// `newDim` at `iteratorTypes.size()` by:
115 ///   1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
116 ///   2. Appending a `newDim` to the domain of every indexing map.
117 ///   3. For each operand (i.e. for each map in `indexingMaps`), perform packing
118 ///      by potentially adding a `newDim` result to `map`.
119 /// The preserved invariant is that `iteratorTypes.size()` is always equal to
120 /// `map.getNumDims()` for every map in `indexingMaps`.
121 ///
122 /// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
123 /// Return a vector that records the optional packing for each operand.
124 /// Return failure if the packed indexing cannot be represented with a LinalgOp.
125 ///
126 /// Further details:
127 /// ================
128 /// The current implementation of packing (i.e. data tiling) consists of
129 /// rewriting a linearized strip-mined form into a higher-dimensional access.
130 /// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
131 /// `I` into `4 * i + ii`, where `0 <= ii < 4`.
132 /// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
133 ///
134 /// This rewrite into higher dimensional access is not possible for general
135 /// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
136 /// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
137 /// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
138 /// The rewrite of the access would be a form not representable in Linalg:
139 ///   `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
140 /// Note however that as `J` and `ii` iterate, the accesses do not have a
141 /// particular alignment, so packing does not achieve alignment in this case
142 ///
143 /// In the future, we may want to consider a mixed-form that allows some
144 /// alignment in the presence of multiple accesses:
145 ///   `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
146 /// And would rewrite accesses as:
147 ///   `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
148 static FailureOr<SmallVector<std::optional<int64_t>>>
149 packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps,
150                        SmallVectorImpl<utils::IteratorType> &iteratorTypes,
151                        int64_t dim) {
152   int64_t newDim = iteratorTypes.size();
153   iteratorTypes.push_back(iteratorTypes[dim]);
154 
155   SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
156       indexingMaps.size(), std::nullopt);
157   SmallVector<AffineMap> newMaps;
158   for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
159        ++operandIdx) {
160     AffineMap map = indexingMaps[operandIdx];
161 
162     // Add the `newDim` to map whatever the case.
163     assert(map.getNumDims() == newDim && "num dims invariant violation");
164     map = map.shiftDims(1, newDim);
165 
166     // Get the at-most-1 index of the result that is a function of `dim`.
167     // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
168     // logically chunks dimension `dim` into `K * dim + newDim`, where the
169     // packing factor `K` is specified separately.
170     assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
171            "num results invariant violation");
172     auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
173     if (!maybeOperandDimensionToPack.has_value()) {
174       newMaps.push_back(map);
175       continue;
176     }
177 
178     // We can only pack AffineDimExpr atm.
179     if (!isa<AffineDimExpr>(map.getResult(maybeOperandDimensionToPack.value())))
180       return failure();
181 
182     // Add `newDim` to the results of the map.
183     map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
184                            map.getNumResults());
185     newMaps.push_back(map);
186 
187     // Record the that `operandIdx` is packed.
188     packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
189   }
190   indexingMaps = newMaps;
191 
192   return packedDimPerIndexingMap;
193 }
194 
195 namespace {
196 
197 /// Helper struct to encode packing along one dimension of a LinalgOp.
198 struct PackedOperandsDim {
199   OpFoldResult packedSize;
200   SmallVector<std::optional<int64_t>> packedDimForEachOperand;
201 };
202 
203 /// Helper struct to encode packing along all dimensions of a LinalgOp.
204 struct PackedOperandsDimList {
205   void pushBack(PackedOperandsDim &&packedOperandsDims) {
206     spec.emplace_back(packedOperandsDims);
207   }
208   /// Return all the dims that have been packed for operand @ `operandPos`.
209   SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
210   /// Return all the pack sizes by which an operand @ `operandPos` is packed.
211   SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
212 
213 private:
214   SmallVector<PackedOperandsDim> spec;
215 };
216 
217 } // namespace
218 
219 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
220                                              tensor::PackOp packOp,
221                                              bool lowerPadLikeWithInsertSlice) {
222   // 1. Filter out NYI cases.
223   auto packedTensorType =
224       cast<RankedTensorType>(packOp->getResultTypes().front());
225   if (llvm::any_of(packOp.getStaticInnerTiles(),
226                    [](int64_t size) { return ShapedType::isDynamic(size); })) {
227     return rewriter.notifyMatchFailure(
228         packOp,
229         "non-static shape NYI, needs a more powerful tensor.expand_shape op");
230   }
231 
232   Location loc = packOp->getLoc();
233   OpBuilder::InsertionGuard g(rewriter);
234   rewriter.setInsertionPoint(packOp);
235 
236   // 2. Compute the permutation vector to shuffle packed shape into the shape
237   // before any outer or inner permutations have been applied.
238   PackingMetadata packingMetadata = computePackingMetadata(
239       packedTensorType.getRank(), packOp.getInnerDimsPos());
240   SmallVector<int64_t> packedToStripMinedShapePerm =
241       tensor::getPackInverseDestPerm(packOp);
242 
243   // 3. Compute the stripMinedShape: this is the packed shape before any outer
244   // or inner permutations have been applied.
245   SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
246   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
247 
248   // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
249   SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
250                                  rewriter.getIndexAttr(0));
251   SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
252                                   rewriter.getIndexAttr(0));
253   for (auto [pos, innerSize] :
254        llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
255     int outerPos =
256         packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
257     OpFoldResult origSize =
258         tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos);
259     OpFoldResult outerSize =
260         tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos);
261     AffineExpr s0, d0, d1;
262     bindDims(rewriter.getContext(), d0, d1);
263     bindSymbols(rewriter.getContext(), s0);
264     auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1);
265     highs[pos] = affine::makeComposedFoldedAffineApply(
266         rewriter, loc, map, {outerSize, origSize, innerSize});
267   }
268   RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
269       RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
270       packingMetadata.reassociations);
271   Value paddingValue = packOp.getPaddingValue();
272   if (!paddingValue) {
273     paddingValue = rewriter.create<arith::ConstantOp>(
274         loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
275   }
276   auto padOp =
277       rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
278                                      highs, paddingValue, /*nofold=*/false);
279 
280   LLVM_DEBUG(
281       DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
282                                                 DBGS() << "insertPositions: ");
283       DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
284                                       DBGS() << "outerPositions: ");
285       DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
286                                       DBGS() << "packedShape: ");
287       DBGSNL();
288       llvm::interleaveComma(packedToStripMinedShapePerm,
289                             DBGS() << "packedToStripMinedShapePerm: ");
290       DBGSNL(); llvm::interleaveComma(
291           packingMetadata.reassociations, DBGS() << "reassociations: ",
292           [&](ReassociationIndices ri) {
293             llvm::interleaveComma(ri, llvm::dbgs() << "|");
294           });
295       DBGSNL();
296       llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
297       DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
298 
299   if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
300     // Pack ops which operate as simple pads may not produce legal
301     // tensor.insert_slice operations when the packed type does not rank reduce
302     // to the padded type.
303     SliceVerificationResult rankReduces =
304         isRankReducedType(packedTensorType, padOp.getResultType());
305 
306     if (rankReduces == SliceVerificationResult::Success) {
307       // This pack is just a plain pad.
308       // Just insert the pad in the higher ranked tensor.
309       // Offsets.
310       SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
311                                       rewriter.getIndexAttr(0));
312       // Strides.
313       SmallVector<OpFoldResult> ones(packOp.getDestRank(),
314                                      rewriter.getIndexAttr(1));
315       SmallVector<OpFoldResult> sizes =
316           tensor::getMixedSizes(rewriter, loc, packOp.getDest());
317 
318       auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
319           loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
320           /*offsets=*/zeros, sizes, /*strides=*/ones);
321 
322       LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
323 
324       rewriter.replaceOp(packOp, insertSliceOp->getResults());
325 
326       return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
327                              /*transposeOp=*/nullptr};
328     }
329   }
330 
331   // 5. Expand from the padded result to the stripMinedShape.
332   auto expandShapeResultType =
333       RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
334   auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
335       loc, expandShapeResultType, padOp.getResult(),
336       packingMetadata.reassociations);
337 
338   // 6. Transpose stripMinedShape to packedShape.
339   SmallVector<int64_t> transpPerm =
340       invertPermutationVector(packedToStripMinedShapePerm);
341   auto transposeOp = rewriter.create<linalg::TransposeOp>(
342       loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
343 
344   LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
345              DBGS() << "reshape op: " << reshapeOp; DBGSNL();
346              llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
347              DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
348 
349   // 7. Replace packOp by transposeOp.
350   rewriter.replaceOp(packOp, transposeOp->getResults());
351 
352   return LowerPackResult{padOp, reshapeOp, transposeOp};
353 }
354 
355 FailureOr<LowerUnPackOpResult>
356 linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
357                     bool lowerUnpadLikeWithExtractSlice) {
358   Location loc = unPackOp->getLoc();
359   OpBuilder::InsertionGuard g(rewriter);
360   rewriter.setInsertionPoint(unPackOp);
361 
362   RankedTensorType packedTensorType = unPackOp.getSourceType();
363   int64_t packedRank = packedTensorType.getRank();
364 
365   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
366   auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
367   if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
368     // This unpack is just a plain unpad.
369     // Just extract the slice from the higher ranked tensor.
370     ArrayRef<int64_t> destShape = destTensorType.getShape();
371     // The inner dimensions stay the same as the destination tensor, but the
372     // outer ones are additional 1s.
373     SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
374     sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()));
375 
376     auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
377         loc, destTensorType, unPackOp.getSource(),
378         SmallVector<OpFoldResult>(packedRank, zero), sizes,
379         SmallVector<OpFoldResult>(packedRank, one));
380 
381     rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
382 
383     return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
384                                /*reshapeOp=*/nullptr, extractSliceOp};
385   }
386 
387   // 1. Compute the permutation vector to shuffle packed shape into the shape
388   // before any outer or inner permutations have been applied.
389   PackingMetadata packingMetadata;
390   SmallVector<int64_t> packedToStripMinedShapePerm =
391       tensor::getUnPackInverseSrcPerm(unPackOp, packingMetadata);
392 
393   // 2. Compute the stripMinedShape: this is the packed shape without outer and
394   // inner permutations.
395   SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
396   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
397 
398   // 3. Transpose packedShape to stripMinedShape.
399   RankedTensorType stripMinedTensorType =
400       RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
401   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
402       stripMinedTensorType, packingMetadata.reassociations);
403 
404   // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
405   // permutation.
406   SmallVector<OpFoldResult, 4> dims =
407       tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
408   applyPermutationToVector(dims, packedToStripMinedShapePerm);
409   auto emptyOp = rewriter.create<tensor::EmptyOp>(
410       loc, dims, stripMinedTensorType.getElementType());
411   auto transposeOp = rewriter.create<linalg::TransposeOp>(
412       loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
413 
414   LLVM_DEBUG(
415       DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
416                                                 DBGS() << "insertPositions: ");
417       DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
418                                       DBGS() << "packedShape: ");
419       DBGSNL();
420       llvm::interleaveComma(packedToStripMinedShapePerm,
421                             DBGS() << "packedToStripMinedShapePerm: ");
422       DBGSNL(); llvm::interleaveComma(
423           packingMetadata.reassociations, DBGS() << "reassociations: ",
424           [&](ReassociationIndices ri) {
425             llvm::interleaveComma(ri, llvm::dbgs() << "|");
426           });
427       DBGSNL();
428       llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
429       DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
430 
431   // 4. Collapse from the stripMinedShape to the padded result.
432   auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
433       loc, collapsedType, transposeOp->getResult(0),
434       packingMetadata.reassociations);
435 
436   // 5. ExtractSlice.
437   int64_t destRank = destTensorType.getRank();
438   auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
439       loc, destTensorType, reshapeOp->getResult(0),
440       SmallVector<OpFoldResult>(destRank, zero),
441       tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
442       SmallVector<OpFoldResult>(destRank, one));
443 
444   // 6. Inject a copy to preserve DPS.
445   auto copyOp = rewriter.create<linalg::CopyOp>(
446       loc, extractSliceOp->getResult(0), unPackOp.getDest());
447 
448   // 7. Replace unPackOp by copyOp.
449   rewriter.replaceOp(unPackOp, copyOp->getResults());
450 
451   return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
452 }
453 
454 SmallVector<int64_t>
455 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
456   SmallVector<int64_t> res;
457   for (auto &i : spec) {
458     if (!i.packedDimForEachOperand[operandPos].has_value())
459       continue;
460     res.push_back(i.packedDimForEachOperand[operandPos].value());
461   }
462   return res;
463 }
464 
465 SmallVector<OpFoldResult>
466 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
467   SmallVector<OpFoldResult> res;
468   for (auto &i : spec) {
469     if (!i.packedDimForEachOperand[operandPos].has_value())
470       continue;
471     res.push_back(i.packedSize);
472   }
473   return res;
474 }
475 
476 /// Implement packing of a single LinalgOp by performing packing by
477 /// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
478 /// Return the packed Linalg op on success, failure otherwise.
479 FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
480                                    linalg::LinalgOp linalgOp,
481                                    ArrayRef<OpFoldResult> packedSizes) {
482   if (packedSizes.size() != linalgOp.getNumLoops()) {
483     return rewriter.notifyMatchFailure(linalgOp,
484                                        "incorrect number of pack sizes");
485   }
486 
487   Location loc = linalgOp->getLoc();
488   SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
489   SmallVector<utils::IteratorType> iteratorTypes =
490       linalgOp.getIteratorTypesArray();
491   LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n";
492              llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
493              llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: ");
494              DBGSNL(););
495 
496   SmallVector<tensor::PackOp> packOps;
497   SmallVector<tensor::UnPackOp> unPackOps;
498   // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
499   PackedOperandsDimList listOfPackedOperandsDim;
500   for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
501     std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
502     // Skip tile sizes explicitly set to 0.
503     if (maybeConstant.has_value() && maybeConstant.value() == 0)
504       continue;
505 
506     PackedOperandsDim packedOperandsDims;
507     packedOperandsDims.packedSize = packedSizes[i];
508     FailureOr<SmallVector<std::optional<int64_t>>>
509         maybePackedDimForEachOperand =
510             packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
511     if (failed(maybePackedDimForEachOperand))
512       return failure();
513     packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
514     listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
515 
516     LLVM_DEBUG(
517         DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
518                << "\n";
519         llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
520         llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL();
521         llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
522                               DBGS() << "packedDimForEachOperand: ");
523         DBGSNL(););
524   }
525 
526   // Step 2. Propagate packing to all LinalgOp operands.
527   SmallVector<Value> inputsAndInits, results;
528   SmallVector<OpOperand *> initOperands = llvm::to_vector(llvm::map_range(
529       linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
530   SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
531   for (const auto &operandsList : {inputOperands, initOperands}) {
532     for (OpOperand *opOperand : operandsList) {
533       int64_t pos = opOperand->getOperandNumber();
534       Value operand = opOperand->get();
535       SmallVector<int64_t> innerPos =
536           listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
537       SmallVector<OpFoldResult> innerPackSizes =
538           listOfPackedOperandsDim.extractPackSizesForOperand(pos);
539       LLVM_DEBUG(
540           DBGS() << "operand: " << operand << "\n";
541           llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL();
542           llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: ");
543           DBGSNL(););
544       if (innerPackSizes.empty()) {
545         inputsAndInits.push_back(operand);
546         continue;
547       }
548       Value dest = tensor::PackOp::createDestinationTensor(
549           rewriter, loc, operand, innerPackSizes, innerPos,
550           /*outerDimsPerm=*/{});
551       ShapedType operandType = cast<ShapedType>(operand.getType());
552       bool areConstantTiles =
553           llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
554             return getConstantIntValue(tile).has_value();
555           });
556       if (areConstantTiles && operandType.hasStaticShape() &&
557           !tensor::PackOp::requirePaddingValue(
558               operandType.getShape(), innerPos,
559               cast<ShapedType>(dest.getType()).getShape(), {},
560               innerPackSizes)) {
561         packOps.push_back(rewriter.create<tensor::PackOp>(
562             loc, operand, dest, innerPos, innerPackSizes));
563       } else {
564         // TODO: value of the padding attribute should be determined by
565         // consumers.
566         auto zeroAttr =
567             rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
568         Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
569         packOps.push_back(rewriter.create<tensor::PackOp>(
570             loc, operand, dest, innerPos, innerPackSizes, zero));
571       }
572       inputsAndInits.push_back(packOps.back());
573     }
574   }
575 
576   // Step 3. Build the packed op, use the type of `inits` as result types.
577   ValueRange inputs =
578       ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
579   ValueRange inits =
580       ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
581   auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
582       linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
583       iteratorTypes);
584   packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
585 
586   // Step 4. Propagate packing to all the op results.
587   for (OpResult result : packedLinalgOp->getResults()) {
588     int64_t resultNum = result.getResultNumber();
589     tensor::PackOp maybePackedInit =
590         inits[resultNum].getDefiningOp<tensor::PackOp>();
591     if (!maybePackedInit) {
592       results.push_back(result);
593       continue;
594     }
595     // Build the symmetrical UnPackOp to the existing PackOp.
596     unPackOps.push_back(rewriter.create<tensor::UnPackOp>(
597         packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
598         maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
599     results.push_back(unPackOps.back());
600   }
601 
602   // Step 5. Replace `linalgOp`.
603   rewriter.replaceOp(linalgOp, results);
604 
605   // Return packedLinalgOp.
606   return PackResult{packOps,
607                     cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
608                     unPackOps};
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // packTranspose transformation.
613 //===----------------------------------------------------------------------===//
614 
615 /// Return a copy of `tensorType` after permutation by `permutationVector`.
616 // Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
617 // but this would introduce a dependence on Dialect in IR.
618 // TODO: Restructure.
619 static RankedTensorType permuteShape(RankedTensorType tensorType,
620                                      ArrayRef<int64_t> permutationVector) {
621   SmallVector<int64_t> shape(tensorType.getShape());
622   applyPermutationToVector(shape, permutationVector);
623   return RankedTensorType::Builder(tensorType).setShape(shape);
624 }
625 
626 /// Return a new GenericOp obtained by transposing opOperand by the permutation
627 /// vector:
628 ///   - the corresponding indexing map is transposed by `permutation`
629 ///   - the corresponding operand value is replaced by `transposedValue`
630 /// `linalgOp` is replaced by the return op in the process.
631 /// Asserts that `transposedValue` is of the proper transposed ShapedType.
632 static LinalgOp transposeOneLinalgOperandAndReplace(
633     RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
634     ArrayRef<int64_t> permutation, Value transposedValue) {
635   // Sanity check the operand.
636   assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
637 
638   // Sanity check of the expected transposed tensor type.
639   auto tensorType = permuteShape(
640       cast<RankedTensorType>(opOperand.get().getType()), permutation);
641   (void)tensorType;
642   assert(tensorType == transposedValue.getType() &&
643          "expected tensor type mismatch");
644 
645   // Compute the transposed indexing map.
646   // Sigh unsigned pollution.
647   SmallVector<unsigned> tmpTransposition = llvm::to_vector(
648       llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
649   AffineMap permutationMap =
650       AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
651   AffineMap transposedMap =
652       permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
653 
654   // Set the transposed indexing map in the proper position.
655   SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
656   indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
657   // Set the transposedValue in the proper operand position.
658   SmallVector<Value> operands = linalgOp->getOperands();
659   operands[opOperand.getOperandNumber()] = transposedValue;
660 
661   ValueRange operandsRef(operands);
662   auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
663       /*location=*/linalgOp->getLoc(),
664       /*resultTensorTypes=*/
665       operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
666       /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
667       /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
668       /*indexingMaps=*/indexingMaps,
669       /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
670   transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
671   rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
672 
673   return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
674 }
675 
676 FailureOr<PackTransposeResult>
677 linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
678                       linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
679                       ArrayRef<int64_t> outerPerm,
680                       ArrayRef<int64_t> innerPerm) {
681   Location loc = linalgOp.getLoc();
682 
683   // Step 1. Transpose packOp.
684   rewriter.setInsertionPoint(packOp);
685   tensor::PackOp transposedPackOp =
686       packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
687 
688   if (!packOp.getResult().hasOneUse())
689     return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
690 
691   OpOperand &packUse = *packOp->getUses().begin();
692   if (packUse.getOwner() != linalgOp) {
693     return rewriter.notifyMatchFailure(
694         linalgOp, "not a single use by the LinalgOp target");
695   }
696   if (maybeUnPackOp &&
697       (!linalgOp.isDpsInit(&packUse) ||
698        maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
699     return rewriter.notifyMatchFailure(linalgOp,
700                                        "not produced by the LinalgOp target");
701   }
702 
703   // Step 2. Transpose linalgOp.
704   // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
705   // identity. Don't rely on it.
706   int64_t numLeadingDims = packOp.getSourceRank();
707   int64_t numTrailingDims = packOp.getInnerDimsPos().size();
708   // Step 2.a. Compute the permutation on the whole operand.
709   // Leading part just reuse the outerPerm.
710   SmallVector<int64_t> permutation(outerPerm);
711   if (permutation.empty())
712     llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
713   // Trailing part needs to reindex positions by `numLeadingDims`.
714   if (innerPerm.empty()) {
715     llvm::append_range(
716         permutation,
717         llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
718   } else {
719     llvm::append_range(permutation,
720                        llvm::map_range(innerPerm, [&](int64_t pos) {
721                          return numLeadingDims + pos;
722                        }));
723   }
724   if (!isPermutationVector(permutation))
725     return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
726 
727   // Step 2.b. Save the transposedPackUse operand number in case we need to
728   // get the tied OpResult after `linalgOp` has been replaced.
729   int64_t packUseOperandNumber = packUse.getOperandNumber();
730   // Step 2.c. Actually perform the transposition.
731   rewriter.setInsertionPoint(linalgOp);
732   linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
733       rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
734 
735   // Step 3. Maybe transpose unPackOp.
736   tensor::UnPackOp transposedUnPackOp;
737   if (maybeUnPackOp) {
738     OpOperand &opOperand =
739         transposedLinalgOp->getOpOperand(packUseOperandNumber);
740     OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
741     rewriter.setInsertionPoint(maybeUnPackOp);
742     transposedUnPackOp = maybeUnPackOp.createTransposedClone(
743         rewriter, loc, transposedResult, innerPerm, outerPerm);
744 
745     rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
746   }
747 
748   // Step 4. Finally, replace packOp now that we don't need it anymore.
749   rewriter.replaceOp(packOp, transposedPackOp->getResults());
750 
751   return PackTransposeResult{transposedPackOp, transposedLinalgOp,
752                              transposedUnPackOp};
753 }
754 
755 //===----------------------------------------------------------------------===//
756 // packMatmulGreedily transformation.
757 //===----------------------------------------------------------------------===//
758 
759 /// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
760 /// and n are proper parallel dimensions and k is a proper reduction
761 /// dimension. Packing occurs by rewriting the op as a linalg.generic and
762 /// calling linalg::pack by `mnkPackedSizes`. The order of the packed
763 /// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
764 /// to reorder {m, n, k} into one of the 8 possible forms. The outer
765 /// dimensions of the operands are not permuted at this time, this is left for
766 /// future work.
767 FailureOr<PackResult>
768 linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
769                            ArrayRef<OpFoldResult> mnkPackedSizes,
770                            ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
771                            ArrayRef<int64_t> mnkOrder) {
772   assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
773   assert((mnkPaddedSizesNextMultipleOf.empty() ||
774           mnkPaddedSizesNextMultipleOf.size() == 3) &&
775          "num of packing sizes next multiple should be empty or of size 3");
776   assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
777   assert(isPermutationVector(mnkOrder) && "expected a permutation");
778 
779   int64_t numLoops = linalgOp.getNumLoops();
780   if (numLoops <= 2) {
781     LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got "
782                       << numLoops << "\nin: " << linalgOp << "\n");
783     return rewriter.notifyMatchFailure(
784         linalgOp, "need 3+ loops to find a matmul to pack");
785   }
786 
787   // Locally adjust the desired iterator position of mnk and packing sizes.
788   int64_t numPackedDims = mnkPackedSizes.size();
789   SmallVector<int64_t> mmnnkkPos(numPackedDims);
790   for (int64_t i = 0, e = numPackedDims; i < e; ++i)
791     mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
792   SmallVector<OpFoldResult> packedSizes(numPackedDims);
793   for (int64_t i = 0, e = numPackedDims; i < e; ++i)
794     packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
795   SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
796   for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
797     paddedSizesNextMultipleOf[mnkOrder[i]] =
798         mnkPaddedSizesNextMultipleOf.empty() ? 0
799                                              : mnkPaddedSizesNextMultipleOf[i];
800   }
801 
802   // 1. Infer dims that are important for matmul.
803   FailureOr<ContractionDimensions> maybeDimensions =
804       inferContractionDims(linalgOp);
805   if (failed(maybeDimensions)) {
806     LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp
807                       << "\n");
808     return rewriter.notifyMatchFailure(linalgOp,
809                                        "couldn't infer matmul iterators");
810   }
811 
812   // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
813   // minor iterators. In cases with multiple options for m, n, k bias towards
814   // the most minor embedding.
815   // If we wanted a different normalization order, this is where it would have
816   // to plug a heuristic.
817   int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
818           kPos = maybeDimensions->k.back();
819   LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
820              DBGS() << "Start packing generic op greedily with (m@" << mPos
821                     << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
822                     << "\n";);
823 
824   // 2.a. Rewrite as a generic.
825   auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
826   if (!genericOp) {
827     FailureOr<GenericOp> generalizeResult =
828         generalizeNamedOp(rewriter, linalgOp);
829     assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
830     genericOp = *generalizeResult;
831   }
832 
833   // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
834   // iterators. Note that this only normalized the iteration order and does
835   // not change the indexings of any operand.
836   SmallVector<int64_t> permutation =
837       computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
838   LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL(););
839   // Sign .. unsigned pollution.
840   SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
841   FailureOr<GenericOp> interchangeResult =
842       interchangeGenericOp(rewriter, genericOp, unsignedPerm);
843   assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
844   genericOp = *interchangeResult;
845   LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
846 
847   // At this point, the op iterators are normalized to {leading, k, m, n}.
848   // The layouts induced by packing will always be:
849   //   - LHS{leading_lhs, kk, mm}
850   //   - RHS{leading_rhs, kk, nn}
851   //   - RES{leading_res, mm, nn}
852   // If we wanted to change the packed order, we would reorder (k, m, n) to
853   // something else above.
854   //
855   // Additional permutations of the outer dims of the operands (i.e.
856   // leading_lhs, leading_rhs and leading_res) could follow by computing the
857   // desired outerPerm for each operand.
858   // This is left for future work.
859 
860   // TODO: this creates too much IR, go use reifyResultShapes.
861   SmallVector<Range, 4> loopRanges =
862       cast<LinalgOp>(genericOp.getOperation())
863           .createLoopRanges(rewriter, genericOp.getLoc());
864 
865   // Add leading zeros to match numLoops, we only pack the last 3 dimensions
866   // post interchange.
867   LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
868                                    DBGS() << "paddedSizesNextMultipleOf: ");
869              DBGSNL(););
870   LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ",
871                                    [](Range r) { llvm::dbgs() << r.size; });
872              DBGSNL(););
873   SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
874                                                 rewriter.getIndexAttr(0));
875   for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
876     if (paddedSizesNextMultipleOf[i] == 0) {
877       adjustedPackedSizes.push_back(packedSizes[i]);
878       continue;
879     }
880     AffineExpr d0, s0;
881     bindDims(rewriter.getContext(), d0);
882     bindSymbols(rewriter.getContext(), s0);
883     adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply(
884         rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
885         {loopRanges[adjustedPackedSizes.size()].size,
886          rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
887   }
888   LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
889                                    DBGS() << "adjustedPackedSizes: ");
890              DBGSNL(););
891 
892   // TODO: If we wanted to give the genericOp a name after packing, after
893   // calling `pack` would be a good time. One would still need to check that
894   // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
895   // also allow degenerate matmul cases (i.e. matvec, dot).
896   return pack(rewriter, genericOp, adjustedPackedSizes);
897 }
898 
899 //===----------------------------------------------------------------------===//
900 // Transformations exposed as rewrite patterns.
901 //===----------------------------------------------------------------------===//
902 
903 LinalgTilingOptions &
904 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
905   assert(!tileSizeComputationFunction && "tile sizes already set");
906   SmallVector<int64_t, 4> tileSizes(ts);
907   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
908     OpBuilder::InsertionGuard guard(b);
909     b.setInsertionPointToStart(
910         &op->getParentOfType<func::FuncOp>().getBody().front());
911     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
912       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
913       return v;
914     }));
915   };
916   return *this;
917 }
918 
919 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
920     memref::CopyOp copyOp, PatternRewriter &rewriter) const {
921   return vectorizeCopy(rewriter, copyOp);
922 }
923 
924 /// Filling `dest` using FillOp constant padding value if possible.
925 /// Otherwise, generate a tensor::GenerateOp.
926 Value DecomposePadOpPattern::createFillOrGenerateOp(
927     RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
928     const SmallVector<Value> &dynSizes) const {
929   auto padValue = padOp.getConstantPaddingValue();
930   if (padValue)
931     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
932 
933   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
934   auto generateOp = rewriter.create<tensor::GenerateOp>(
935       padOp.getLoc(), padOp.getResultType(), dynSizes);
936   // Copy region to new op.
937   IRMapping bvm;
938   padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
939   return generateOp;
940 }
941 
942 LogicalResult
943 DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
944                                        PatternRewriter &rewriter) const {
945   // Given an OpFoldResult, return an index-typed value.
946   auto getIdxValue = [&](OpFoldResult ofr) {
947     if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
948       return val;
949     return rewriter
950         .create<arith::ConstantIndexOp>(
951             padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
952         .getResult();
953   };
954 
955   auto resultType = padOp.getResultType();
956   // Compute size of EmptyOp. Any combination of static/dynamic is supported.
957   SmallVector<Value> dynSizes;
958   SmallVector<int64_t> staticSizes;
959   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
960     if (resultType.isDynamicDim(dim)) {
961       auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(),
962                                                       padOp.getSource(), dim));
963       // Add low and high padding value.
964       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
965           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
966       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
967           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
968       dynSizes.push_back(plusHigh);
969     }
970     staticSizes.push_back(resultType.getDimSize(dim));
971   }
972 
973   // Init tensor and fill it with padding.
974   Value emptyTensor = rewriter.create<tensor::EmptyOp>(
975       padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
976   Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
977 
978   // Generate a InsertSliceOp for copying the PadOp source.
979   auto sourceType = padOp.getSourceType();
980   // Compute size of source of tensor::PadOp.
981   SmallVector<OpFoldResult> srcSizes =
982       tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource());
983   // Strides of InsertSliceOp are all 1.
984   SmallVector<OpFoldResult> strides(sourceType.getRank(),
985                                     rewriter.getIndexAttr(1));
986   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
987       padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
988       strides);
989 
990   return success();
991 }
992 
993 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
994     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
995   if (!sliceOp.hasUnitStride())
996     return failure();
997 
998   auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
999   if (!padOp)
1000     return failure();
1001 
1002   bool zeroSliceGuard = true;
1003   if (controlFn) {
1004     if (std::optional<bool> control = controlFn(sliceOp))
1005       zeroSliceGuard = *control;
1006     else
1007       return failure();
1008   }
1009 
1010   FailureOr<TilingResult> tilingResult =
1011       tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
1012                                sliceOp.getMixedSizes(), zeroSliceGuard);
1013   if (failed(tilingResult))
1014     return failure();
1015   // All shapes are static and the data source is actually used. Rewrite into
1016   // pad(extract_slice(x)).
1017   rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
1018   return success();
1019 }
1020 
1021 /// If padding value is set, returns a tensor.pad Op for the source tensor,
1022 /// with the output shape matching the output of `packOp`. Otherwise, returns
1023 /// the source directly.
1024 ///
1025 /// This method assumes that all outer dims for this pack Op are 1.
1026 static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
1027                                            tensor::PackOp packOp) {
1028   Value input = packOp.getSource();
1029   if (!packOp.getPaddingValue()) {
1030     return input;
1031   }
1032 
1033   assert(llvm::all_of(packOp.getAllOuterDims(),
1034                       [](int64_t val) { return val == 1; }) &&
1035          "some outer dims are != 1");
1036 
1037   Location loc = packOp.getLoc();
1038   ShapedType inputType = packOp.getSourceType();
1039   int64_t inputRank = inputType.getRank();
1040 
1041   DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
1042       packOp.getDimAndTileMapping();
1043 
1044   // The sizes of dynamic tiles
1045   SmallVector<Value> dynamicTileSizes;
1046 
1047   // Collect dims for the padded shape.
1048   SmallVector<int64_t> paddedShape;
1049   for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1050     // 1. Non-tiled outer dims.
1051     // These dims should be 1 and we simply preserve them.
1052     if (!tileAndPosMapping.count(dimIdx)) {
1053       int64_t inputDimSize = inputType.getDimSize(dimIdx);
1054       assert(inputDimSize == 1 &&
1055              "with all outer dims == 1, this non-tiled input dim should be 1!");
1056       paddedShape.push_back(inputDimSize);
1057       continue;
1058     }
1059 
1060     // 2. Tiled outer dims
1061     // As all outer dims == 1, it is safe to use the tile size for the padded
1062     // shape.
1063     OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1064 
1065     // 2.1 Static tile sizes
1066     std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
1067     if (cstTileSize.has_value()) {
1068       paddedShape.push_back(cstTileSize.value());
1069       continue;
1070     }
1071 
1072     // 2.2 Dynamic tile sizes
1073     paddedShape.push_back(ShapedType::kDynamic);
1074 
1075     // Get the value that holds the dynamic size.
1076     dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1077   }
1078   auto resultType =
1079       RankedTensorType::get(paddedShape, inputType.getElementType());
1080   return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
1081                                  /*nofold=*/false, loc, builder,
1082                                  dynamicTileSizes);
1083 }
1084 
1085 // Normalizes a permutation on a higher rank space to its actual size, e.g.
1086 //   perm = [1, 4, 2]
1087 // becomes
1088 //   norm = [0, 2, 1]
1089 static SmallVector<int64_t>
1090 getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) {
1091   constexpr int64_t kNonTiledMarker = -1;
1092   SmallVector<int64_t> vec(rank, kNonTiledMarker);
1093   for (auto [index, value] : llvm::enumerate(perm))
1094     vec[value] = index;
1095   SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector(
1096       vec, [&](int64_t v) { return v != kNonTiledMarker; });
1097   // This inverts the permutation in addition to normalizing so invert back.
1098   return invertPermutationVector(normalizedPerm);
1099 }
1100 
1101 // Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
1102 // assuming rank reduction of unit outer dims.
1103 static SmallVector<int64_t>
1104 getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
1105                              ArrayRef<int64_t> innerDimsPos,
1106                              ArrayRef<int64_t> outerDimsPerm) {
1107   SmallVector<int64_t> rankReducedOuterDimsPerm;
1108   SmallVector<int64_t> outerDims;
1109   SmallVector<int64_t> innerDims;
1110   int64_t dim = 0;
1111   int64_t unpackedRank = shape.size();
1112   for (auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1113     if (llvm::is_contained(innerDimsPos, i)) {
1114       innerDims.push_back(dim++);
1115       continue;
1116     }
1117     if (shape[i] == 1)
1118       continue;
1119     outerDims.push_back(dim++);
1120     if (!outerDimsPerm.empty())
1121       rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1122   }
1123 
1124   // Get the position of the inner dims after permutation.
1125   SmallVector<int64_t> innerPerm =
1126       getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos);
1127   applyPermutationToVector<int64_t>(innerDims, innerPerm);
1128 
1129   // Ditto for the outer dims.
1130   SmallVector<int64_t> perm = outerDims;
1131 
1132   rankReducedOuterDimsPerm =
1133       getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm);
1134   if (!rankReducedOuterDimsPerm.empty())
1135     applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1136 
1137   // The tile always ends up as the inner most dims after packing.
1138   perm.append(innerDims);
1139 
1140   return perm;
1141 }
1142 
1143 LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
1144     tensor::PackOp packOp, PatternRewriter &rewriter) const {
1145   // TODO: support the case that outer dimensions are not all 1s. A
1146   // tensor.expand_shape will be generated in this case.
1147   if (llvm::any_of(packOp.getAllOuterDims(),
1148                    [](int64_t dim) { return dim != 1; })) {
1149     return rewriter.notifyMatchFailure(
1150         packOp, "not all outer dimensions of the result are 1s");
1151   }
1152 
1153   Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1154   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1155   Location loc = packOp.getLoc();
1156 
1157   Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1158   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1159       packOp.getDimAndTileMapping();
1160   int64_t srcRank = packOp.getSourceRank();
1161   int64_t destRank = packOp.getDestRank();
1162   int64_t numTiles = destRank - srcRank;
1163 
1164   if (!llvm::all_of(packOp.getInnerDimsPos(),
1165                     [&srcRank, &numTiles](int64_t dimPos) {
1166                       return dimPos >= (srcRank - numTiles - 1);
1167                     }))
1168     return rewriter.notifyMatchFailure(
1169         packOp, "Attempting to tile non-trailing source dims!");
1170 
1171   // 1. Extract the inner tile sizes.
1172   // Where possible, values are replaced with constant attributes (to match the
1173   // behaviour of `getPackOpSourceOrPaddedSource`).
1174   SmallVector<OpFoldResult> tileSizes;
1175   for (auto i : llvm::seq<unsigned>(0, srcRank)) {
1176     if (dimAndTileMapping.count(i)) {
1177       // Rather than taking the tile size as is, extact the actual constant
1178       // value Attribute where possible, e.g.:
1179       //    [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
1180       auto [_, tileSize] =
1181           getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
1182       tileSizes.push_back(tileSize);
1183     }
1184   }
1185 
1186   // 2. Transpose the input to match the inner tile order:
1187   //    %init = tensor.empty()
1188   //    %transposed_tile = linalg.transpose ins(%source_or_padded_source),
1189   //                                        outs(%init)
1190   // Two assumptions are made:
1191   //  1. All outer dims are 1 - the corresponding transposition doesn't matter.
1192   //  2. Inner dims position correspond to the trailing `numTiles` dims.
1193   SmallVector<int64_t> tilesPermNormalized =
1194       getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
1195   SmallVector<int64_t> srcPermForTranspose;
1196   for (int64_t i = 0; i < (srcRank - numTiles); i++)
1197     srcPermForTranspose.push_back(i);
1198 
1199   srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
1200 
1201   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
1202              llvm::interleaveComma(srcPermForTranspose, DBGS() << "perm: ");
1203              DBGSNL(););
1204 
1205   // 2.1 Create tensor.empty (init value for TransposeOp)
1206   SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
1207                                                  oneIdxAttr);
1208   transShapeForEmptyOp.append(tileSizes);
1209 
1210   applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1211                                          srcPermForTranspose);
1212   Value empty = rewriter.create<tensor::EmptyOp>(
1213       loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
1214 
1215   // 2.2 Create linalg.transpose
1216   auto transposedOp = rewriter.create<linalg::TransposeOp>(loc, input, empty,
1217                                                            srcPermForTranspose);
1218 
1219   // 3. Insert the inner tile to the destination:
1220   //  %inserted_tile = tensor.insert_slice(%transposed_tile)
1221   SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1222   SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1223   // Outer dims are all 1s!
1224   SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
1225                                        oneIdxAttr);
1226   SmallVector<int64_t> writeShape;
1227 
1228   for (auto tileSize : packOp.getMixedTiles()) {
1229     auto [tileSizeStatic, tileSizeOfr] =
1230         getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
1231     writeSizes.push_back(tileSizeOfr);
1232     writeShape.push_back(tileSizeStatic);
1233   }
1234 
1235   // 4. Replace tensor.packOp with tensor.insert_slice created above
1236   auto insert = rewriter.create<tensor::InsertSliceOp>(
1237       loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1238       writeSizes, writeStrides);
1239   rewriter.replaceOp(packOp, insert.getResult());
1240 
1241   return success();
1242 }
1243 
1244 LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
1245     tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1246   int64_t srcRank = unpackOp.getSourceRank();
1247   int64_t destRank = unpackOp.getDestRank();
1248   ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
1249   ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1250   if (llvm::any_of(unpackOp.getTiledOuterDims(),
1251                    [](int64_t dim) { return dim != 1; })) {
1252     return rewriter.notifyMatchFailure(
1253         unpackOp,
1254         "require the tiled outer dimensions of the result are all 1s");
1255   }
1256 
1257   // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1258   //    %extracted_tile = tensor.extract_slice(%unpack_op_input)
1259   Location loc = unpackOp.getLoc();
1260   Value source = unpackOp.getSource();
1261   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1262       unpackOp.getDimAndTileMapping();
1263   Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1264   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1265 
1266   // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1267   // dims:
1268   //    [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1269   SmallVector<int64_t> readShapeForExtractSlice;
1270   // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1271   // outer-tiled-dims being all 1), this will be
1272   //    [ outer-untiled-dims, tile-sizes ]
1273   SmallVector<OpFoldResult> extractSliceSizes;
1274   // The offset and strides attributes for ExtractSliceOp.
1275   SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
1276   SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
1277 
1278   // Shape for EmptyOp that's used as the init value for TransposeOp below.
1279   // This should be:
1280   //    [ outer-untiled-dims, tile-sizes ]
1281   // However, skip unit dims - TransposeOp (below) applies rank-reduced
1282   // permutation.
1283   SmallVector<OpFoldResult> shapeForEmptyOp;
1284 
1285   for (auto i : llvm::seq<unsigned>(0, destRank)) {
1286     // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1287     //
1288     // As all outer tiled dims are 1, so the corresponding
1289     // slice size to read will also 1. As this will be rank-reducing "extract
1290     // slice" (i.e. the unit dims will be "collapsed"), there's no need to
1291     // update:
1292     //  * the output shape for ExtractSliceOp, nor
1293     //  * the shape for EmptyOp.
1294     if (dimAndTileMapping.count(i)) {
1295       extractSliceSizes.push_back(oneIdxAttr);
1296       continue;
1297     }
1298 
1299     // Compute sizes attribute for ExtractSliceOp + EmptyOp -
1300     // outer-untiled-dims
1301     if (ShapedType::isDynamic(srcShape[i])) {
1302       OpFoldResult dynamicDim =
1303           rewriter.create<tensor::DimOp>(loc, source, i).getResult();
1304       extractSliceSizes.push_back(dynamicDim);
1305       shapeForEmptyOp.push_back(dynamicDim);
1306     } else {
1307       extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1308       if (srcShape[i] != 1)
1309         shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
1310     }
1311     // Compute the output shape for ExtractSliceOp  - outer-untiled-dims (take
1312     // into account rank-reducing)
1313     if (srcShape[i] != 1) {
1314       readShapeForExtractSlice.push_back(srcShape[i]);
1315     }
1316   }
1317   // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1318   // shape for EmptyOp.
1319   auto mixedTiles = unpackOp.getMixedTiles();
1320   extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1321   shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1322 
1323   // Explicitly create the type for extract_slice op because the inner tile
1324   // size could be 1. We want to represent the whole inner tile in this case.
1325   auto tileShape = srcShape.drop_front(destRank);
1326   // Append the inner tile shape to the permuted and rank-reduced outer shape.
1327   readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1328   Type elemType = unpackOp.getSourceType().getElementType();
1329   auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1330   Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
1331       loc, readType, unpackOp.getSource(), extractSliceOffsets,
1332       extractSliceSizes, extractSliceStrides);
1333 
1334   // 2. Transpose the tile to match the outer corresponding tile order.
1335   SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
1336       srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1337   // Unpack is a transition out of packed space so we invert the permutation.
1338   perm = invertPermutationVector(perm);
1339   applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
1340 
1341   Value empty =
1342       rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType);
1343   auto transposedOp =
1344       rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
1345 
1346   // 3. Handle in-complete tiles if needed. It truncates trailing data from the
1347   // transposed tile.
1348   int numLoops = shapeForEmptyOp.size();
1349   SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
1350   SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
1351   SmallVector<OpFoldResult> tileSizes;
1352   ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
1353   for (auto i : llvm::seq<unsigned>(0, destRank)) {
1354     if (dimAndTileMapping.count(i) || destShape[i] != 1)
1355       tileSizes.push_back(
1356           tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i));
1357   }
1358 
1359   auto partialTile = rewriter.create<tensor::ExtractSliceOp>(
1360       loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1361 
1362   // 4. Insert the result to the destination tensor.
1363   SmallVector<OpFoldResult> writeSizes;
1364   SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1365   SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1366   for (int i = 0, idx = 0; i < destRank; ++i) {
1367     if (dimAndTileMapping.count(i) || destShape[i] != 1)
1368       writeSizes.push_back(tileSizes[idx++]);
1369     else
1370       writeSizes.push_back(oneIdxAttr);
1371   }
1372   auto insert = rewriter.create<tensor::InsertSliceOp>(
1373       loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1374       writeStrides);
1375   rewriter.replaceOp(unpackOp, insert.getResult());
1376 
1377   return success();
1378 }
1379 
1380 // The following are patterns for downscaling convolution ops with size-1
1381 // window dimensions.
1382 //
1383 // Note that we'd eventually want to write such transformations in a generic
1384 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
1385 // and then turning back to named ops. But for now it's fine to have a few
1386 // patterns matching special ops to get started.
1387 
1388 template <typename Conv2DOp, typename Conv1DOp>
1389 FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
1390     returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
1391   if (convOp.hasPureBufferSemantics())
1392     return failure(); // To be implemented.
1393 
1394   Value input = convOp.getInputs().front();
1395   Value kernel = convOp.getInputs().back();
1396   Value output = convOp.getOutputs().front();
1397 
1398   auto inputType = dyn_cast<RankedTensorType>(input.getType());
1399   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1400   auto outputType = dyn_cast<RankedTensorType>(output.getType());
1401 
1402   auto kernelShape = kernelType.getShape();
1403   auto outputShape = outputType.getShape();
1404 
1405   // Get domain indices based on conv2D layout.
1406   auto [khIndex, kwIndex, ohIndex, owIndex] =
1407       TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>(
1408           convOp)
1409           .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1410             return std::make_tuple(0, 1, 1, 2);
1411           })
1412           .Case([&](linalg::Conv2DNchwFchwOp op) {
1413             return std::make_tuple(2, 3, 2, 3);
1414           })
1415           .Case([&](linalg::PoolingNhwcSumOp op) {
1416             return std::make_tuple(0, 1, 1, 2);
1417           })
1418           .Case([&](linalg::PoolingNchwSumOp op) {
1419             return std::make_tuple(0, 1, 2, 3);
1420           })
1421           .Case([&](linalg::PoolingNhwcMaxOp op) {
1422             return std::make_tuple(0, 1, 1, 2);
1423           })
1424           .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1425             return std::make_tuple(0, 1, 1, 2);
1426           })
1427           .Case([&](linalg::PoolingNhwcMinOp op) {
1428             return std::make_tuple(0, 1, 1, 2);
1429           })
1430           .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1431             return std::make_tuple(0, 1, 1, 2);
1432           })
1433           .Case([&](linalg::PoolingNchwMaxOp op) {
1434             return std::make_tuple(0, 1, 2, 3);
1435           })
1436           .Default([&](Operation *op) {
1437             llvm_unreachable("unexpected conv2d/pool2d operation.");
1438             return std::make_tuple(0, 0, 0, 0);
1439           });
1440 
1441   // Only handle the case where at least one of the window dimensions is
1442   // of size 1. Other cases can rely on tiling to reduce to such cases.
1443   int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1444   int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1445   bool removeH = (khSize == 1 && ohSize == 1);
1446   bool removeW = (kwSize == 1 && owSize == 1);
1447   if (!removeH && !removeW)
1448     return failure();
1449 
1450   // Get new shapes and types for all operands by removing the size-1
1451   // dimension.
1452   using RTTBuilder = RankedTensorType::Builder;
1453   RankedTensorType newInputType =
1454       RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
1455   RankedTensorType newKernelType =
1456       RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
1457   RankedTensorType newOutputType =
1458       RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
1459 
1460   // Rank-reduce operands.
1461   Location loc = convOp.getLoc();
1462   Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1463       rewriter, loc, input, newInputType);
1464   Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1465       rewriter, loc, kernel, newKernelType);
1466   Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1467       rewriter, loc, output, newOutputType);
1468 
1469   // Rank-reduce strides and dilations too.
1470   // TODO: dropDim 1-liner helper.
1471   auto strides =
1472       llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1473   strides.erase(strides.begin() + (removeH ? 0 : 1));
1474   auto stridesAttr = rewriter.getI64VectorAttr(strides);
1475 
1476   auto dilations =
1477       llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1478   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1479   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1480 
1481   auto conv1DOp = rewriter.create<Conv1DOp>(
1482       loc, newOutputType, ValueRange{newInput, newKernel},
1483       ValueRange{newOutput}, stridesAttr, dilationsAttr);
1484 
1485   // Insert back.
1486   Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1487       rewriter, loc, conv1DOp.getResult(0), output);
1488   rewriter.replaceOp(convOp, inserted);
1489 
1490   return conv1DOp;
1491 }
1492 
1493 template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
1494                                                               Conv1DNwcWcfOp>;
1495 template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
1496                                                               Conv1DNcwFcwOp>;
1497 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
1498                                                               PoolingNwcSumOp>;
1499 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
1500                                                               PoolingNcwSumOp>;
1501 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
1502                                                               PoolingNwcMaxOp>;
1503 template struct linalg::DownscaleSizeOneWindowed2DConvolution<
1504     PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1505 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
1506                                                               PoolingNwcMinOp>;
1507 template struct linalg::DownscaleSizeOneWindowed2DConvolution<
1508     PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1509 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
1510                                                               PoolingNcwMaxOp>;
1511 
1512 FailureOr<DepthwiseConv1DNwcWcOp>
1513 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
1514     DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
1515   if (convOp.hasPureBufferSemantics())
1516     return failure(); // To be implemented.
1517 
1518   Value input = convOp.getInputs().front();
1519   Value kernel = convOp.getInputs().back();
1520   Value output = convOp.getOutputs().front();
1521 
1522   auto inputType = dyn_cast<RankedTensorType>(input.getType());
1523   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1524   auto outputType = dyn_cast<RankedTensorType>(output.getType());
1525 
1526   auto kernelShape = kernelType.getShape();
1527   auto outputShape = outputType.getShape();
1528 
1529   // Only handle the case where at least one of the window dimensions is
1530   // of size 1. Other cases can rely on tiling to reduce to such cases.
1531   int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1532   int64_t ohSize = outputShape[1], owSize = outputShape[2];
1533   bool removeH = (khSize == 1 && ohSize == 1);
1534   bool removeW = (kwSize == 1 && owSize == 1);
1535   if (!removeH && !removeW)
1536     return failure();
1537 
1538   // Get new shapes and types for all operands by removing the size-1
1539   // dimension.
1540   using RTTBuilder = RankedTensorType::Builder;
1541   RankedTensorType newInputType =
1542       RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1543   RankedTensorType newKernelType =
1544       RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1545   RankedTensorType newOutputType =
1546       RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1547 
1548   // Rank-reduce operands.
1549   Location loc = convOp.getLoc();
1550   Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1551       rewriter, loc, input, newInputType);
1552   Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1553       rewriter, loc, kernel, newKernelType);
1554   Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1555       rewriter, loc, output, newOutputType);
1556 
1557   // Rank-reduce strides and dilations too.
1558   // TODO: dropDim 1-liner helper.
1559   auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1560   strides.erase(strides.begin() + (removeH ? 0 : 1));
1561   auto stridesAttr = rewriter.getI64VectorAttr(strides);
1562 
1563   auto dilations =
1564       llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1565   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1566   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1567 
1568   auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1569       loc, newOutputType, ValueRange{newInput, newKernel},
1570       ValueRange{newOutput}, stridesAttr, dilationsAttr);
1571 
1572   // Insert back.
1573   Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1574       rewriter, loc, conv1DOp.getResult(0), output);
1575   rewriter.replaceOp(convOp, inserted);
1576 
1577   return conv1DOp;
1578 }
1579 
1580 FailureOr<Conv1DOp>
1581 DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
1582                                             PatternRewriter &rewriter) const {
1583   if (convOp.hasPureBufferSemantics())
1584     return failure(); // To be implemented.
1585 
1586   Value input = convOp.getInputs().front();
1587   Value kernel = convOp.getInputs().back();
1588   Value output = convOp.getOutputs().front();
1589 
1590   auto inputType = dyn_cast<RankedTensorType>(input.getType());
1591   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1592   auto outputType = dyn_cast<RankedTensorType>(output.getType());
1593 
1594   auto kernelShape = kernelType.getShape();
1595   auto outputShape = outputType.getShape();
1596 
1597   // Only handle the case where at least one of the window dimensions is
1598   // of size 1. Other cases can rely on tiling to reduce to such cases.
1599   int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1600   int64_t ohSize = outputShape[0], owSize = outputShape[1];
1601   bool removeH = (khSize == 1 && ohSize == 1);
1602   bool removeW = (kwSize == 1 && owSize == 1);
1603   if (!removeH && !removeW)
1604     return failure();
1605 
1606   // Get new shapes and types for all operands by removing the size-1
1607   // dimension.
1608   using RTTBuilder = RankedTensorType::Builder;
1609   RankedTensorType newInputType =
1610       RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
1611   RankedTensorType newKernelType =
1612       RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1613   RankedTensorType newOutputType =
1614       RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
1615 
1616   // Rank-reduce operands.
1617   Location loc = convOp.getLoc();
1618   Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1619       rewriter, loc, input, newInputType);
1620   Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1621       rewriter, loc, kernel, newKernelType);
1622   Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1623       rewriter, loc, output, newOutputType);
1624 
1625   auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType,
1626                                             ValueRange{newInput, newKernel},
1627                                             ValueRange{newOutput});
1628 
1629   // Insert back.
1630   Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1631       rewriter, loc, conv1DOp.getResult(0), output);
1632   rewriter.replaceOp(convOp, inserted);
1633 
1634   return conv1DOp;
1635 }
1636 
1637 void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
1638                                                   PatternBenefit benefit) {
1639   patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
1640                                                      Conv1DNwcWcfOp>,
1641                DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
1642                                                      Conv1DNcwFcwOp>,
1643                DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>(
1644       patterns.getContext(), benefit);
1645   patterns.add<
1646       DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>,
1647       DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>,
1648       DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>,
1649       DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
1650                                             PoolingNwcMaxUnsignedOp>,
1651       DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>,
1652       DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
1653                                             PoolingNwcMinUnsignedOp>,
1654       DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
1655       patterns.getContext(), benefit);
1656 }
1657 
1658 void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
1659   patterns.add<DecomposeOuterUnitDimsPackOpPattern>(patterns.getContext());
1660   patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(patterns.getContext());
1661 }
1662 
1663 void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) {
1664   patterns.add<DecomposePadOpPattern>(patterns.getContext());
1665 }
1666