xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (revision e095d978ba476c9624b4e72776089ea7301fa657)
1 //===- DataLayoutPropagation.cpp -----------------------------------------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Passes.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Tensor/Utils/Utils.h"
17 #include "mlir/Dialect/Utils/IndexingUtils.h"
18 #include "mlir/IR/Dominance.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
22 #include <optional>
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
26 #include "mlir/Dialect/Linalg/Passes.h.inc"
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace mlir::linalg;
31 
32 #define DEBUG_TYPE "linalg-data-layout-propagation"
33 
34 namespace {
35 
36 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
37   for (Operation &op : genericOp.getBody()->getOperations())
38     if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
39       return true;
40   return false;
41 }
42 
43 // The struct contains the infomation about mapping packing information to
44 // the iteration domain of Linalg ops.
45 struct PackInfo {
46   int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
47   // InnerDimsPos on iteration domain, which follows the order in pack ops.
48   SmallVector<int64_t> tiledDimsPos;
49   // The sizes of tiling data dimensions on iteration domain.
50   llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
51   // The mapping from a dimension of iteration domain to the corresponding inner
52   // tiling dimension on iteration domain.
53   llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
54   // The permutation of outer dims (on domain).
55   SmallVector<int64_t> outerDimsOnDomainPerm;
56 };
57 
58 template <typename OpTy>
59 static FailureOr<PackInfo>
60 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
61                           OpTy packOrUnPackOp) {
62   static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
63                 "applies to only pack or unpack operations");
64   LLVM_DEBUG(
65       { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
66 
67   AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
68   SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
69   SmallVector<utils::IteratorType> iterators =
70       genericOp.getIteratorTypesArray();
71 
72   PackInfo packInfo;
73   int64_t origNumDims = indexingMap.getNumDims();
74   SmallVector<AffineExpr> exprs(indexingMap.getResults());
75   ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
76   for (auto [index, innerDimPos, tileSize] :
77        llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
78                        innerDimsPos, packOrUnPackOp.getMixedTiles())) {
79     auto expr = exprs[innerDimPos];
80     if (!isa<AffineDimExpr>(expr))
81       return failure();
82     int64_t domainDimPos =
83         cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
84     if (!isParallelIterator(iterators[domainDimPos]))
85       return failure();
86     packInfo.tiledDimsPos.push_back(domainDimPos);
87     packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
88     packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
89     LLVM_DEBUG({
90       llvm::dbgs() << "map innerDimPos=" << innerDimPos
91                    << " to iteration dimension (d" << domainDimPos << ", d"
92                    << packInfo.tileToPointMapping[domainDimPos]
93                    << "), which has size=("
94                    << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
95     });
96   }
97 
98   // Bail out if a tiled dimension is present in a map but not as an affine dim
99   // expression.
100   auto areAllAffineDimExpr = [&](int dim) {
101     for (AffineMap map : indexingMaps) {
102       if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
103             return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
104           })) {
105         return false;
106       }
107     }
108     return true;
109   };
110   for (int64_t i : packInfo.tiledDimsPos)
111     if (!areAllAffineDimExpr(i))
112       return failure();
113 
114   // Get the outer dims perm on the iteration domain. Start by identifying the
115   // set of domain dims affected by the outer permutation along with the
116   // permuted ordering for those dims. Then the full outer dims permutation can
117   // be constructed by replacing the affected dims with the permuted result in a
118   // numLoops-rank identity. e.g.
119   //   outerDimsPerm = [1, 2, 0]
120   //   indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
121   //
122   //   permutedOuterDims =        [4,    3, 1]
123   //   outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
124   //
125   // Non-affine dim expressions must not be permuted by the outer dims
126   // permutation.
127   SmallVector<int64_t> permutedOuterDims;
128   for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
129     auto permutedExpr = indexingMap.getResult(dim);
130     if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
131       permutedOuterDims.push_back(dimExpr.getPosition());
132       continue;
133     }
134 
135     // TODO: Allow propagation with transposes on non affine dim expressions,
136     // e.g. d0 + d1 which implies transposing both dims simultaneously while
137     // maintaining the relative position between them.
138     if (static_cast<int64_t>(index) != dim)
139       return failure();
140   }
141   if (!permutedOuterDims.empty()) {
142     int64_t outerDimIndex = 0;
143     llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
144                                                permutedOuterDims.end());
145     for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
146       packInfo.outerDimsOnDomainPerm.push_back(
147           permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
148                                          : i);
149     LLVM_DEBUG({
150       llvm::dbgs() << "map outer dimsDimsPerm to ";
151       for (auto dim : packInfo.outerDimsOnDomainPerm)
152         llvm::dbgs() << dim << " ";
153       llvm::dbgs() << "\n";
154     });
155   }
156 
157   return packInfo;
158 }
159 
160 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
161                                              ArrayRef<AffineExpr> exprs) {
162   // Compute `outer_dims_perm`. See example:
163   // current exprs      : (d0, d1, d2, d3) -> (d2, d3)
164   // perm               : [0, 3, 1, 2]
165   // First map d2, d3 with their position in the array as:
166   // currentPositionTileLoops: dim | pos
167   //                           d2  | 0
168   //                           d3  | 1
169   // then scan `perm` in order and get the `outer_dims_perm`
170   // to be used, here it would be [1, 0].
171   assert(!perm.empty() && "expect perm not to be empty");
172   assert(!exprs.empty() && "expect exprs not to be empty");
173   if (exprs.size() == 1)
174     return {};
175   SmallVector<int64_t> outerDimsPerm;
176   DenseMap<int64_t, int64_t> currentPositionTileLoops;
177   for (auto [pos, expr] : llvm::enumerate(exprs)) {
178     // Here we rely on the assumption that the outer dims permutation
179     // when propagating currently requires that non-affine dim expressions
180     // are not permuted, thus allowing the identity assignment below.
181     if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
182       currentPositionTileLoops[dimExpr.getPosition()] = pos;
183     else
184       currentPositionTileLoops[pos] = pos;
185   }
186   for (int64_t loopIdx : perm) {
187     if (currentPositionTileLoops.count(loopIdx))
188       outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
189   }
190   return outerDimsPerm;
191 }
192 
193 /// Returns a tuple for packed operand and indexing_map with the assumptions:
194 ///   1) The generic op is the producer of the pack op.
195 ///   2) The generic op has only one result.
196 /// If the operand is a scalar or packing dimensions are all irrelevant to the
197 /// operand, the operand and the updated indexing map will be returned.
198 /// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
199 ///
200 ///   #map0 = affine_map<(d0, d1) -> (d0, d1)>
201 ///   #map1 = affine_map<(d0, d1) -> (d0)>
202 ///   #map2 = affine_map<(d0, d1) -> (d1)>
203 ///   %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
204 ///                        iterator_types = ["parallel", "parallel"]}
205 ///      ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
206 ///      outs(%init : tensor<?x?xf32>) {
207 ///    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
208 ///      %4 = arith.addf %arg3, %arg4 : f32
209 ///      linalg.yield %4 : f32
210 ///  } -> tensor<?x?xf32>
211 ///  %1 = tensor.pack %0
212 ///    inner_dims_pos = [0, 1]
213 ///    inner_tiles = [8, 2]
214 ///    into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
215 ///
216 ///  Taking the first input operand as an example, the inner tile size of d1 is
217 ///  8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
218 ///  affine_map<(d1, d3)>` will be returned.
219 ///
220 ///  %pack = tensor.pack %arg0
221 ///    inner_dims_pos = [0]
222 ///    inner_tiles = [8]
223 ///    into %init : tensor<?xf32> -> tensor<?x8xf32>
224 static std::tuple<Value, AffineMap>
225 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
226                                GenericOp genericOp, OpOperand *opOperand) {
227   int64_t numOrigLoops = genericOp.getNumLoops();
228   int64_t numInnerLoops = packInfo.getNumTiledLoops();
229   int64_t numLoops = numOrigLoops + numInnerLoops;
230   AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
231   llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
232   SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
233 
234   // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
235   if (genericOp.isScalar(opOperand) || exprs.empty())
236     return std::make_tuple(opOperand->get(),
237                            AffineMap::get(numLoops, 0, exprs, b.getContext()));
238 
239   // Step 1. Construct the information of packing data dimensions; append inner
240   // dimensions to the indexing maps for the operand.
241   for (auto [index, expr] : llvm::enumerate(exprs)) {
242     if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
243       int64_t dimPos = dimExpr.getPosition();
244       domainDimToOperandDim[dimPos] = index;
245       continue;
246     }
247   }
248   SmallVector<int64_t> innerDimsPos;
249   SmallVector<OpFoldResult> innerTileSizes;
250   for (auto dimPos : packInfo.tiledDimsPos) {
251     if (!domainDimToOperandDim.count(dimPos))
252       continue;
253     int64_t index = domainDimToOperandDim[dimPos];
254     innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
255     innerDimsPos.push_back(index);
256     exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
257   }
258 
259   // Step 2. Handle outer dim permutations.
260   SmallVector<int64_t> outerDimsPerm;
261   if (!packInfo.outerDimsOnDomainPerm.empty()) {
262     outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
263 
264     // Step 2.1: Fold transpose into the linalg.generic.
265     SmallVector<int64_t> inversedOuterPerm =
266         invertPermutationVector(packInfo.outerDimsOnDomainPerm);
267     for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
268       if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
269         int64_t dimPos = dimExpr.getPosition();
270         exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
271         continue;
272       }
273       assert(isa<AffineConstantExpr>(exprs[i]) &&
274              "Attempted to permute non-constant and non-affine dim expression");
275     }
276     // Step 2.2: Undo the transposition on `exprs` and propagate the
277     // transposition on the pack using outerDimsPerm.
278     if (!outerDimsPerm.empty()) {
279       SmallVector<AffineExpr> auxVec = exprs;
280       for (const auto &en : enumerate(outerDimsPerm))
281         auxVec[en.index()] = exprs[en.value()];
282       exprs = auxVec;
283     }
284   }
285   auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
286 
287   // The operand does not have dimensions that relates to pack op.
288   if (innerDimsPos.empty() && outerDimsPerm.empty())
289     return std::make_tuple(opOperand->get(), indexingMap);
290 
291   auto empty = tensor::PackOp::createDestinationTensor(
292       b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
293   auto packedOperand = b.create<tensor::PackOp>(
294       loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
295       /*padding=*/std::nullopt, outerDimsPerm);
296   return std::make_tuple(packedOperand, indexingMap);
297 }
298 
299 /// Pack a genericOp and return it.
300 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
301                                Value dest, AffineMap packedOutIndexingMap,
302                                const PackInfo &packInfo) {
303   Location loc = genericOp.getLoc();
304   SmallVector<Value> inputOperands;
305   SmallVector<AffineMap> indexingMaps;
306   for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
307     auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
308         rewriter, loc, packInfo, genericOp, inputOperand);
309     inputOperands.push_back(packedOperand);
310     indexingMaps.push_back(packedIndexingMap);
311   }
312 
313   int64_t numInnerLoops = packInfo.getNumTiledLoops();
314   SmallVector<utils::IteratorType> iterTypes =
315       genericOp.getIteratorTypesArray();
316   iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
317 
318   indexingMaps.push_back(packedOutIndexingMap);
319 
320   auto newGenericOp = rewriter.create<linalg::GenericOp>(
321       loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
322       /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
323   rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
324                              newGenericOp.getRegion().begin());
325   return newGenericOp;
326 }
327 
328 /// Bubbles up tensor.pack op through a producer generic op. This
329 /// swap pack(generic) to generic(pack). The new generic op works on packed
330 /// domain; pack ops are created for input and output operands. E.g.,
331 ///
332 ///     #map0 = affine_map<(d0, d1) -> (d0, d1)>
333 ///     %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
334 ///     %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
335 ///     %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
336 ///     %3 = linalg.generic {indexing_maps = [#map0, #map0],
337 ///                          iterator_types = ["parallel", "parallel"]}
338 ///         ins(%arg0 : tensor<?x?xf32>)
339 ///         outs(%2 : tensor<?x?xf32>) {
340 ///       ^bb0(%arg3: f32, %arg4: f32):
341 ///         %4 = arith.addf %arg3, %arg3 : f32
342 ///         linalg.yield %4 : f32
343 ///     } -> tensor<?x?xf32>
344 ///     %4 = tensor.pack %3
345 ///       inner_dims_pos = [0, 1]
346 ///       inner_tiles = [8, 2]
347 ///       into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
348 ///
349 /// will be converted to
350 ///
351 ///     #map = affine_map<()[s0] -> (s0 ceildiv 8)>
352 ///     #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
353 ///     #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
354 ///     %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
355 ///     %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
356 ///     %0 = affine.apply #map()[%dim]
357 ///     %1 = affine.apply #map1()[%dim_0]
358 ///     %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
359 ///     %pack = tensor.pack %arg0
360 ///       inner_dims_pos = [0, 1]
361 ///       inner_tiles = [8, 2]
362 ///       into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
363 ///     %3 = linalg.generic {indexing_maps = [#map2, #map2],
364 ///       iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
365 ///       ins(%pack : tensor<?x?x8x2xf32>)
366 ///       outs(%arg1 : tensor<?x?x8x2xf32>) {
367 ///     ^bb0(%in: f32, %out: f32):
368 ///       %4 = arith.addf %in, %in : f32
369 ///       linalg.yield %4 : f32
370 ///     } -> tensor<?x?x8x2xf32>
371 static FailureOr<GenericOp>
372 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
373                                const ControlPropagationFn &controlFn) {
374   auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
375   if (!genericOp)
376     return failure();
377 
378   // User controlled propagation function.
379   if (!controlFn(genericOp))
380     return failure();
381 
382   // TODO: Enable propagation in the presence of linalg.index and
383   // tensor.extract, likely as a separate pattern as the pack information and
384   // propagation decision needs to be inferred from the region of the generic.
385   if (hasGatherSemantics(genericOp))
386     return failure();
387 
388   // TODO: Relax the restriction. We are able to bubble up the pack op through
389   // multi-result generic op. It just needs more work.
390   if (genericOp.getNumResults() != 1)
391     return failure();
392 
393   // Bail-out if the result of the generic has multiple uses, as bubbling up
394   // creates recomputation if the generic has multiple users.
395   // TODO: Enable the case where every use is an identical pack op as no
396   // recomputation is needed in that case.
397   if (!genericOp->getResult(0).hasOneUse())
398     return failure();
399 
400   // We want to move the pack not the generic.
401   OpBuilder::InsertionGuard guard(rewriter);
402   rewriter.setInsertionPoint(genericOp);
403 
404   // We need to handle two cases:
405   // 1) The tensor.pack destination is a tensor.empty. If this is the case, we
406   // create a new tensor.empty to avoid breaking dominance, as we are moving the
407   // tensor.pack above the linalg.generic.
408   // 2) The destination is not a tensor.empty. In this case we can replace only
409   // if the destination of the tensor.pack dominates the linalg.generic.
410   Value packOpDest = packOp.getDest();
411   if (!packOpDest.hasOneUse())
412     return failure();
413   if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
414     packOpDest = rewriter.create<tensor::EmptyOp>(
415         genericOp->getLoc(), emptyOp.getMixedSizes(),
416         emptyOp.getType().getElementType());
417   } else {
418     DominanceInfo dom(genericOp);
419     if (!dom.properlyDominates(packOpDest, genericOp))
420       return failure();
421   }
422 
423   // TODO: Add an option for allowing padding values. It could introduce
424   // undefined behavior if we unconditionally propagate pack op through all
425   // the ops. E.g., if the padding value is zero and there are division ops in
426   // a generic op. Some values of padding area could be NaN (0/0).
427   if (packOp.getPaddingValue())
428     return failure();
429 
430   OpOperand *opOperand = genericOp.getDpsInitOperand(0);
431   auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
432   if (failed(packInfo))
433     return failure();
434 
435   // Rebuild the indexing map for the corresponding init operand.
436   auto [packedOutOperand, packedOutIndexingMap] =
437       getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
438                                      genericOp, opOperand);
439 
440   // If the dps init operand of the generic is a tensor.empty forward the pack
441   // op destination.
442   Value dest = packedOutOperand;
443   if (auto initTensor = genericOp.getDpsInitOperand(0)
444                             ->get()
445                             .getDefiningOp<tensor::EmptyOp>()) {
446     dest = packOpDest;
447   }
448   return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
449                        *packInfo);
450 }
451 
452 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
453 struct BubbleUpPackOpThroughGenericOpPattern
454     : public OpRewritePattern<tensor::PackOp> {
455 public:
456   BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
457                                         ControlPropagationFn fun)
458       : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
459 
460   LogicalResult matchAndRewrite(tensor::PackOp packOp,
461                                 PatternRewriter &rewriter) const override {
462     auto genericOp =
463         bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
464     if (failed(genericOp))
465       return failure();
466     rewriter.replaceOp(packOp, genericOp->getResults());
467     return success();
468   }
469 
470 private:
471   ControlPropagationFn controlFn;
472 };
473 
474 /// Propagate a tensor.pack operation up through a tensor.pad. The idea is to
475 /// add as many zero padding dimensions in `high` and `low` based on the number
476 /// of point loops.
477 class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
478 public:
479   BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
480       : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
481 
482   LogicalResult matchAndRewrite(tensor::PackOp packOp,
483                                 PatternRewriter &rewriter) const override {
484     auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
485     if (!padOp)
486       return failure();
487 
488     // User controlled propagation function.
489     if (!controlFn(padOp))
490       return failure();
491 
492     if (!padOp.getResult().hasOneUse())
493       return failure();
494 
495     // TODO: Enable padding when the padding values are the same.
496     if (packOp.getPaddingValue())
497       return failure();
498 
499     // Fail for non-constant padding values. The body of the pad could
500     // depend on the padding indices and/or properties of the padded
501     // tensor so for now we fail.
502     // TODO: Support non-constant padding values.
503     Value paddingVal = padOp.getConstantPaddingValue();
504     if (!paddingVal)
505       return failure();
506 
507     if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
508       return failure();
509 
510     ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
511     ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
512 
513     // Bail out if one of the padded dimension is a tiled one.
514     llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
515     llvm::SmallBitVector innerDims(paddedDims.size());
516     for (int64_t dim : innerDimsPos)
517       innerDims.flip(dim);
518     if (paddedDims.anyCommon(innerDims))
519       return failure();
520 
521     Location loc = padOp->getLoc();
522     OpBuilder::InsertionGuard guard(rewriter);
523     rewriter.setInsertionPoint(padOp);
524 
525     auto empty = tensor::PackOp::createDestinationTensor(
526         rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos,
527         outerDimsPerm);
528     Value packedSource = rewriter.create<tensor::PackOp>(
529         loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(),
530         /*padding=*/std::nullopt, outerDimsPerm);
531 
532     // If we have `outer_dims_perms` we need to adjust the padded dimensions.
533     SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
534     SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
535     if (!outerDimsPerm.empty()) {
536       applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
537       applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
538     }
539     // The tiled dimensions were verified to be unpadded above, so here we
540     // just append 0 for the inner tile dimensions.
541     size_t pointLoopsSize = innerDimsPos.size();
542     lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
543     highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
544 
545     auto newPadOp = rewriter.create<tensor::PadOp>(
546         loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal,
547         padOp.getNofold());
548     rewriter.replaceOp(packOp, newPadOp.getResult());
549     return success();
550   }
551 
552 private:
553   ControlPropagationFn controlFn;
554 };
555 
556 /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
557 ///
558 /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
559 /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
560 /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
561 /// non-unit projected dims in pos [2, 3] is 2.
562 ///
563 /// If all candidates in a reassociation are unit dims, it chooses the
564 /// inner-most dim pos.
565 static SmallVector<int64_t>
566 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
567                                  ArrayRef<ReassociationIndices> reassocIndices,
568                                  ArrayRef<int64_t> targetShape) {
569   SmallVector<int64_t> projectedDimsPos;
570   for (auto pos : dimsPos) {
571     // In the case all dims are unit, this will return the inner-most one.
572     int64_t projectedPos = reassocIndices[pos].back();
573     for (auto i : llvm::reverse(reassocIndices[pos])) {
574       int64_t dim = targetShape[i];
575       if (dim > 1 || ShapedType::isDynamic(dim)) {
576         projectedPos = i;
577         break;
578       }
579     }
580     projectedDimsPos.push_back(projectedPos);
581   }
582   return projectedDimsPos;
583 }
584 
585 /// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
586 static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
587                                        ArrayRef<int64_t> shape,
588                                        ArrayRef<int64_t> tileSizes) {
589   for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
590     int64_t dim = shape[pos];
591     if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
592       return false;
593   }
594   return true;
595 }
596 
597 /// Permutate the reassociation indices and reindex them in the sequence order.
598 /// Returns the next dim pos in the sequence.
599 ///
600 /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
601 /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
602 /// [[0], [1, 2]].
603 static int64_t applyPermutationAndReindexReassoc(
604     SmallVector<ReassociationIndices> &reassocIndices,
605     ArrayRef<int64_t> permutation) {
606   applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
607   int64_t nextPos = 0;
608   for (ReassociationIndices &indices : reassocIndices) {
609     for (auto &index : indices) {
610       index = nextPos;
611       nextPos += 1;
612     }
613   }
614   return nextPos;
615 }
616 
617 /// Bubble up pack op through collapse shape op when the packed dims can be
618 /// projected to the dims before collapsing. This is possible when the inner
619 /// tile sizes can divide the projected dims.
620 ///
621 /// For example:
622 ///
623 /// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
624 ///     : tensor<?x16x4xf32> into tensor<?x4xf32>
625 /// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1]
626 ///     inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
627 ///     : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
628 ///
629 /// can be transformed into:
630 ///
631 /// %pack = tensor.pack %in outer_dims_perm = [1, 2]
632 ///     inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
633 ///     : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
634 /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
635 ///     : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
636 static LogicalResult
637 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
638                                    tensor::PackOp packOp,
639                                    PatternRewriter &rewriter) {
640   SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
641   ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
642   ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
643 
644   ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
645   SmallVector<ReassociationIndices> reassocIndices =
646       collapseOp.getReassociationIndices();
647   // Project inner tile pos to the dim pos before collapsing. For example, if
648   // dims [x, y] is collapsed into [z], packing on dim z can be projected back
649   // to pack on dim y.
650   //
651   // Project to inner-most non-unit dims to increase the chance that they can be
652   // divided by the inner tile sizes. This is correct because for [..., x, 1],
653   // packing on dim 1 is equivalent to packing on dim x.
654   SmallVector<int64_t> projectedInnerDimsPos =
655       projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
656 
657   if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
658                                   innerTileSizes)) {
659     return failure();
660   }
661   // Expand the outer dims permutation with the associated source dims for the
662   // new permutation after bubbling. This is because moving a collapsed dim is
663   // equivalent to moving the associated source dims together.
664   SmallVector<int64_t> newOuterDimsPerm;
665   for (auto outerPos : outerDimsPerm) {
666     newOuterDimsPerm.insert(newOuterDimsPerm.end(),
667                             reassocIndices[outerPos].begin(),
668                             reassocIndices[outerPos].end());
669   }
670 
671   auto emptyOp = tensor::PackOp::createDestinationTensor(
672       rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
673       projectedInnerDimsPos, newOuterDimsPerm);
674   auto newPackOp = rewriter.create<tensor::PackOp>(
675       packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
676       packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
677 
678   SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
679   // First apply the permutation on the reassociations of the outer dims.
680   // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
681   // -> [[0], [1, 2]]
682   int64_t nextPos =
683       applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
684   // Then add direct mapping for the inner tile dims.
685   for (size_t i = 0; i < innerDimsPos.size(); ++i) {
686     newReassocIndices.push_back({nextPos});
687     nextPos += 1;
688   }
689 
690   auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
691       collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
692   rewriter.replaceOp(packOp, newCollapseOp);
693 
694   return success();
695 }
696 
697 class BubbleUpPackOpThroughReshapeOp final
698     : public OpRewritePattern<tensor::PackOp> {
699 public:
700   BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
701       : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
702 
703   LogicalResult matchAndRewrite(tensor::PackOp packOp,
704                                 PatternRewriter &rewriter) const override {
705     Operation *srcOp = packOp.getSource().getDefiningOp();
706     // Currently only support when the pack op is the only user.
707     if (!srcOp || !(srcOp->getNumResults() == 1) ||
708         !srcOp->getResult(0).hasOneUse()) {
709       return failure();
710     }
711     // Currently only support static inner tile sizes.
712     if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
713           return ShapedType::isDynamic(size);
714         })) {
715       return failure();
716     }
717 
718     // User controlled propagation function.
719     if (!controlFn(srcOp))
720       return failure();
721 
722     return TypeSwitch<Operation *, LogicalResult>(srcOp)
723         .Case([&](tensor::CollapseShapeOp op) {
724           return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
725         })
726         .Default([](Operation *) { return failure(); });
727   }
728 
729 private:
730   ControlPropagationFn controlFn;
731 };
732 
733 /// Push down unpack op through expand shape op when the packed dims can be
734 /// projected to the dims after expanding. This is possible when the inner tile
735 /// sizes can divide the projected dims.
736 ///
737 /// For example:
738 ///
739 /// %unpack = tensor.unpack %in outer_dims_perm = [0, 1]
740 ///     inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
741 ///     : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
742 /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
743 ///     : tensor<?x256xf32> into tensor<?x256x256xf32>
744 ///
745 /// can be transformed into:
746 ///
747 /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
748 ///     : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
749 /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
750 ///     inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
751 ///     : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
752 static LogicalResult
753 pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
754                                    tensor::ExpandShapeOp expandOp,
755                                    PatternRewriter &rewriter) {
756   SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
757   ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
758   ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
759 
760   auto expandTy = expandOp.getType().dyn_cast<RankedTensorType>();
761   if (!expandTy)
762     return failure();
763   ArrayRef<int64_t> dstShape = expandTy.getShape();
764   SmallVector<ReassociationIndices> reassocIndices =
765       expandOp.getReassociationIndices();
766   // Project inner tile pos to the dim pos after expanding. For example, if dims
767   // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
768   // on dim y.
769   //
770   // Project to inner-most non-unit dims to increase the chance that they can be
771   // divided by the inner tile sizes. This is correct because for [..., x, 1],
772   // unpacking on dim 1 is equivalent to unpacking on dim x.
773   SmallVector<int64_t> projectedInnerDimsPos =
774       projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
775 
776   if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
777                                   innerTileSizes)) {
778     return failure();
779   }
780   // Expand the outer dims permutation with the associated expanded dims for the
781   // new permutation after pushing. This is because moving a source dim is
782   // equivalent to moving the associated expanded dims together.
783   SmallVector<int64_t> newOuterDimsPerm;
784   for (auto outerPos : outerDimsPerm) {
785     newOuterDimsPerm.insert(newOuterDimsPerm.end(),
786                             reassocIndices[outerPos].begin(),
787                             reassocIndices[outerPos].end());
788   }
789 
790   SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
791   // First apply the permutation on the reassociations of the outer dims.
792   // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
793   // -> [[0], [1, 2]]
794   int64_t nextPos =
795       applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
796   // Then add direct mapping for the inner tile dims.
797   for (size_t i = 0; i < innerDimsPos.size(); ++i) {
798     newReassocIndices.push_back({nextPos});
799     nextPos += 1;
800   }
801 
802   RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
803       expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
804   auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
805       expandOp.getLoc(), newExpandType, unPackOp.getSource(),
806       newReassocIndices);
807 
808   auto emptyOp = tensor::UnPackOp::createDestinationTensor(
809       rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
810       projectedInnerDimsPos, newOuterDimsPerm);
811   auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
812       unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
813       projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
814   rewriter.replaceOp(expandOp, newUnPackOp);
815 
816   return success();
817 }
818 
819 class PushDownUnPackOpThroughReshapeOp final
820     : public OpRewritePattern<tensor::UnPackOp> {
821 public:
822   PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
823                                    ControlPropagationFn fun)
824       : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
825   }
826 
827   LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
828                                 PatternRewriter &rewriter) const override {
829     Value result = unPackOp.getResult();
830     // Currently only support unpack op with the single user.
831     if (!result.hasOneUse()) {
832       return failure();
833     }
834     // Currently only support static inner tile sizes.
835     if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
836           return ShapedType::isDynamic(size);
837         })) {
838       return failure();
839     }
840 
841     Operation *consumerOp = *result.user_begin();
842     // User controlled propagation function.
843     if (!controlFn(consumerOp))
844       return failure();
845 
846     return TypeSwitch<Operation *, LogicalResult>(consumerOp)
847         .Case([&](tensor::ExpandShapeOp op) {
848           return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
849         })
850         .Default([](Operation *) { return failure(); });
851   }
852 
853 private:
854   ControlPropagationFn controlFn;
855 };
856 
857 // TODO: Relax this restriction. We should unpack a generic op also
858 // in the presence of multiple unpack ops as producers.
859 /// Return the unpacked operand, if present, for the current generic op.
860 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
861   OpOperand *unPackedOperand = nullptr;
862   for (OpOperand &operand : genericOp->getOpOperands()) {
863     auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>();
864     if (!unPackOp)
865       continue;
866     if (unPackedOperand)
867       return failure();
868     unPackedOperand = &operand;
869   }
870   if (!unPackedOperand)
871     return failure();
872   return unPackedOperand;
873 }
874 
875 /// Push down a tensor.unpack op through a generic op.
876 /// The new generic op works on packed domain; pack ops are created for input
877 /// and output operands. A tensor.unpack op is inserted right after the packed
878 /// generic. E.g.
879 ///
880 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
881 ///
882 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
883 ///
884 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
885 /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
886 ///                          inner_dims_pos = [3] inner_tiles = [32] into %0
887 /// %2 = linalg.generic {indexing_maps = [#map],
888 ///      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
889 ///      outs(%1 : tensor<12x56x56x64xf32>) {
890 ///      ^bb0(%out : f32):
891 ///         linalg.yield %out : f32
892 ///      } -> tensor<12x56x56x64xf32>
893 ///
894 /// will be converted to
895 ///
896 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
897 ///
898 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
899 /// %1 = linalg.generic {indexing_maps = [#map],
900 ///      iterator_types = ["parallel", "parallel", "parallel",
901 ///                        "parallel", "parallel"]}
902 ///      outs(%arg0 : tensor<12x2x56x56x32xf32>) {
903 ///      ^bb0(%out : f32):
904 ///         linalg.yield %out : f32
905 ///      } -> tensor<12x2x56x56x32xf32>
906 /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2]
907 ///                       inner_dims_pos = [3] inner_tiles = [32] into %0
908 ///
909 static FailureOr<std::tuple<GenericOp, Value>>
910 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
911   if (genericOp.getNumResults() != 1)
912     return failure();
913 
914   if (hasGatherSemantics(genericOp))
915     return failure();
916 
917   // Collect the unPacked operand, if present.
918   auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
919   if (failed(maybeUnPackedOperand))
920     return failure();
921   OpOperand *unPackedOperand = *(maybeUnPackedOperand);
922 
923   // Extract packing information.
924   tensor::UnPackOp producerUnPackOp =
925       unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
926   assert(producerUnPackOp && "expect a valid UnPackOp");
927   auto packInfo =
928       getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
929   if (failed(packInfo))
930     return failure();
931 
932   // Rebuild the indexing map for the corresponding init operand.
933   auto [packedOutOperand, packedOutIndexingMap] =
934       getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
935                                      genericOp, genericOp.getDpsInitOperand(0));
936   auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
937 
938   // If the dps init operand of the generic is a tensor.empty, do not pack it
939   // and forward the new tensor.empty as a destination.
940   Value dest = packedOutOperand;
941   if (auto initTensor = genericOp.getDpsInitOperand(0)
942                             ->get()
943                             .getDefiningOp<tensor::EmptyOp>()) {
944     if (destPack)
945       dest = destPack.getDest();
946   }
947 
948   // Pack the genericOp.
949   GenericOp newGenericOp =
950       packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
951   Value newResult =
952       newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
953 
954   // If the output is unaffected, no need to unpack.
955   if (!destPack)
956     return std::make_tuple(newGenericOp, newResult);
957 
958   auto mixedTiles = destPack.getMixedTiles();
959   auto innerDimsPos = destPack.getInnerDimsPos();
960   auto outerDimsPerm = destPack.getOuterDimsPerm();
961 
962   // If the output type for the generic differs from the source
963   // unpack op, we need to create a new destination tensor. In the
964   // dynamic case we always need a new destination.
965   auto loc = genericOp.getLoc();
966   Value unPackDest = producerUnPackOp.getDest();
967   auto genericOutType =
968       cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
969   if (producerUnPackOp.getDestType() != genericOutType ||
970       !genericOutType.hasStaticShape()) {
971     unPackDest = tensor::UnPackOp::createDestinationTensor(
972         rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
973   }
974 
975   // Insert an unPackOp right after the packed generic.
976   Value unPackOpRes =
977       rewriter
978           .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
979                                     mixedTiles, outerDimsPerm)
980           .getResult();
981 
982   return std::make_tuple(newGenericOp, unPackOpRes);
983 }
984 
985 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
986 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
987 public:
988   PushDownUnPackOpThroughGenericOp(MLIRContext *context,
989                                    ControlPropagationFn fun)
990       : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
991 
992   LogicalResult matchAndRewrite(GenericOp genericOp,
993                                 PatternRewriter &rewriter) const override {
994     if (!controlFn(genericOp))
995       return failure();
996 
997     auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
998     if (failed(genericAndRepl))
999       return failure();
1000     rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1001     return success();
1002   }
1003 
1004 private:
1005   ControlPropagationFn controlFn;
1006 };
1007 
1008 /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
1009 /// add as many zero padding dimensions in `high` and `low` based on the number
1010 /// of point loops.
1011 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1012   PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1013       : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
1014 
1015   LogicalResult matchAndRewrite(tensor::PadOp padOp,
1016                                 PatternRewriter &rewriter) const override {
1017     tensor::UnPackOp unpackOp =
1018         padOp.getSource().getDefiningOp<tensor::UnPackOp>();
1019     if (!unpackOp)
1020       return failure();
1021 
1022     if (!controlFn(padOp))
1023       return failure();
1024 
1025     Location loc = padOp.getLoc();
1026     // Bail out if one of the padded dimension is a tiled one.
1027     llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1028     ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1029     llvm::SmallBitVector innerDims(paddedDims.size());
1030     for (int64_t dim : innerDimsPos)
1031       innerDims.flip(dim);
1032     if (paddedDims.anyCommon(innerDims))
1033       return failure();
1034 
1035     Value paddingVal = padOp.getConstantPaddingValue();
1036     if (!paddingVal)
1037       return failure();
1038 
1039     // If we have `outer_dims_perms` we need to adjust the padded dimensions.
1040     ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1041     SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
1042     SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
1043     if (!outerDimsPerm.empty()) {
1044       applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1045       applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1046     }
1047     // Add zero padding for the point loops.
1048     size_t pointLoopsSize = innerDimsPos.size();
1049     lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1050     highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1051 
1052     auto newPadOp = rewriter.create<tensor::PadOp>(
1053         loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
1054         paddingVal, padOp.getNofold());
1055 
1056     // Inject the tensor.unpack right after the packed padOp.
1057     Value outputUnPack = rewriter.create<tensor::EmptyOp>(
1058         loc, padOp.getResultType().getShape(),
1059         padOp.getResultType().getElementType());
1060 
1061     Value replacement = rewriter.create<tensor::UnPackOp>(
1062         loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1063         unpackOp.getMixedTiles(), outerDimsPerm);
1064     rewriter.replaceOp(padOp, replacement);
1065     return success();
1066   }
1067 
1068 private:
1069   ControlPropagationFn controlFn;
1070 };
1071 
1072 } // namespace
1073 
1074 void mlir::linalg::populateDataLayoutPropagationPatterns(
1075     RewritePatternSet &patterns,
1076     const ControlPropagationFn &controlPackUnPackPropagation) {
1077   patterns
1078       .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1079               BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1080               PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1081           patterns.getContext(), controlPackUnPackPropagation);
1082 }
1083