xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (revision 165f45354ae51bd00fe9000afbdcc4405e360b02)
10f297cadSHanhan Wang //===- DataLayoutPropagation.cpp -----------------------------------------===///
20f297cadSHanhan Wang //
30f297cadSHanhan Wang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40f297cadSHanhan Wang // See https://llvm.org/LICENSE.txt for license information.
50f297cadSHanhan Wang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60f297cadSHanhan Wang //
70f297cadSHanhan Wang //===----------------------------------------------------------------------===//
80f297cadSHanhan Wang 
90f297cadSHanhan Wang #include "mlir/Dialect/Linalg/Passes.h"
100f297cadSHanhan Wang 
110f297cadSHanhan Wang #include "mlir/Dialect/Affine/IR/AffineOps.h"
120f297cadSHanhan Wang #include "mlir/Dialect/Linalg/IR/Linalg.h"
130f297cadSHanhan Wang #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
140f297cadSHanhan Wang #include "mlir/Dialect/Linalg/Utils/Utils.h"
150f297cadSHanhan Wang #include "mlir/Dialect/Tensor/IR/Tensor.h"
160f297cadSHanhan Wang #include "mlir/Dialect/Tensor/Utils/Utils.h"
170f297cadSHanhan Wang #include "mlir/Dialect/Utils/IndexingUtils.h"
181c228026SLorenzo Chelini #include "mlir/IR/Dominance.h"
190f297cadSHanhan Wang #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20a945f55dSAdam Siemieniuk #include "llvm/ADT/SetOperations.h"
21a945f55dSAdam Siemieniuk #include "llvm/ADT/SetVector.h"
220c1c0d53SJerry Wu #include "llvm/ADT/TypeSwitch.h"
23d38d6065SHanhan Wang #include "llvm/Support/Debug.h"
24a1fe1f5fSKazu Hirata #include <optional>
250f297cadSHanhan Wang 
260f297cadSHanhan Wang namespace mlir {
270f297cadSHanhan Wang #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
280f297cadSHanhan Wang #include "mlir/Dialect/Linalg/Passes.h.inc"
290f297cadSHanhan Wang } // namespace mlir
300f297cadSHanhan Wang 
310f297cadSHanhan Wang using namespace mlir;
320f297cadSHanhan Wang using namespace mlir::linalg;
330f297cadSHanhan Wang 
340f297cadSHanhan Wang #define DEBUG_TYPE "linalg-data-layout-propagation"
350f297cadSHanhan Wang 
360f297cadSHanhan Wang namespace {
370f297cadSHanhan Wang 
38b4563ee1SQuinn Dawkins static bool hasGatherSemantics(linalg::GenericOp genericOp) {
39b4563ee1SQuinn Dawkins   for (Operation &op : genericOp.getBody()->getOperations())
40b4563ee1SQuinn Dawkins     if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
41b4563ee1SQuinn Dawkins       return true;
42b4563ee1SQuinn Dawkins   return false;
43b4563ee1SQuinn Dawkins }
44b4563ee1SQuinn Dawkins 
45d38d6065SHanhan Wang // The struct contains the infomation about mapping packing information to
46d38d6065SHanhan Wang // the iteration domain of Linalg ops.
47d38d6065SHanhan Wang struct PackInfo {
48d38d6065SHanhan Wang   int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
49d38d6065SHanhan Wang   // InnerDimsPos on iteration domain, which follows the order in pack ops.
50d38d6065SHanhan Wang   SmallVector<int64_t> tiledDimsPos;
51d38d6065SHanhan Wang   // The sizes of tiling data dimensions on iteration domain.
52d38d6065SHanhan Wang   llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
53d38d6065SHanhan Wang   // The mapping from a dimension of iteration domain to the corresponding inner
54d38d6065SHanhan Wang   // tiling dimension on iteration domain.
55d38d6065SHanhan Wang   llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
56d38d6065SHanhan Wang   // The permutation of outer dims (on domain).
57d38d6065SHanhan Wang   SmallVector<int64_t> outerDimsOnDomainPerm;
58d38d6065SHanhan Wang };
59d38d6065SHanhan Wang 
606bb0ab0dSLorenzo Chelini template <typename OpTy>
61b4563ee1SQuinn Dawkins static FailureOr<PackInfo>
62b4563ee1SQuinn Dawkins getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
636bb0ab0dSLorenzo Chelini                           OpTy packOrUnPackOp) {
646bb0ab0dSLorenzo Chelini   static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
656bb0ab0dSLorenzo Chelini                 "applies to only pack or unpack operations");
66d38d6065SHanhan Wang   LLVM_DEBUG(
676bb0ab0dSLorenzo Chelini       { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
68b4563ee1SQuinn Dawkins 
69b4563ee1SQuinn Dawkins   AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
70b4563ee1SQuinn Dawkins   SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
71b4563ee1SQuinn Dawkins   SmallVector<utils::IteratorType> iterators =
72b4563ee1SQuinn Dawkins       genericOp.getIteratorTypesArray();
73b4563ee1SQuinn Dawkins 
74d38d6065SHanhan Wang   PackInfo packInfo;
75d38d6065SHanhan Wang   int64_t origNumDims = indexingMap.getNumDims();
76d38d6065SHanhan Wang   SmallVector<AffineExpr> exprs(indexingMap.getResults());
776bb0ab0dSLorenzo Chelini   ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
78d38d6065SHanhan Wang   for (auto [index, innerDimPos, tileSize] :
79d38d6065SHanhan Wang        llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
806bb0ab0dSLorenzo Chelini                        innerDimsPos, packOrUnPackOp.getMixedTiles())) {
81b4563ee1SQuinn Dawkins     auto expr = exprs[innerDimPos];
821609f1c2Slong.chen     if (!isa<AffineDimExpr>(expr))
83b4563ee1SQuinn Dawkins       return failure();
84d38d6065SHanhan Wang     int64_t domainDimPos =
851609f1c2Slong.chen         cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
86b4563ee1SQuinn Dawkins     if (!isParallelIterator(iterators[domainDimPos]))
87b4563ee1SQuinn Dawkins       return failure();
88d38d6065SHanhan Wang     packInfo.tiledDimsPos.push_back(domainDimPos);
89d38d6065SHanhan Wang     packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
90d38d6065SHanhan Wang     packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
91d38d6065SHanhan Wang     LLVM_DEBUG({
92d38d6065SHanhan Wang       llvm::dbgs() << "map innerDimPos=" << innerDimPos
93d38d6065SHanhan Wang                    << " to iteration dimension (d" << domainDimPos << ", d"
94d38d6065SHanhan Wang                    << packInfo.tileToPointMapping[domainDimPos]
95d38d6065SHanhan Wang                    << "), which has size=("
96d38d6065SHanhan Wang                    << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
97d38d6065SHanhan Wang     });
98d38d6065SHanhan Wang   }
99d38d6065SHanhan Wang 
100b4563ee1SQuinn Dawkins   // Bail out if a tiled dimension is present in a map but not as an affine dim
101b4563ee1SQuinn Dawkins   // expression.
102b4563ee1SQuinn Dawkins   auto areAllAffineDimExpr = [&](int dim) {
103b4563ee1SQuinn Dawkins     for (AffineMap map : indexingMaps) {
104b4563ee1SQuinn Dawkins       if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
1051609f1c2Slong.chen             return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
106b4563ee1SQuinn Dawkins           })) {
107b4563ee1SQuinn Dawkins         return false;
108b4563ee1SQuinn Dawkins       }
109b4563ee1SQuinn Dawkins     }
110b4563ee1SQuinn Dawkins     return true;
111b4563ee1SQuinn Dawkins   };
112b4563ee1SQuinn Dawkins   for (int64_t i : packInfo.tiledDimsPos)
113b4563ee1SQuinn Dawkins     if (!areAllAffineDimExpr(i))
114b4563ee1SQuinn Dawkins       return failure();
115b4563ee1SQuinn Dawkins 
116b4563ee1SQuinn Dawkins   // Get the outer dims perm on the iteration domain. Start by identifying the
117b4563ee1SQuinn Dawkins   // set of domain dims affected by the outer permutation along with the
118b4563ee1SQuinn Dawkins   // permuted ordering for those dims. Then the full outer dims permutation can
119b4563ee1SQuinn Dawkins   // be constructed by replacing the affected dims with the permuted result in a
120b4563ee1SQuinn Dawkins   // numLoops-rank identity. e.g.
121b4563ee1SQuinn Dawkins   //   outerDimsPerm = [1, 2, 0]
122b4563ee1SQuinn Dawkins   //   indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
123b4563ee1SQuinn Dawkins   //
124b4563ee1SQuinn Dawkins   //   permutedOuterDims =        [4,    3, 1]
125b4563ee1SQuinn Dawkins   //   outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
126b4563ee1SQuinn Dawkins   //
127b4563ee1SQuinn Dawkins   // Non-affine dim expressions must not be permuted by the outer dims
128b4563ee1SQuinn Dawkins   // permutation.
129b4563ee1SQuinn Dawkins   SmallVector<int64_t> permutedOuterDims;
130b4563ee1SQuinn Dawkins   for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
131b4563ee1SQuinn Dawkins     auto permutedExpr = indexingMap.getResult(dim);
1321609f1c2Slong.chen     if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
133b4563ee1SQuinn Dawkins       permutedOuterDims.push_back(dimExpr.getPosition());
134b4563ee1SQuinn Dawkins       continue;
135b4563ee1SQuinn Dawkins     }
136b4563ee1SQuinn Dawkins 
137b4563ee1SQuinn Dawkins     // TODO: Allow propagation with transposes on non affine dim expressions,
138b4563ee1SQuinn Dawkins     // e.g. d0 + d1 which implies transposing both dims simultaneously while
139b4563ee1SQuinn Dawkins     // maintaining the relative position between them.
140b4563ee1SQuinn Dawkins     if (static_cast<int64_t>(index) != dim)
141b4563ee1SQuinn Dawkins       return failure();
142b4563ee1SQuinn Dawkins   }
143b4563ee1SQuinn Dawkins   if (!permutedOuterDims.empty()) {
144b4563ee1SQuinn Dawkins     int64_t outerDimIndex = 0;
145b4563ee1SQuinn Dawkins     llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
146b4563ee1SQuinn Dawkins                                                permutedOuterDims.end());
147b4563ee1SQuinn Dawkins     for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
148b4563ee1SQuinn Dawkins       packInfo.outerDimsOnDomainPerm.push_back(
149b4563ee1SQuinn Dawkins           permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
150b4563ee1SQuinn Dawkins                                          : i);
151d38d6065SHanhan Wang     LLVM_DEBUG({
152d38d6065SHanhan Wang       llvm::dbgs() << "map outer dimsDimsPerm to ";
153d38d6065SHanhan Wang       for (auto dim : packInfo.outerDimsOnDomainPerm)
154d38d6065SHanhan Wang         llvm::dbgs() << dim << " ";
155d38d6065SHanhan Wang       llvm::dbgs() << "\n";
156d38d6065SHanhan Wang     });
157d38d6065SHanhan Wang   }
158d38d6065SHanhan Wang 
159d38d6065SHanhan Wang   return packInfo;
160d38d6065SHanhan Wang }
161d38d6065SHanhan Wang 
162d7904a70SLorenzo Chelini static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
163d7904a70SLorenzo Chelini                                              ArrayRef<AffineExpr> exprs) {
164d7904a70SLorenzo Chelini   // Compute `outer_dims_perm`. See example:
165d7904a70SLorenzo Chelini   // current exprs      : (d0, d1, d2, d3) -> (d2, d3)
166d7904a70SLorenzo Chelini   // perm               : [0, 3, 1, 2]
167d7904a70SLorenzo Chelini   // First map d2, d3 with their position in the array as:
168d7904a70SLorenzo Chelini   // currentPositionTileLoops: dim | pos
169d7904a70SLorenzo Chelini   //                           d2  | 0
170d7904a70SLorenzo Chelini   //                           d3  | 1
171d7904a70SLorenzo Chelini   // then scan `perm` in order and get the `outer_dims_perm`
172d7904a70SLorenzo Chelini   // to be used, here it would be [1, 0].
173d7904a70SLorenzo Chelini   assert(!perm.empty() && "expect perm not to be empty");
174d7904a70SLorenzo Chelini   assert(!exprs.empty() && "expect exprs not to be empty");
175d7904a70SLorenzo Chelini   if (exprs.size() == 1)
176d7904a70SLorenzo Chelini     return {};
177d7904a70SLorenzo Chelini   SmallVector<int64_t> outerDimsPerm;
178d7904a70SLorenzo Chelini   DenseMap<int64_t, int64_t> currentPositionTileLoops;
179d7904a70SLorenzo Chelini   for (auto [pos, expr] : llvm::enumerate(exprs)) {
180b4563ee1SQuinn Dawkins     // Here we rely on the assumption that the outer dims permutation
181b4563ee1SQuinn Dawkins     // when propagating currently requires that non-affine dim expressions
182b4563ee1SQuinn Dawkins     // are not permuted, thus allowing the identity assignment below.
1831609f1c2Slong.chen     if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
184b4563ee1SQuinn Dawkins       currentPositionTileLoops[dimExpr.getPosition()] = pos;
185b4563ee1SQuinn Dawkins     else
186b4563ee1SQuinn Dawkins       currentPositionTileLoops[pos] = pos;
187d7904a70SLorenzo Chelini   }
188d7904a70SLorenzo Chelini   for (int64_t loopIdx : perm) {
189d7904a70SLorenzo Chelini     if (currentPositionTileLoops.count(loopIdx))
190d7904a70SLorenzo Chelini       outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
191d7904a70SLorenzo Chelini   }
192d7904a70SLorenzo Chelini   return outerDimsPerm;
193d7904a70SLorenzo Chelini }
194d7904a70SLorenzo Chelini 
1950f297cadSHanhan Wang /// Returns a tuple for packed operand and indexing_map with the assumptions:
1960f297cadSHanhan Wang ///   1) The generic op is the producer of the pack op.
1970f297cadSHanhan Wang ///   2) The generic op has only one result.
1980f297cadSHanhan Wang /// If the operand is a scalar or packing dimensions are all irrelevant to the
199d7904a70SLorenzo Chelini /// operand, the operand and the updated indexing map will be returned.
2000f297cadSHanhan Wang /// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
2010f297cadSHanhan Wang ///
2020f297cadSHanhan Wang ///   #map0 = affine_map<(d0, d1) -> (d0, d1)>
2030f297cadSHanhan Wang ///   #map1 = affine_map<(d0, d1) -> (d0)>
2040f297cadSHanhan Wang ///   #map2 = affine_map<(d0, d1) -> (d1)>
2050f297cadSHanhan Wang ///   %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
2060f297cadSHanhan Wang ///                        iterator_types = ["parallel", "parallel"]}
2070f297cadSHanhan Wang ///      ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
2080f297cadSHanhan Wang ///      outs(%init : tensor<?x?xf32>) {
2090f297cadSHanhan Wang ///    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
2100f297cadSHanhan Wang ///      %4 = arith.addf %arg3, %arg4 : f32
2110f297cadSHanhan Wang ///      linalg.yield %4 : f32
2120f297cadSHanhan Wang ///  } -> tensor<?x?xf32>
2130f297cadSHanhan Wang ///  %1 = tensor.pack %0
2140f297cadSHanhan Wang ///    inner_dims_pos = [0, 1]
2150f297cadSHanhan Wang ///    inner_tiles = [8, 2]
2160f297cadSHanhan Wang ///    into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
2170f297cadSHanhan Wang ///
2180f297cadSHanhan Wang ///  Taking the first input operand as an example, the inner tile size of d1 is
2190f297cadSHanhan Wang ///  8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
2200f297cadSHanhan Wang ///  affine_map<(d1, d3)>` will be returned.
2210f297cadSHanhan Wang ///
2220f297cadSHanhan Wang ///  %pack = tensor.pack %arg0
2230f297cadSHanhan Wang ///    inner_dims_pos = [0]
2240f297cadSHanhan Wang ///    inner_tiles = [8]
2250f297cadSHanhan Wang ///    into %init : tensor<?xf32> -> tensor<?x8xf32>
2260f297cadSHanhan Wang static std::tuple<Value, AffineMap>
227d38d6065SHanhan Wang getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
228d38d6065SHanhan Wang                                GenericOp genericOp, OpOperand *opOperand) {
229d38d6065SHanhan Wang   int64_t numOrigLoops = genericOp.getNumLoops();
230d38d6065SHanhan Wang   int64_t numInnerLoops = packInfo.getNumTiledLoops();
2310f297cadSHanhan Wang   int64_t numLoops = numOrigLoops + numInnerLoops;
2320f297cadSHanhan Wang   AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
233d38d6065SHanhan Wang   llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
2340f297cadSHanhan Wang   SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
2359f242404SLorenzo Chelini 
2369f242404SLorenzo Chelini   // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
2373cf42c3fSAdrian Kuegel   if (genericOp.isScalar(opOperand) || exprs.empty())
238d38d6065SHanhan Wang     return std::make_tuple(opOperand->get(),
239d38d6065SHanhan Wang                            AffineMap::get(numLoops, 0, exprs, b.getContext()));
2400f297cadSHanhan Wang 
241d38d6065SHanhan Wang   // Step 1. Construct the information of packing data dimensions; append inner
242d38d6065SHanhan Wang   // dimensions to the indexing maps for the operand.
243d38d6065SHanhan Wang   for (auto [index, expr] : llvm::enumerate(exprs)) {
2441609f1c2Slong.chen     if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
2455f2618feSQuinn Dawkins       int64_t dimPos = dimExpr.getPosition();
246d38d6065SHanhan Wang       domainDimToOperandDim[dimPos] = index;
2475f2618feSQuinn Dawkins       continue;
2485f2618feSQuinn Dawkins     }
2490f297cadSHanhan Wang   }
2500f297cadSHanhan Wang   SmallVector<int64_t> innerDimsPos;
2510f297cadSHanhan Wang   SmallVector<OpFoldResult> innerTileSizes;
252d38d6065SHanhan Wang   for (auto dimPos : packInfo.tiledDimsPos) {
253d38d6065SHanhan Wang     if (!domainDimToOperandDim.count(dimPos))
2540f297cadSHanhan Wang       continue;
255d38d6065SHanhan Wang     int64_t index = domainDimToOperandDim[dimPos];
256d38d6065SHanhan Wang     innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
257d38d6065SHanhan Wang     innerDimsPos.push_back(index);
258d38d6065SHanhan Wang     exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
2590f297cadSHanhan Wang   }
2600f297cadSHanhan Wang 
261d7904a70SLorenzo Chelini   // Step 2. Handle outer dim permutations.
2620f297cadSHanhan Wang   SmallVector<int64_t> outerDimsPerm;
263d38d6065SHanhan Wang   if (!packInfo.outerDimsOnDomainPerm.empty()) {
264d7904a70SLorenzo Chelini     outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
265d7904a70SLorenzo Chelini 
266d7904a70SLorenzo Chelini     // Step 2.1: Fold transpose into the linalg.generic.
267d38d6065SHanhan Wang     SmallVector<int64_t> inversedOuterPerm =
268d38d6065SHanhan Wang         invertPermutationVector(packInfo.outerDimsOnDomainPerm);
269d38d6065SHanhan Wang     for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
2701609f1c2Slong.chen       if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
2715f2618feSQuinn Dawkins         int64_t dimPos = dimExpr.getPosition();
272d38d6065SHanhan Wang         exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
2735f2618feSQuinn Dawkins         continue;
2745f2618feSQuinn Dawkins       }
2751609f1c2Slong.chen       assert(isa<AffineConstantExpr>(exprs[i]) &&
2765f2618feSQuinn Dawkins              "Attempted to permute non-constant and non-affine dim expression");
2770f297cadSHanhan Wang     }
278d7904a70SLorenzo Chelini     // Step 2.2: Undo the transposition on `exprs` and propagate the
279d7904a70SLorenzo Chelini     // transposition on the pack using outerDimsPerm.
280d7904a70SLorenzo Chelini     if (!outerDimsPerm.empty()) {
281d7904a70SLorenzo Chelini       SmallVector<AffineExpr> auxVec = exprs;
282d7904a70SLorenzo Chelini       for (const auto &en : enumerate(outerDimsPerm))
283d7904a70SLorenzo Chelini         auxVec[en.index()] = exprs[en.value()];
284d7904a70SLorenzo Chelini       exprs = auxVec;
285d7904a70SLorenzo Chelini     }
286d38d6065SHanhan Wang   }
287d38d6065SHanhan Wang   auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
2880f297cadSHanhan Wang 
2890f297cadSHanhan Wang   // The operand does not have dimensions that relates to pack op.
290b4563ee1SQuinn Dawkins   if (innerDimsPos.empty() && outerDimsPerm.empty())
2910f297cadSHanhan Wang     return std::make_tuple(opOperand->get(), indexingMap);
2920f297cadSHanhan Wang 
2930f297cadSHanhan Wang   auto empty = tensor::PackOp::createDestinationTensor(
2940f297cadSHanhan Wang       b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
2950f297cadSHanhan Wang   auto packedOperand = b.create<tensor::PackOp>(
2960f297cadSHanhan Wang       loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
297ff6f4ae7SLorenzo Chelini       /*padding=*/std::nullopt, outerDimsPerm);
2980f297cadSHanhan Wang   return std::make_tuple(packedOperand, indexingMap);
2990f297cadSHanhan Wang }
3000f297cadSHanhan Wang 
3019f242404SLorenzo Chelini /// Pack a genericOp and return it.
3029f242404SLorenzo Chelini static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
3036bb0ab0dSLorenzo Chelini                                Value dest, AffineMap packedOutIndexingMap,
3046bb0ab0dSLorenzo Chelini                                const PackInfo &packInfo) {
3056bb0ab0dSLorenzo Chelini   Location loc = genericOp.getLoc();
3066bb0ab0dSLorenzo Chelini   SmallVector<Value> inputOperands;
3076bb0ab0dSLorenzo Chelini   SmallVector<AffineMap> indexingMaps;
3086bb0ab0dSLorenzo Chelini   for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
3096bb0ab0dSLorenzo Chelini     auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
3106bb0ab0dSLorenzo Chelini         rewriter, loc, packInfo, genericOp, inputOperand);
3116bb0ab0dSLorenzo Chelini     inputOperands.push_back(packedOperand);
3126bb0ab0dSLorenzo Chelini     indexingMaps.push_back(packedIndexingMap);
3136bb0ab0dSLorenzo Chelini   }
3146bb0ab0dSLorenzo Chelini 
3156bb0ab0dSLorenzo Chelini   int64_t numInnerLoops = packInfo.getNumTiledLoops();
3166bb0ab0dSLorenzo Chelini   SmallVector<utils::IteratorType> iterTypes =
3176bb0ab0dSLorenzo Chelini       genericOp.getIteratorTypesArray();
3186bb0ab0dSLorenzo Chelini   iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
3196bb0ab0dSLorenzo Chelini 
3206bb0ab0dSLorenzo Chelini   indexingMaps.push_back(packedOutIndexingMap);
3216bb0ab0dSLorenzo Chelini 
3226bb0ab0dSLorenzo Chelini   auto newGenericOp = rewriter.create<linalg::GenericOp>(
3236bb0ab0dSLorenzo Chelini       loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
3246bb0ab0dSLorenzo Chelini       /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
3256bb0ab0dSLorenzo Chelini   rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
3266bb0ab0dSLorenzo Chelini                              newGenericOp.getRegion().begin());
3276bb0ab0dSLorenzo Chelini   return newGenericOp;
3286bb0ab0dSLorenzo Chelini }
3296bb0ab0dSLorenzo Chelini 
330b4563ee1SQuinn Dawkins /// Bubbles up tensor.pack op through a producer generic op. This
3310f297cadSHanhan Wang /// swap pack(generic) to generic(pack). The new generic op works on packed
3320f297cadSHanhan Wang /// domain; pack ops are created for input and output operands. E.g.,
3330f297cadSHanhan Wang ///
3340f297cadSHanhan Wang ///     #map0 = affine_map<(d0, d1) -> (d0, d1)>
3350f297cadSHanhan Wang ///     %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
3360f297cadSHanhan Wang ///     %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
3370f297cadSHanhan Wang ///     %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
3380f297cadSHanhan Wang ///     %3 = linalg.generic {indexing_maps = [#map0, #map0],
3390f297cadSHanhan Wang ///                          iterator_types = ["parallel", "parallel"]}
3400f297cadSHanhan Wang ///         ins(%arg0 : tensor<?x?xf32>)
3410f297cadSHanhan Wang ///         outs(%2 : tensor<?x?xf32>) {
3420f297cadSHanhan Wang ///       ^bb0(%arg3: f32, %arg4: f32):
3430f297cadSHanhan Wang ///         %4 = arith.addf %arg3, %arg3 : f32
3440f297cadSHanhan Wang ///         linalg.yield %4 : f32
3450f297cadSHanhan Wang ///     } -> tensor<?x?xf32>
3460f297cadSHanhan Wang ///     %4 = tensor.pack %3
3470f297cadSHanhan Wang ///       inner_dims_pos = [0, 1]
3480f297cadSHanhan Wang ///       inner_tiles = [8, 2]
3490f297cadSHanhan Wang ///       into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
3500f297cadSHanhan Wang ///
3510f297cadSHanhan Wang /// will be converted to
3520f297cadSHanhan Wang ///
3530f297cadSHanhan Wang ///     #map = affine_map<()[s0] -> (s0 ceildiv 8)>
3540f297cadSHanhan Wang ///     #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
3550f297cadSHanhan Wang ///     #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
3560f297cadSHanhan Wang ///     %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
3570f297cadSHanhan Wang ///     %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
3580f297cadSHanhan Wang ///     %0 = affine.apply #map()[%dim]
3590f297cadSHanhan Wang ///     %1 = affine.apply #map1()[%dim_0]
3600f297cadSHanhan Wang ///     %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
3610f297cadSHanhan Wang ///     %pack = tensor.pack %arg0
3620f297cadSHanhan Wang ///       inner_dims_pos = [0, 1]
3630f297cadSHanhan Wang ///       inner_tiles = [8, 2]
3640f297cadSHanhan Wang ///       into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
3650f297cadSHanhan Wang ///     %3 = linalg.generic {indexing_maps = [#map2, #map2],
3660f297cadSHanhan Wang ///       iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
3670f297cadSHanhan Wang ///       ins(%pack : tensor<?x?x8x2xf32>)
3680f297cadSHanhan Wang ///       outs(%arg1 : tensor<?x?x8x2xf32>) {
3690f297cadSHanhan Wang ///     ^bb0(%in: f32, %out: f32):
3700f297cadSHanhan Wang ///       %4 = arith.addf %in, %in : f32
3710f297cadSHanhan Wang ///       linalg.yield %4 : f32
3720f297cadSHanhan Wang ///     } -> tensor<?x?x8x2xf32>
3730f297cadSHanhan Wang static FailureOr<GenericOp>
374b4563ee1SQuinn Dawkins bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
3753b61f5a1SMehdi Amini                                const ControlPropagationFn &controlFn) {
3760f297cadSHanhan Wang   auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
377b4563ee1SQuinn Dawkins   if (!genericOp)
378b4563ee1SQuinn Dawkins     return failure();
379b4563ee1SQuinn Dawkins 
380b4563ee1SQuinn Dawkins   // User controlled propagation function.
38104fc471fSHan-Chung Wang   if (!controlFn(&packOp.getSourceMutable()))
382b4563ee1SQuinn Dawkins     return failure();
383b4563ee1SQuinn Dawkins 
384b4563ee1SQuinn Dawkins   // TODO: Enable propagation in the presence of linalg.index and
385b4563ee1SQuinn Dawkins   // tensor.extract, likely as a separate pattern as the pack information and
386b4563ee1SQuinn Dawkins   // propagation decision needs to be inferred from the region of the generic.
387b4563ee1SQuinn Dawkins   if (hasGatherSemantics(genericOp))
3880f297cadSHanhan Wang     return failure();
3890f297cadSHanhan Wang 
3900f297cadSHanhan Wang   // TODO: Relax the restriction. We are able to bubble up the pack op through
3910f297cadSHanhan Wang   // multi-result generic op. It just needs more work.
3920f297cadSHanhan Wang   if (genericOp.getNumResults() != 1)
3930f297cadSHanhan Wang     return failure();
3940f297cadSHanhan Wang 
3951c228026SLorenzo Chelini   // Bail-out if the result of the generic has multiple uses, as bubbling up
3961c228026SLorenzo Chelini   // creates recomputation if the generic has multiple users.
397b4563ee1SQuinn Dawkins   // TODO: Enable the case where every use is an identical pack op as no
398b4563ee1SQuinn Dawkins   // recomputation is needed in that case.
3991c228026SLorenzo Chelini   if (!genericOp->getResult(0).hasOneUse())
4001c228026SLorenzo Chelini     return failure();
4011c228026SLorenzo Chelini 
4021c228026SLorenzo Chelini   // We want to move the pack not the generic.
4031c228026SLorenzo Chelini   OpBuilder::InsertionGuard guard(rewriter);
4041c228026SLorenzo Chelini   rewriter.setInsertionPoint(genericOp);
4051c228026SLorenzo Chelini 
4061c228026SLorenzo Chelini   // We need to handle two cases:
4071c228026SLorenzo Chelini   // 1) The tensor.pack destination is a tensor.empty. If this is the case, we
4081c228026SLorenzo Chelini   // create a new tensor.empty to avoid breaking dominance, as we are moving the
4091c228026SLorenzo Chelini   // tensor.pack above the linalg.generic.
4101c228026SLorenzo Chelini   // 2) The destination is not a tensor.empty. In this case we can replace only
4111c228026SLorenzo Chelini   // if the destination of the tensor.pack dominates the linalg.generic.
4121c228026SLorenzo Chelini   Value packOpDest = packOp.getDest();
4131c228026SLorenzo Chelini   if (!packOpDest.hasOneUse())
4141c228026SLorenzo Chelini     return failure();
4151c228026SLorenzo Chelini   if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
4161c228026SLorenzo Chelini     packOpDest = rewriter.create<tensor::EmptyOp>(
4171c228026SLorenzo Chelini         genericOp->getLoc(), emptyOp.getMixedSizes(),
4181c228026SLorenzo Chelini         emptyOp.getType().getElementType());
4191c228026SLorenzo Chelini   } else {
4201c228026SLorenzo Chelini     DominanceInfo dom(genericOp);
4211c228026SLorenzo Chelini     if (!dom.properlyDominates(packOpDest, genericOp))
4221c228026SLorenzo Chelini       return failure();
4231c228026SLorenzo Chelini   }
4241c228026SLorenzo Chelini 
4250f297cadSHanhan Wang   // TODO: Add an option for allowing padding values. It could introduce
4260f297cadSHanhan Wang   // undefined behavior if we unconditionally propagate pack op through all
4270f297cadSHanhan Wang   // the ops. E.g., if the padding value is zero and there are division ops in
4280f297cadSHanhan Wang   // a generic op. Some values of padding area could be NaN (0/0).
4290f297cadSHanhan Wang   if (packOp.getPaddingValue())
4300f297cadSHanhan Wang     return failure();
4310f297cadSHanhan Wang 
4320f297cadSHanhan Wang   OpOperand *opOperand = genericOp.getDpsInitOperand(0);
433b4563ee1SQuinn Dawkins   auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
434b4563ee1SQuinn Dawkins   if (failed(packInfo))
435b4563ee1SQuinn Dawkins     return failure();
4360f297cadSHanhan Wang 
437d38d6065SHanhan Wang   // Rebuild the indexing map for the corresponding init operand.
438d38d6065SHanhan Wang   auto [packedOutOperand, packedOutIndexingMap] =
439b4563ee1SQuinn Dawkins       getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
4406bb0ab0dSLorenzo Chelini                                      genericOp, opOperand);
4410f297cadSHanhan Wang 
44261be9358SLorenzo Chelini   // If the dps init operand of the generic is a tensor.empty forward the pack
44361be9358SLorenzo Chelini   // op destination.
44461be9358SLorenzo Chelini   Value dest = packedOutOperand;
44561be9358SLorenzo Chelini   if (auto initTensor = genericOp.getDpsInitOperand(0)
44661be9358SLorenzo Chelini                             ->get()
44761be9358SLorenzo Chelini                             .getDefiningOp<tensor::EmptyOp>()) {
44861be9358SLorenzo Chelini     dest = packOpDest;
44961be9358SLorenzo Chelini   }
4509f242404SLorenzo Chelini   return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
451b4563ee1SQuinn Dawkins                        *packInfo);
4520f297cadSHanhan Wang }
4530f297cadSHanhan Wang 
454b4563ee1SQuinn Dawkins /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
455b4563ee1SQuinn Dawkins struct BubbleUpPackOpThroughGenericOpPattern
4560f297cadSHanhan Wang     : public OpRewritePattern<tensor::PackOp> {
457b4563ee1SQuinn Dawkins public:
458b4563ee1SQuinn Dawkins   BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
459b4563ee1SQuinn Dawkins                                         ControlPropagationFn fun)
460b4563ee1SQuinn Dawkins       : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
4610f297cadSHanhan Wang 
4620f297cadSHanhan Wang   LogicalResult matchAndRewrite(tensor::PackOp packOp,
4630f297cadSHanhan Wang                                 PatternRewriter &rewriter) const override {
464b4563ee1SQuinn Dawkins     auto genericOp =
465b4563ee1SQuinn Dawkins         bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
4660f297cadSHanhan Wang     if (failed(genericOp))
4670f297cadSHanhan Wang       return failure();
468cbb09813SFangrui Song     rewriter.replaceOp(packOp, genericOp->getResults());
4690f297cadSHanhan Wang     return success();
4700f297cadSHanhan Wang   }
471b4563ee1SQuinn Dawkins 
472b4563ee1SQuinn Dawkins private:
473b4563ee1SQuinn Dawkins   ControlPropagationFn controlFn;
4740f297cadSHanhan Wang };
4756bb0ab0dSLorenzo Chelini 
476886294a2SQuinn Dawkins /// Propagate a tensor.pack operation up through a tensor.pad. The idea is to
477886294a2SQuinn Dawkins /// add as many zero padding dimensions in `high` and `low` based on the number
478886294a2SQuinn Dawkins /// of point loops.
479886294a2SQuinn Dawkins class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
480886294a2SQuinn Dawkins public:
481886294a2SQuinn Dawkins   BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
482886294a2SQuinn Dawkins       : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
483886294a2SQuinn Dawkins 
484886294a2SQuinn Dawkins   LogicalResult matchAndRewrite(tensor::PackOp packOp,
485886294a2SQuinn Dawkins                                 PatternRewriter &rewriter) const override {
486886294a2SQuinn Dawkins     auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
487886294a2SQuinn Dawkins     if (!padOp)
488886294a2SQuinn Dawkins       return failure();
489886294a2SQuinn Dawkins 
490886294a2SQuinn Dawkins     // User controlled propagation function.
49104fc471fSHan-Chung Wang     if (!controlFn(&packOp.getSourceMutable()))
492886294a2SQuinn Dawkins       return failure();
493886294a2SQuinn Dawkins 
494886294a2SQuinn Dawkins     // TODO: Enable padding when the padding values are the same.
495886294a2SQuinn Dawkins     if (packOp.getPaddingValue())
496886294a2SQuinn Dawkins       return failure();
497886294a2SQuinn Dawkins 
498886294a2SQuinn Dawkins     // Fail for non-constant padding values. The body of the pad could
499886294a2SQuinn Dawkins     // depend on the padding indices and/or properties of the padded
500886294a2SQuinn Dawkins     // tensor so for now we fail.
501886294a2SQuinn Dawkins     // TODO: Support non-constant padding values.
502886294a2SQuinn Dawkins     Value paddingVal = padOp.getConstantPaddingValue();
503886294a2SQuinn Dawkins     if (!paddingVal)
504886294a2SQuinn Dawkins       return failure();
505886294a2SQuinn Dawkins 
506886294a2SQuinn Dawkins     if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
507886294a2SQuinn Dawkins       return failure();
508886294a2SQuinn Dawkins 
509886294a2SQuinn Dawkins     ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
510886294a2SQuinn Dawkins 
511886294a2SQuinn Dawkins     // Bail out if one of the padded dimension is a tiled one.
512886294a2SQuinn Dawkins     llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
513886294a2SQuinn Dawkins     llvm::SmallBitVector innerDims(paddedDims.size());
514886294a2SQuinn Dawkins     for (int64_t dim : innerDimsPos)
515886294a2SQuinn Dawkins       innerDims.flip(dim);
516886294a2SQuinn Dawkins     if (paddedDims.anyCommon(innerDims))
517886294a2SQuinn Dawkins       return failure();
518886294a2SQuinn Dawkins 
519886294a2SQuinn Dawkins     Location loc = padOp->getLoc();
520886294a2SQuinn Dawkins     OpBuilder::InsertionGuard guard(rewriter);
521886294a2SQuinn Dawkins     rewriter.setInsertionPoint(padOp);
522886294a2SQuinn Dawkins 
5234ad96785SQuinn Dawkins     ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
5244ad96785SQuinn Dawkins     SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
525886294a2SQuinn Dawkins     auto empty = tensor::PackOp::createDestinationTensor(
5264ad96785SQuinn Dawkins         rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
527886294a2SQuinn Dawkins         outerDimsPerm);
5284ad96785SQuinn Dawkins     auto sourcePack = rewriter.create<tensor::PackOp>(
5294ad96785SQuinn Dawkins         loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
530886294a2SQuinn Dawkins         /*padding=*/std::nullopt, outerDimsPerm);
531886294a2SQuinn Dawkins 
532886294a2SQuinn Dawkins     // If we have `outer_dims_perms` we need to adjust the padded dimensions.
533886294a2SQuinn Dawkins     SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
534886294a2SQuinn Dawkins     SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
535886294a2SQuinn Dawkins     if (!outerDimsPerm.empty()) {
536886294a2SQuinn Dawkins       applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
537886294a2SQuinn Dawkins       applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
538886294a2SQuinn Dawkins     }
539886294a2SQuinn Dawkins     // The tiled dimensions were verified to be unpadded above, so here we
540886294a2SQuinn Dawkins     // just append 0 for the inner tile dimensions.
541886294a2SQuinn Dawkins     size_t pointLoopsSize = innerDimsPos.size();
542886294a2SQuinn Dawkins     lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
543886294a2SQuinn Dawkins     highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
544886294a2SQuinn Dawkins 
545886294a2SQuinn Dawkins     auto newPadOp = rewriter.create<tensor::PadOp>(
5464ad96785SQuinn Dawkins         loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal,
547886294a2SQuinn Dawkins         padOp.getNofold());
5484ad96785SQuinn Dawkins 
5494ad96785SQuinn Dawkins     // If the pad has more than one user, create an unpack on the new pad to
5504ad96785SQuinn Dawkins     // replace the other uses.
5514ad96785SQuinn Dawkins     if (!padOp->hasOneUse()) {
5524ad96785SQuinn Dawkins       auto unpackEmpty = tensor::UnPackOp::createDestinationTensor(
5534ad96785SQuinn Dawkins           rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
5544ad96785SQuinn Dawkins       Value unpackedPad = rewriter.create<tensor::UnPackOp>(
5554ad96785SQuinn Dawkins           loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm);
5564ad96785SQuinn Dawkins       rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
5574ad96785SQuinn Dawkins     }
5584ad96785SQuinn Dawkins 
5594ad96785SQuinn Dawkins     // Replace the pack with the new pad.
560886294a2SQuinn Dawkins     rewriter.replaceOp(packOp, newPadOp.getResult());
5614ad96785SQuinn Dawkins 
562886294a2SQuinn Dawkins     return success();
563886294a2SQuinn Dawkins   }
564886294a2SQuinn Dawkins 
565886294a2SQuinn Dawkins private:
566886294a2SQuinn Dawkins   ControlPropagationFn controlFn;
567886294a2SQuinn Dawkins };
568886294a2SQuinn Dawkins 
5690c1c0d53SJerry Wu /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
5700c1c0d53SJerry Wu ///
5710c1c0d53SJerry Wu /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
5720c1c0d53SJerry Wu /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
5730c1c0d53SJerry Wu /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
5740c1c0d53SJerry Wu /// non-unit projected dims in pos [2, 3] is 2.
5750c1c0d53SJerry Wu ///
5760c1c0d53SJerry Wu /// If all candidates in a reassociation are unit dims, it chooses the
5770c1c0d53SJerry Wu /// inner-most dim pos.
5780c1c0d53SJerry Wu static SmallVector<int64_t>
5790c1c0d53SJerry Wu projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
5800c1c0d53SJerry Wu                                  ArrayRef<ReassociationIndices> reassocIndices,
5810c1c0d53SJerry Wu                                  ArrayRef<int64_t> targetShape) {
5820c1c0d53SJerry Wu   SmallVector<int64_t> projectedDimsPos;
5830c1c0d53SJerry Wu   for (auto pos : dimsPos) {
5840c1c0d53SJerry Wu     // In the case all dims are unit, this will return the inner-most one.
5850c1c0d53SJerry Wu     int64_t projectedPos = reassocIndices[pos].back();
5860c1c0d53SJerry Wu     for (auto i : llvm::reverse(reassocIndices[pos])) {
5870c1c0d53SJerry Wu       int64_t dim = targetShape[i];
5880c1c0d53SJerry Wu       if (dim > 1 || ShapedType::isDynamic(dim)) {
5890c1c0d53SJerry Wu         projectedPos = i;
5900c1c0d53SJerry Wu         break;
5910c1c0d53SJerry Wu       }
5920c1c0d53SJerry Wu     }
5930c1c0d53SJerry Wu     projectedDimsPos.push_back(projectedPos);
5940c1c0d53SJerry Wu   }
5950c1c0d53SJerry Wu   return projectedDimsPos;
5960c1c0d53SJerry Wu }
5970c1c0d53SJerry Wu 
5980c1c0d53SJerry Wu /// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
5990c1c0d53SJerry Wu static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
6000c1c0d53SJerry Wu                                        ArrayRef<int64_t> shape,
6010c1c0d53SJerry Wu                                        ArrayRef<int64_t> tileSizes) {
6020c1c0d53SJerry Wu   for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
6030c1c0d53SJerry Wu     int64_t dim = shape[pos];
6040c1c0d53SJerry Wu     if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
6050c1c0d53SJerry Wu       return false;
6060c1c0d53SJerry Wu   }
6070c1c0d53SJerry Wu   return true;
6080c1c0d53SJerry Wu }
6090c1c0d53SJerry Wu 
6100c1c0d53SJerry Wu /// Permutate the reassociation indices and reindex them in the sequence order.
6110c1c0d53SJerry Wu /// Returns the next dim pos in the sequence.
6120c1c0d53SJerry Wu ///
6130c1c0d53SJerry Wu /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
6140c1c0d53SJerry Wu /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
6150c1c0d53SJerry Wu /// [[0], [1, 2]].
6160c1c0d53SJerry Wu static int64_t applyPermutationAndReindexReassoc(
6170c1c0d53SJerry Wu     SmallVector<ReassociationIndices> &reassocIndices,
6180c1c0d53SJerry Wu     ArrayRef<int64_t> permutation) {
619002e8192Syifeizh2   if (!permutation.empty())
6200c1c0d53SJerry Wu     applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
6210c1c0d53SJerry Wu   int64_t nextPos = 0;
6220c1c0d53SJerry Wu   for (ReassociationIndices &indices : reassocIndices) {
6230c1c0d53SJerry Wu     for (auto &index : indices) {
6240c1c0d53SJerry Wu       index = nextPos;
6250c1c0d53SJerry Wu       nextPos += 1;
6260c1c0d53SJerry Wu     }
6270c1c0d53SJerry Wu   }
6280c1c0d53SJerry Wu   return nextPos;
6290c1c0d53SJerry Wu }
6300c1c0d53SJerry Wu 
6310c1c0d53SJerry Wu /// Bubble up pack op through collapse shape op when the packed dims can be
6320c1c0d53SJerry Wu /// projected to the dims before collapsing. This is possible when the inner
6330c1c0d53SJerry Wu /// tile sizes can divide the projected dims.
6340c1c0d53SJerry Wu ///
6350c1c0d53SJerry Wu /// For example:
6360c1c0d53SJerry Wu ///
6370c1c0d53SJerry Wu /// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
6380c1c0d53SJerry Wu ///     : tensor<?x16x4xf32> into tensor<?x4xf32>
6390c1c0d53SJerry Wu /// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1]
6400c1c0d53SJerry Wu ///     inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
6410c1c0d53SJerry Wu ///     : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
6420c1c0d53SJerry Wu ///
6430c1c0d53SJerry Wu /// can be transformed into:
6440c1c0d53SJerry Wu ///
6450c1c0d53SJerry Wu /// %pack = tensor.pack %in outer_dims_perm = [1, 2]
6460c1c0d53SJerry Wu ///     inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
6470c1c0d53SJerry Wu ///     : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
6480c1c0d53SJerry Wu /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
6490c1c0d53SJerry Wu ///     : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
6500c1c0d53SJerry Wu static LogicalResult
6510c1c0d53SJerry Wu bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
6520c1c0d53SJerry Wu                                    tensor::PackOp packOp,
6530c1c0d53SJerry Wu                                    PatternRewriter &rewriter) {
6540c1c0d53SJerry Wu   SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
6550c1c0d53SJerry Wu   ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
6560c1c0d53SJerry Wu   ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
6570c1c0d53SJerry Wu 
6580c1c0d53SJerry Wu   ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
6590c1c0d53SJerry Wu   SmallVector<ReassociationIndices> reassocIndices =
6600c1c0d53SJerry Wu       collapseOp.getReassociationIndices();
6610c1c0d53SJerry Wu   // Project inner tile pos to the dim pos before collapsing. For example, if
6620c1c0d53SJerry Wu   // dims [x, y] is collapsed into [z], packing on dim z can be projected back
6630c1c0d53SJerry Wu   // to pack on dim y.
6640c1c0d53SJerry Wu   //
6650c1c0d53SJerry Wu   // Project to inner-most non-unit dims to increase the chance that they can be
6660c1c0d53SJerry Wu   // divided by the inner tile sizes. This is correct because for [..., x, 1],
6670c1c0d53SJerry Wu   // packing on dim 1 is equivalent to packing on dim x.
6680c1c0d53SJerry Wu   SmallVector<int64_t> projectedInnerDimsPos =
6690c1c0d53SJerry Wu       projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
6700c1c0d53SJerry Wu 
6710c1c0d53SJerry Wu   if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
6720c1c0d53SJerry Wu                                   innerTileSizes)) {
6730c1c0d53SJerry Wu     return failure();
6740c1c0d53SJerry Wu   }
6750c1c0d53SJerry Wu   // Expand the outer dims permutation with the associated source dims for the
6760c1c0d53SJerry Wu   // new permutation after bubbling. This is because moving a collapsed dim is
6770c1c0d53SJerry Wu   // equivalent to moving the associated source dims together.
6780c1c0d53SJerry Wu   SmallVector<int64_t> newOuterDimsPerm;
6790c1c0d53SJerry Wu   for (auto outerPos : outerDimsPerm) {
6800c1c0d53SJerry Wu     newOuterDimsPerm.insert(newOuterDimsPerm.end(),
6810c1c0d53SJerry Wu                             reassocIndices[outerPos].begin(),
6820c1c0d53SJerry Wu                             reassocIndices[outerPos].end());
6830c1c0d53SJerry Wu   }
6840c1c0d53SJerry Wu 
6850c1c0d53SJerry Wu   auto emptyOp = tensor::PackOp::createDestinationTensor(
6860c1c0d53SJerry Wu       rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
6870c1c0d53SJerry Wu       projectedInnerDimsPos, newOuterDimsPerm);
6880c1c0d53SJerry Wu   auto newPackOp = rewriter.create<tensor::PackOp>(
6890c1c0d53SJerry Wu       packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
6900c1c0d53SJerry Wu       packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
6910c1c0d53SJerry Wu 
6920c1c0d53SJerry Wu   SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
6930c1c0d53SJerry Wu   // First apply the permutation on the reassociations of the outer dims.
6940c1c0d53SJerry Wu   // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
6950c1c0d53SJerry Wu   // -> [[0], [1, 2]]
6960c1c0d53SJerry Wu   int64_t nextPos =
6970c1c0d53SJerry Wu       applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
6980c1c0d53SJerry Wu   // Then add direct mapping for the inner tile dims.
6990c1c0d53SJerry Wu   for (size_t i = 0; i < innerDimsPos.size(); ++i) {
7000c1c0d53SJerry Wu     newReassocIndices.push_back({nextPos});
7010c1c0d53SJerry Wu     nextPos += 1;
7020c1c0d53SJerry Wu   }
7030c1c0d53SJerry Wu 
7040c1c0d53SJerry Wu   auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
7050c1c0d53SJerry Wu       collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
7060c1c0d53SJerry Wu   rewriter.replaceOp(packOp, newCollapseOp);
7070c1c0d53SJerry Wu 
7080c1c0d53SJerry Wu   return success();
7090c1c0d53SJerry Wu }
7100c1c0d53SJerry Wu 
711a945f55dSAdam Siemieniuk /// Project dimsPos to their collapsed positions in the reassocIndices.
712a945f55dSAdam Siemieniuk ///
713a945f55dSAdam Siemieniuk /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
714a945f55dSAdam Siemieniuk /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
715a945f55dSAdam Siemieniuk /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
716a945f55dSAdam Siemieniuk /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
717a945f55dSAdam Siemieniuk static SmallVector<int64_t>
718a945f55dSAdam Siemieniuk projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
719a945f55dSAdam Siemieniuk                              ArrayRef<ReassociationIndices> reassocIndices) {
720a945f55dSAdam Siemieniuk   SmallVector<int64_t> projectedPos;
721a945f55dSAdam Siemieniuk 
722a945f55dSAdam Siemieniuk   // Map each dimension to the position of corresponding reassociation index.
723a945f55dSAdam Siemieniuk   for (auto pos : dimsPos) {
724a945f55dSAdam Siemieniuk     for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
725a945f55dSAdam Siemieniuk       // If the dimension is present in the current indices group, the group
726a945f55dSAdam Siemieniuk       // position within the reassociation map is the desired projected
727a945f55dSAdam Siemieniuk       // dimension position.
728*165f4535SKazu Hirata       if (llvm::is_contained(indices, pos)) {
729a945f55dSAdam Siemieniuk         projectedPos.push_back(idx);
730a945f55dSAdam Siemieniuk         break;
731a945f55dSAdam Siemieniuk       }
732a945f55dSAdam Siemieniuk     }
733a945f55dSAdam Siemieniuk   }
734a945f55dSAdam Siemieniuk   assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
735a945f55dSAdam Siemieniuk 
736a945f55dSAdam Siemieniuk   return projectedPos;
737a945f55dSAdam Siemieniuk }
738a945f55dSAdam Siemieniuk 
739a945f55dSAdam Siemieniuk /// Bubble up pack op through expand shape op.
740a945f55dSAdam Siemieniuk ///
741a945f55dSAdam Siemieniuk /// For example:
742a945f55dSAdam Siemieniuk ///
743a945f55dSAdam Siemieniuk /// %expand = tensor.expand_shape %in [[0], [1, 2]]
744a945f55dSAdam Siemieniuk ///     : tensor<?x64xf32> into tensor<?x4x16xf32>
745a945f55dSAdam Siemieniuk /// %pack = tensor.pack %expand outer_dims_perm = [0, 1]
746a945f55dSAdam Siemieniuk ///     inner_dims_pos = [2] inner_tiles = [8] into %empty
747a945f55dSAdam Siemieniuk ///     : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
748a945f55dSAdam Siemieniuk ///
749a945f55dSAdam Siemieniuk /// can be transformed into:
750a945f55dSAdam Siemieniuk ///
751a945f55dSAdam Siemieniuk /// %pack = tensor.pack %in outer_dims_perm = [1, 2]
752a945f55dSAdam Siemieniuk ///     inner_dims_pos = [1] inner_tiles = [8] into %empty
753a945f55dSAdam Siemieniuk ///     : tensor<?x64xf32> -> tensor<?x8x8xf32>
754a945f55dSAdam Siemieniuk /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
755a945f55dSAdam Siemieniuk ///     : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
756a945f55dSAdam Siemieniuk static LogicalResult
757a945f55dSAdam Siemieniuk bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
758a945f55dSAdam Siemieniuk                                  tensor::PackOp packOp,
759a945f55dSAdam Siemieniuk                                  PatternRewriter &rewriter) {
760a945f55dSAdam Siemieniuk   // Outer dimensions permutation is not supported currently.
761a945f55dSAdam Siemieniuk   // TODO: Handle outer_dims_perm variants.
762a945f55dSAdam Siemieniuk   ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
763a945f55dSAdam Siemieniuk   if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
764a945f55dSAdam Siemieniuk     return rewriter.notifyMatchFailure(packOp,
765a945f55dSAdam Siemieniuk                                        "non-identity outer dims perm NYI");
766a945f55dSAdam Siemieniuk   }
767a945f55dSAdam Siemieniuk 
768a945f55dSAdam Siemieniuk   // Validate dimensions' relations between shape expansion and packing.
769a945f55dSAdam Siemieniuk   SmallVector<ReassociationIndices, 4> reassoc =
770a945f55dSAdam Siemieniuk       expandOp.getReassociationIndices();
771a945f55dSAdam Siemieniuk   ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
772a945f55dSAdam Siemieniuk   llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(),
773a945f55dSAdam Siemieniuk                                        packInnerDims.end());
774a945f55dSAdam Siemieniuk 
775a945f55dSAdam Siemieniuk   for (auto [idx, indices] : llvm::enumerate(reassoc)) {
776a945f55dSAdam Siemieniuk     // For each expand_shape reassociation, figure out which dimensions get
777a945f55dSAdam Siemieniuk     // packed if any.
778a945f55dSAdam Siemieniuk     llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end());
779a945f55dSAdam Siemieniuk     llvm::SetVector<int64_t> packedDims =
780a945f55dSAdam Siemieniuk         llvm::set_intersection(packDimsPos, expandDimPos);
781a945f55dSAdam Siemieniuk 
782a945f55dSAdam Siemieniuk     // The expanded dimension is not packed so, it does not affect moving pack
783a945f55dSAdam Siemieniuk     // before shape expansion - simply continue.
784a945f55dSAdam Siemieniuk     if (packedDims.empty())
785a945f55dSAdam Siemieniuk       continue;
786a945f55dSAdam Siemieniuk     // Shape expansion cannot be propagated when multiple expanded dimension are
787a945f55dSAdam Siemieniuk     // packed - in this case operation reordering would affect final element
788a945f55dSAdam Siemieniuk     // positions and/or shapes can no longer be projected.
789a945f55dSAdam Siemieniuk     if (packedDims.size() != 1)
790a945f55dSAdam Siemieniuk       return rewriter.notifyMatchFailure(
791a945f55dSAdam Siemieniuk           packOp, "only one of the expanded dimensions can be packed");
792a945f55dSAdam Siemieniuk     // Only the inner-most expanded dimension should be packed. Otherwise,
793a945f55dSAdam Siemieniuk     // elements order will be affected after operation reordering.
794a945f55dSAdam Siemieniuk     if (packedDims.front() != indices.back())
795a945f55dSAdam Siemieniuk       return rewriter.notifyMatchFailure(
796a945f55dSAdam Siemieniuk           packOp, "can only pack the inner-most expanded dimension");
797a945f55dSAdam Siemieniuk   }
798a945f55dSAdam Siemieniuk 
799a945f55dSAdam Siemieniuk   // Project pack.inner_dims_pos to positions before shape expansion.
800a945f55dSAdam Siemieniuk   SmallVector<int64_t> projectedInnerDimsPos =
801a945f55dSAdam Siemieniuk       projectDimsPosIntoReassocPos(packInnerDims, reassoc);
802a945f55dSAdam Siemieniuk 
803a945f55dSAdam Siemieniuk   // Project the shape expansion to new packed shape.
804a945f55dSAdam Siemieniuk   // The pack.outer_dims_perm is restricted to identity so, the permutation can
805a945f55dSAdam Siemieniuk   // be omitted for simplicity.
806a945f55dSAdam Siemieniuk   // TODO: Account for outer dimensions permutation.
807a945f55dSAdam Siemieniuk   //
808a945f55dSAdam Siemieniuk   // If reassociation is not possible, then reordering cannot happen.
809a945f55dSAdam Siemieniuk   // This can be caused by pack padding affecting previously expanded
810a945f55dSAdam Siemieniuk   // dimensions or packing extending dimensions.
811a945f55dSAdam Siemieniuk   RankedTensorType newPackType = tensor::PackOp::inferPackedType(
812a945f55dSAdam Siemieniuk       expandOp.getSrcType(), packOp.getStaticInnerTiles(),
813a945f55dSAdam Siemieniuk       projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
814a945f55dSAdam Siemieniuk   auto reassocExpand =
815a945f55dSAdam Siemieniuk       getReassociationIndicesForReshape(newPackType, packOp.getDestType());
816a945f55dSAdam Siemieniuk   if (!reassocExpand)
817a945f55dSAdam Siemieniuk     return rewriter.notifyMatchFailure(
818a945f55dSAdam Siemieniuk         packOp, "could not reassociate dims after bubbling up");
819a945f55dSAdam Siemieniuk 
820a945f55dSAdam Siemieniuk   Value destTensor = tensor::PackOp::createDestinationTensor(
821a945f55dSAdam Siemieniuk       rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
822a945f55dSAdam Siemieniuk       projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
823a945f55dSAdam Siemieniuk   Value packedVal = rewriter.create<tensor::PackOp>(
824a945f55dSAdam Siemieniuk       packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
825a945f55dSAdam Siemieniuk       packOp.getMixedTiles(), packOp.getPaddingValue(),
826a945f55dSAdam Siemieniuk       /*outerDimsPerm=*/SmallVector<int64_t>{});
827a945f55dSAdam Siemieniuk 
828a945f55dSAdam Siemieniuk   Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
829a945f55dSAdam Siemieniuk       packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
830a945f55dSAdam Siemieniuk   rewriter.replaceOp(packOp, newExpandOp);
831a945f55dSAdam Siemieniuk 
832a945f55dSAdam Siemieniuk   return success();
833a945f55dSAdam Siemieniuk }
834a945f55dSAdam Siemieniuk 
8350c1c0d53SJerry Wu class BubbleUpPackOpThroughReshapeOp final
8360c1c0d53SJerry Wu     : public OpRewritePattern<tensor::PackOp> {
8370c1c0d53SJerry Wu public:
8380c1c0d53SJerry Wu   BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
8390c1c0d53SJerry Wu       : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
8400c1c0d53SJerry Wu 
8410c1c0d53SJerry Wu   LogicalResult matchAndRewrite(tensor::PackOp packOp,
8420c1c0d53SJerry Wu                                 PatternRewriter &rewriter) const override {
8430c1c0d53SJerry Wu     Operation *srcOp = packOp.getSource().getDefiningOp();
8440c1c0d53SJerry Wu     // Currently only support when the pack op is the only user.
8450c1c0d53SJerry Wu     if (!srcOp || !(srcOp->getNumResults() == 1) ||
8460c1c0d53SJerry Wu         !srcOp->getResult(0).hasOneUse()) {
8470c1c0d53SJerry Wu       return failure();
8480c1c0d53SJerry Wu     }
8490c1c0d53SJerry Wu     // Currently only support static inner tile sizes.
8500c1c0d53SJerry Wu     if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
8510c1c0d53SJerry Wu           return ShapedType::isDynamic(size);
8520c1c0d53SJerry Wu         })) {
8530c1c0d53SJerry Wu       return failure();
8540c1c0d53SJerry Wu     }
8550c1c0d53SJerry Wu 
8560c1c0d53SJerry Wu     // User controlled propagation function.
85704fc471fSHan-Chung Wang     if (!controlFn(&packOp.getSourceMutable()))
8580c1c0d53SJerry Wu       return failure();
8590c1c0d53SJerry Wu 
8600c1c0d53SJerry Wu     return TypeSwitch<Operation *, LogicalResult>(srcOp)
8610c1c0d53SJerry Wu         .Case([&](tensor::CollapseShapeOp op) {
8620c1c0d53SJerry Wu           return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
8630c1c0d53SJerry Wu         })
864a945f55dSAdam Siemieniuk         .Case([&](tensor::ExpandShapeOp op) {
865a945f55dSAdam Siemieniuk           return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
866a945f55dSAdam Siemieniuk         })
8670c1c0d53SJerry Wu         .Default([](Operation *) { return failure(); });
8680c1c0d53SJerry Wu   }
8690c1c0d53SJerry Wu 
8700c1c0d53SJerry Wu private:
8710c1c0d53SJerry Wu   ControlPropagationFn controlFn;
8720c1c0d53SJerry Wu };
8730c1c0d53SJerry Wu 
8740c1c0d53SJerry Wu /// Push down unpack op through expand shape op when the packed dims can be
8750c1c0d53SJerry Wu /// projected to the dims after expanding. This is possible when the inner tile
8760c1c0d53SJerry Wu /// sizes can divide the projected dims.
8770c1c0d53SJerry Wu ///
8780c1c0d53SJerry Wu /// For example:
8790c1c0d53SJerry Wu ///
8800c1c0d53SJerry Wu /// %unpack = tensor.unpack %in outer_dims_perm = [0, 1]
8810c1c0d53SJerry Wu ///     inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
8820c1c0d53SJerry Wu ///     : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
8830c1c0d53SJerry Wu /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
8840c1c0d53SJerry Wu ///     : tensor<?x256xf32> into tensor<?x256x256xf32>
8850c1c0d53SJerry Wu ///
8860c1c0d53SJerry Wu /// can be transformed into:
8870c1c0d53SJerry Wu ///
8880c1c0d53SJerry Wu /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
8890c1c0d53SJerry Wu ///     : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
8900c1c0d53SJerry Wu /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
8910c1c0d53SJerry Wu ///     inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
8920c1c0d53SJerry Wu ///     : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
89304fc471fSHan-Chung Wang static LogicalResult pushDownUnPackOpThroughExpandShape(
89404fc471fSHan-Chung Wang     tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
89504fc471fSHan-Chung Wang     PatternRewriter &rewriter, ControlPropagationFn controlFn) {
89604fc471fSHan-Chung Wang   // User controlled propagation function.
89704fc471fSHan-Chung Wang   if (!controlFn(&expandOp.getSrcMutable()))
89804fc471fSHan-Chung Wang     return failure();
89904fc471fSHan-Chung Wang 
9000c1c0d53SJerry Wu   SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
9010c1c0d53SJerry Wu   ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
9020c1c0d53SJerry Wu   ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
9030c1c0d53SJerry Wu 
904d2353695SPeiming Liu   auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
90597069a86SGaurav Shukla   if (!expandTy)
90697069a86SGaurav Shukla     return failure();
90797069a86SGaurav Shukla   ArrayRef<int64_t> dstShape = expandTy.getShape();
9080c1c0d53SJerry Wu   SmallVector<ReassociationIndices> reassocIndices =
9090c1c0d53SJerry Wu       expandOp.getReassociationIndices();
9100c1c0d53SJerry Wu   // Project inner tile pos to the dim pos after expanding. For example, if dims
9110c1c0d53SJerry Wu   // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
9120c1c0d53SJerry Wu   // on dim y.
9130c1c0d53SJerry Wu   //
9140c1c0d53SJerry Wu   // Project to inner-most non-unit dims to increase the chance that they can be
9150c1c0d53SJerry Wu   // divided by the inner tile sizes. This is correct because for [..., x, 1],
9160c1c0d53SJerry Wu   // unpacking on dim 1 is equivalent to unpacking on dim x.
9170c1c0d53SJerry Wu   SmallVector<int64_t> projectedInnerDimsPos =
9180c1c0d53SJerry Wu       projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
9190c1c0d53SJerry Wu 
9200c1c0d53SJerry Wu   if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
9210c1c0d53SJerry Wu                                   innerTileSizes)) {
9220c1c0d53SJerry Wu     return failure();
9230c1c0d53SJerry Wu   }
9240c1c0d53SJerry Wu   // Expand the outer dims permutation with the associated expanded dims for the
9250c1c0d53SJerry Wu   // new permutation after pushing. This is because moving a source dim is
9260c1c0d53SJerry Wu   // equivalent to moving the associated expanded dims together.
9270c1c0d53SJerry Wu   SmallVector<int64_t> newOuterDimsPerm;
9280c1c0d53SJerry Wu   for (auto outerPos : outerDimsPerm) {
9290c1c0d53SJerry Wu     newOuterDimsPerm.insert(newOuterDimsPerm.end(),
9300c1c0d53SJerry Wu                             reassocIndices[outerPos].begin(),
9310c1c0d53SJerry Wu                             reassocIndices[outerPos].end());
9320c1c0d53SJerry Wu   }
9330c1c0d53SJerry Wu 
9340c1c0d53SJerry Wu   SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
9350c1c0d53SJerry Wu   // First apply the permutation on the reassociations of the outer dims.
9360c1c0d53SJerry Wu   // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
9370c1c0d53SJerry Wu   // -> [[0], [1, 2]]
9380c1c0d53SJerry Wu   int64_t nextPos =
9390c1c0d53SJerry Wu       applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
9400c1c0d53SJerry Wu   // Then add direct mapping for the inner tile dims.
9410c1c0d53SJerry Wu   for (size_t i = 0; i < innerDimsPos.size(); ++i) {
9420c1c0d53SJerry Wu     newReassocIndices.push_back({nextPos});
9430c1c0d53SJerry Wu     nextPos += 1;
9440c1c0d53SJerry Wu   }
9450c1c0d53SJerry Wu 
94697069a86SGaurav Shukla   RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
94797069a86SGaurav Shukla       expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
9480c1c0d53SJerry Wu   auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
9490c1c0d53SJerry Wu       expandOp.getLoc(), newExpandType, unPackOp.getSource(),
9500c1c0d53SJerry Wu       newReassocIndices);
9510c1c0d53SJerry Wu 
9520c1c0d53SJerry Wu   auto emptyOp = tensor::UnPackOp::createDestinationTensor(
9530c1c0d53SJerry Wu       rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
9540c1c0d53SJerry Wu       projectedInnerDimsPos, newOuterDimsPerm);
9550c1c0d53SJerry Wu   auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
9560c1c0d53SJerry Wu       unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
9570c1c0d53SJerry Wu       projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
9580c1c0d53SJerry Wu   rewriter.replaceOp(expandOp, newUnPackOp);
9590c1c0d53SJerry Wu 
9600c1c0d53SJerry Wu   return success();
9610c1c0d53SJerry Wu }
9620c1c0d53SJerry Wu 
9630c1c0d53SJerry Wu class PushDownUnPackOpThroughReshapeOp final
9640c1c0d53SJerry Wu     : public OpRewritePattern<tensor::UnPackOp> {
9650c1c0d53SJerry Wu public:
9660c1c0d53SJerry Wu   PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
9670c1c0d53SJerry Wu                                    ControlPropagationFn fun)
9680c1c0d53SJerry Wu       : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
9690c1c0d53SJerry Wu   }
9700c1c0d53SJerry Wu 
9710c1c0d53SJerry Wu   LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
9720c1c0d53SJerry Wu                                 PatternRewriter &rewriter) const override {
9730c1c0d53SJerry Wu     Value result = unPackOp.getResult();
9740c1c0d53SJerry Wu     // Currently only support unpack op with the single user.
9750c1c0d53SJerry Wu     if (!result.hasOneUse()) {
9760c1c0d53SJerry Wu       return failure();
9770c1c0d53SJerry Wu     }
9780c1c0d53SJerry Wu     // Currently only support static inner tile sizes.
9790c1c0d53SJerry Wu     if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
9800c1c0d53SJerry Wu           return ShapedType::isDynamic(size);
9810c1c0d53SJerry Wu         })) {
9820c1c0d53SJerry Wu       return failure();
9830c1c0d53SJerry Wu     }
9840c1c0d53SJerry Wu 
9850c1c0d53SJerry Wu     Operation *consumerOp = *result.user_begin();
9860c1c0d53SJerry Wu     return TypeSwitch<Operation *, LogicalResult>(consumerOp)
9870c1c0d53SJerry Wu         .Case([&](tensor::ExpandShapeOp op) {
98804fc471fSHan-Chung Wang           return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
98904fc471fSHan-Chung Wang                                                     controlFn);
9900c1c0d53SJerry Wu         })
9910c1c0d53SJerry Wu         .Default([](Operation *) { return failure(); });
9920c1c0d53SJerry Wu   }
9930c1c0d53SJerry Wu 
9940c1c0d53SJerry Wu private:
9950c1c0d53SJerry Wu   ControlPropagationFn controlFn;
9960c1c0d53SJerry Wu };
9970c1c0d53SJerry Wu 
9989f242404SLorenzo Chelini // TODO: Relax this restriction. We should unpack a generic op also
9996bb0ab0dSLorenzo Chelini // in the presence of multiple unpack ops as producers.
10006bb0ab0dSLorenzo Chelini /// Return the unpacked operand, if present, for the current generic op.
10016bb0ab0dSLorenzo Chelini static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
10026bb0ab0dSLorenzo Chelini   OpOperand *unPackedOperand = nullptr;
10036bb0ab0dSLorenzo Chelini   for (OpOperand &operand : genericOp->getOpOperands()) {
10046bb0ab0dSLorenzo Chelini     auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>();
10056bb0ab0dSLorenzo Chelini     if (!unPackOp)
10066bb0ab0dSLorenzo Chelini       continue;
10076bb0ab0dSLorenzo Chelini     if (unPackedOperand)
10086bb0ab0dSLorenzo Chelini       return failure();
10096bb0ab0dSLorenzo Chelini     unPackedOperand = &operand;
10106bb0ab0dSLorenzo Chelini   }
10116bb0ab0dSLorenzo Chelini   if (!unPackedOperand)
10126bb0ab0dSLorenzo Chelini     return failure();
10136bb0ab0dSLorenzo Chelini   return unPackedOperand;
10146bb0ab0dSLorenzo Chelini }
10156bb0ab0dSLorenzo Chelini 
10169f242404SLorenzo Chelini /// Push down a tensor.unpack op through a generic op.
10176bb0ab0dSLorenzo Chelini /// The new generic op works on packed domain; pack ops are created for input
10186bb0ab0dSLorenzo Chelini /// and output operands. A tensor.unpack op is inserted right after the packed
10196bb0ab0dSLorenzo Chelini /// generic. E.g.
10206bb0ab0dSLorenzo Chelini ///
10216bb0ab0dSLorenzo Chelini /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
10226bb0ab0dSLorenzo Chelini ///
10236bb0ab0dSLorenzo Chelini /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
10246bb0ab0dSLorenzo Chelini ///
10256bb0ab0dSLorenzo Chelini /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
10266bb0ab0dSLorenzo Chelini /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
10276bb0ab0dSLorenzo Chelini ///                          inner_dims_pos = [3] inner_tiles = [32] into %0
10286bb0ab0dSLorenzo Chelini /// %2 = linalg.generic {indexing_maps = [#map],
10296bb0ab0dSLorenzo Chelini ///      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
10306bb0ab0dSLorenzo Chelini ///      outs(%1 : tensor<12x56x56x64xf32>) {
10316bb0ab0dSLorenzo Chelini ///      ^bb0(%out : f32):
10326bb0ab0dSLorenzo Chelini ///         linalg.yield %out : f32
10336bb0ab0dSLorenzo Chelini ///      } -> tensor<12x56x56x64xf32>
10346bb0ab0dSLorenzo Chelini ///
10356bb0ab0dSLorenzo Chelini /// will be converted to
10366bb0ab0dSLorenzo Chelini ///
10376bb0ab0dSLorenzo Chelini /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
10386bb0ab0dSLorenzo Chelini ///
10396bb0ab0dSLorenzo Chelini /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
10406bb0ab0dSLorenzo Chelini /// %1 = linalg.generic {indexing_maps = [#map],
10416bb0ab0dSLorenzo Chelini ///      iterator_types = ["parallel", "parallel", "parallel",
10426bb0ab0dSLorenzo Chelini ///                        "parallel", "parallel"]}
10436bb0ab0dSLorenzo Chelini ///      outs(%arg0 : tensor<12x2x56x56x32xf32>) {
10446bb0ab0dSLorenzo Chelini ///      ^bb0(%out : f32):
10456bb0ab0dSLorenzo Chelini ///         linalg.yield %out : f32
10466bb0ab0dSLorenzo Chelini ///      } -> tensor<12x2x56x56x32xf32>
10476bb0ab0dSLorenzo Chelini /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2]
10486bb0ab0dSLorenzo Chelini ///                       inner_dims_pos = [3] inner_tiles = [32] into %0
10496bb0ab0dSLorenzo Chelini ///
10506bb0ab0dSLorenzo Chelini static FailureOr<std::tuple<GenericOp, Value>>
105104fc471fSHan-Chung Wang pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
105204fc471fSHan-Chung Wang                                  ControlPropagationFn controlFn) {
10536bb0ab0dSLorenzo Chelini   if (genericOp.getNumResults() != 1)
10546bb0ab0dSLorenzo Chelini     return failure();
10556bb0ab0dSLorenzo Chelini 
1056b4563ee1SQuinn Dawkins   if (hasGatherSemantics(genericOp))
1057b4563ee1SQuinn Dawkins     return failure();
1058b4563ee1SQuinn Dawkins 
10596bb0ab0dSLorenzo Chelini   // Collect the unPacked operand, if present.
10606bb0ab0dSLorenzo Chelini   auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
10616bb0ab0dSLorenzo Chelini   if (failed(maybeUnPackedOperand))
10626bb0ab0dSLorenzo Chelini     return failure();
10636bb0ab0dSLorenzo Chelini   OpOperand *unPackedOperand = *(maybeUnPackedOperand);
10646bb0ab0dSLorenzo Chelini 
10656bb0ab0dSLorenzo Chelini   // Extract packing information.
10666bb0ab0dSLorenzo Chelini   tensor::UnPackOp producerUnPackOp =
10676bb0ab0dSLorenzo Chelini       unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
10686bb0ab0dSLorenzo Chelini   assert(producerUnPackOp && "expect a valid UnPackOp");
106904fc471fSHan-Chung Wang 
107004fc471fSHan-Chung Wang   if (!controlFn(unPackedOperand))
107104fc471fSHan-Chung Wang     return failure();
107204fc471fSHan-Chung Wang 
1073b4563ee1SQuinn Dawkins   auto packInfo =
1074b4563ee1SQuinn Dawkins       getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1075b4563ee1SQuinn Dawkins   if (failed(packInfo))
1076b4563ee1SQuinn Dawkins     return failure();
10776bb0ab0dSLorenzo Chelini 
10786bb0ab0dSLorenzo Chelini   // Rebuild the indexing map for the corresponding init operand.
10796bb0ab0dSLorenzo Chelini   auto [packedOutOperand, packedOutIndexingMap] =
1080b4563ee1SQuinn Dawkins       getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
10816bb0ab0dSLorenzo Chelini                                      genericOp, genericOp.getDpsInitOperand(0));
1082b4563ee1SQuinn Dawkins   auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
10836bb0ab0dSLorenzo Chelini 
10846bb0ab0dSLorenzo Chelini   // If the dps init operand of the generic is a tensor.empty, do not pack it
10856bb0ab0dSLorenzo Chelini   // and forward the new tensor.empty as a destination.
10866bb0ab0dSLorenzo Chelini   Value dest = packedOutOperand;
10876bb0ab0dSLorenzo Chelini   if (auto initTensor = genericOp.getDpsInitOperand(0)
10886bb0ab0dSLorenzo Chelini                             ->get()
10896bb0ab0dSLorenzo Chelini                             .getDefiningOp<tensor::EmptyOp>()) {
1090b4563ee1SQuinn Dawkins     if (destPack)
1091b4563ee1SQuinn Dawkins       dest = destPack.getDest();
10926bb0ab0dSLorenzo Chelini   }
10936bb0ab0dSLorenzo Chelini 
10946bb0ab0dSLorenzo Chelini   // Pack the genericOp.
10959f242404SLorenzo Chelini   GenericOp newGenericOp =
10969f242404SLorenzo Chelini       packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
1097b4563ee1SQuinn Dawkins   Value newResult =
1098b4563ee1SQuinn Dawkins       newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
10996bb0ab0dSLorenzo Chelini 
1100b4563ee1SQuinn Dawkins   // If the output is unaffected, no need to unpack.
1101b4563ee1SQuinn Dawkins   if (!destPack)
1102b4563ee1SQuinn Dawkins     return std::make_tuple(newGenericOp, newResult);
1103b4563ee1SQuinn Dawkins 
1104b4563ee1SQuinn Dawkins   auto mixedTiles = destPack.getMixedTiles();
1105b4563ee1SQuinn Dawkins   auto innerDimsPos = destPack.getInnerDimsPos();
1106b4563ee1SQuinn Dawkins   auto outerDimsPerm = destPack.getOuterDimsPerm();
1107b4563ee1SQuinn Dawkins 
11086bb0ab0dSLorenzo Chelini   // Insert an unPackOp right after the packed generic.
11096bb0ab0dSLorenzo Chelini   Value unPackOpRes =
11106bb0ab0dSLorenzo Chelini       rewriter
1111536486fbSAbhishek Varma           .create<tensor::UnPackOp>(genericOp.getLoc(), newResult,
1112536486fbSAbhishek Varma                                     destPack.getSource(), innerDimsPos,
1113b4563ee1SQuinn Dawkins                                     mixedTiles, outerDimsPerm)
11146bb0ab0dSLorenzo Chelini           .getResult();
11156bb0ab0dSLorenzo Chelini 
11166bb0ab0dSLorenzo Chelini   return std::make_tuple(newGenericOp, unPackOpRes);
11176bb0ab0dSLorenzo Chelini }
11186bb0ab0dSLorenzo Chelini 
1119b4563ee1SQuinn Dawkins // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
1120b4563ee1SQuinn Dawkins struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
1121b4563ee1SQuinn Dawkins public:
1122b4563ee1SQuinn Dawkins   PushDownUnPackOpThroughGenericOp(MLIRContext *context,
1123b4563ee1SQuinn Dawkins                                    ControlPropagationFn fun)
1124b4563ee1SQuinn Dawkins       : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
11256bb0ab0dSLorenzo Chelini 
11266bb0ab0dSLorenzo Chelini   LogicalResult matchAndRewrite(GenericOp genericOp,
11276bb0ab0dSLorenzo Chelini                                 PatternRewriter &rewriter) const override {
112804fc471fSHan-Chung Wang     auto genericAndRepl =
112904fc471fSHan-Chung Wang         pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
11306bb0ab0dSLorenzo Chelini     if (failed(genericAndRepl))
11316bb0ab0dSLorenzo Chelini       return failure();
11326bb0ab0dSLorenzo Chelini     rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
11336bb0ab0dSLorenzo Chelini     return success();
11346bb0ab0dSLorenzo Chelini   }
1135b4563ee1SQuinn Dawkins 
1136b4563ee1SQuinn Dawkins private:
1137b4563ee1SQuinn Dawkins   ControlPropagationFn controlFn;
11386bb0ab0dSLorenzo Chelini };
11396bb0ab0dSLorenzo Chelini 
114030d542f9SLorenzo Chelini /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
114130d542f9SLorenzo Chelini /// add as many zero padding dimensions in `high` and `low` based on the number
114230d542f9SLorenzo Chelini /// of point loops.
114330d542f9SLorenzo Chelini struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1144b4563ee1SQuinn Dawkins   PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1145b4563ee1SQuinn Dawkins       : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
114630d542f9SLorenzo Chelini 
114730d542f9SLorenzo Chelini   LogicalResult matchAndRewrite(tensor::PadOp padOp,
114830d542f9SLorenzo Chelini                                 PatternRewriter &rewriter) const override {
114930d542f9SLorenzo Chelini     tensor::UnPackOp unpackOp =
115030d542f9SLorenzo Chelini         padOp.getSource().getDefiningOp<tensor::UnPackOp>();
115130d542f9SLorenzo Chelini     if (!unpackOp)
115230d542f9SLorenzo Chelini       return failure();
115330d542f9SLorenzo Chelini 
115404fc471fSHan-Chung Wang     if (!controlFn(&padOp.getSourceMutable()))
1155b4563ee1SQuinn Dawkins       return failure();
1156b4563ee1SQuinn Dawkins 
115730d542f9SLorenzo Chelini     Location loc = padOp.getLoc();
115830d542f9SLorenzo Chelini     // Bail out if one of the padded dimension is a tiled one.
115930d542f9SLorenzo Chelini     llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
116030d542f9SLorenzo Chelini     ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
116130d542f9SLorenzo Chelini     llvm::SmallBitVector innerDims(paddedDims.size());
116230d542f9SLorenzo Chelini     for (int64_t dim : innerDimsPos)
116330d542f9SLorenzo Chelini       innerDims.flip(dim);
116430d542f9SLorenzo Chelini     if (paddedDims.anyCommon(innerDims))
116530d542f9SLorenzo Chelini       return failure();
116630d542f9SLorenzo Chelini 
116730d542f9SLorenzo Chelini     Value paddingVal = padOp.getConstantPaddingValue();
116830d542f9SLorenzo Chelini     if (!paddingVal)
116930d542f9SLorenzo Chelini       return failure();
117030d542f9SLorenzo Chelini 
117130d542f9SLorenzo Chelini     // If we have `outer_dims_perms` we need to adjust the padded dimensions.
117230d542f9SLorenzo Chelini     ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
117330d542f9SLorenzo Chelini     SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
117430d542f9SLorenzo Chelini     SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
117530d542f9SLorenzo Chelini     if (!outerDimsPerm.empty()) {
117630d542f9SLorenzo Chelini       applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
117730d542f9SLorenzo Chelini       applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
117830d542f9SLorenzo Chelini     }
117930d542f9SLorenzo Chelini     // Add zero padding for the point loops.
118030d542f9SLorenzo Chelini     size_t pointLoopsSize = innerDimsPos.size();
118130d542f9SLorenzo Chelini     lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
118230d542f9SLorenzo Chelini     highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
118330d542f9SLorenzo Chelini 
118430d542f9SLorenzo Chelini     auto newPadOp = rewriter.create<tensor::PadOp>(
118530d542f9SLorenzo Chelini         loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
118630d542f9SLorenzo Chelini         paddingVal, padOp.getNofold());
118730d542f9SLorenzo Chelini 
118830d542f9SLorenzo Chelini     // Inject the tensor.unpack right after the packed padOp.
118930d542f9SLorenzo Chelini     Value outputUnPack = rewriter.create<tensor::EmptyOp>(
119030d542f9SLorenzo Chelini         loc, padOp.getResultType().getShape(),
119130d542f9SLorenzo Chelini         padOp.getResultType().getElementType());
119230d542f9SLorenzo Chelini 
119330d542f9SLorenzo Chelini     Value replacement = rewriter.create<tensor::UnPackOp>(
119430d542f9SLorenzo Chelini         loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
119530d542f9SLorenzo Chelini         unpackOp.getMixedTiles(), outerDimsPerm);
119630d542f9SLorenzo Chelini     rewriter.replaceOp(padOp, replacement);
119730d542f9SLorenzo Chelini     return success();
119830d542f9SLorenzo Chelini   }
1199b4563ee1SQuinn Dawkins 
1200b4563ee1SQuinn Dawkins private:
1201b4563ee1SQuinn Dawkins   ControlPropagationFn controlFn;
120230d542f9SLorenzo Chelini };
120330d542f9SLorenzo Chelini 
12040f297cadSHanhan Wang } // namespace
12050f297cadSHanhan Wang 
12060f297cadSHanhan Wang void mlir::linalg::populateDataLayoutPropagationPatterns(
1207b4563ee1SQuinn Dawkins     RewritePatternSet &patterns,
1208b4563ee1SQuinn Dawkins     const ControlPropagationFn &controlPackUnPackPropagation) {
1209886294a2SQuinn Dawkins   patterns
1210886294a2SQuinn Dawkins       .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
12110c1c0d53SJerry Wu               BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
12120c1c0d53SJerry Wu               PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1213b4563ee1SQuinn Dawkins           patterns.getContext(), controlPackUnPackPropagation);
12140f297cadSHanhan Wang }
1215