xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (revision 04fc471f485a9beadd8ccc63f6af29765ec6f45b)
1 //===- DataLayoutPropagation.cpp -----------------------------------------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Passes.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Tensor/Utils/Utils.h"
17 #include "mlir/Dialect/Utils/IndexingUtils.h"
18 #include "mlir/IR/Dominance.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "llvm/ADT/SetOperations.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include <optional>
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
28 #include "mlir/Dialect/Linalg/Passes.h.inc"
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::linalg;
33 
34 #define DEBUG_TYPE "linalg-data-layout-propagation"
35 
36 namespace {
37 
38 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
39   for (Operation &op : genericOp.getBody()->getOperations())
40     if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
41       return true;
42   return false;
43 }
44 
45 // The struct contains the infomation about mapping packing information to
46 // the iteration domain of Linalg ops.
47 struct PackInfo {
48   int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
49   // InnerDimsPos on iteration domain, which follows the order in pack ops.
50   SmallVector<int64_t> tiledDimsPos;
51   // The sizes of tiling data dimensions on iteration domain.
52   llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
53   // The mapping from a dimension of iteration domain to the corresponding inner
54   // tiling dimension on iteration domain.
55   llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
56   // The permutation of outer dims (on domain).
57   SmallVector<int64_t> outerDimsOnDomainPerm;
58 };
59 
60 template <typename OpTy>
61 static FailureOr<PackInfo>
62 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
63                           OpTy packOrUnPackOp) {
64   static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
65                 "applies to only pack or unpack operations");
66   LLVM_DEBUG(
67       { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
68 
69   AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
70   SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
71   SmallVector<utils::IteratorType> iterators =
72       genericOp.getIteratorTypesArray();
73 
74   PackInfo packInfo;
75   int64_t origNumDims = indexingMap.getNumDims();
76   SmallVector<AffineExpr> exprs(indexingMap.getResults());
77   ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
78   for (auto [index, innerDimPos, tileSize] :
79        llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
80                        innerDimsPos, packOrUnPackOp.getMixedTiles())) {
81     auto expr = exprs[innerDimPos];
82     if (!isa<AffineDimExpr>(expr))
83       return failure();
84     int64_t domainDimPos =
85         cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
86     if (!isParallelIterator(iterators[domainDimPos]))
87       return failure();
88     packInfo.tiledDimsPos.push_back(domainDimPos);
89     packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
90     packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
91     LLVM_DEBUG({
92       llvm::dbgs() << "map innerDimPos=" << innerDimPos
93                    << " to iteration dimension (d" << domainDimPos << ", d"
94                    << packInfo.tileToPointMapping[domainDimPos]
95                    << "), which has size=("
96                    << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
97     });
98   }
99 
100   // Bail out if a tiled dimension is present in a map but not as an affine dim
101   // expression.
102   auto areAllAffineDimExpr = [&](int dim) {
103     for (AffineMap map : indexingMaps) {
104       if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
105             return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
106           })) {
107         return false;
108       }
109     }
110     return true;
111   };
112   for (int64_t i : packInfo.tiledDimsPos)
113     if (!areAllAffineDimExpr(i))
114       return failure();
115 
116   // Get the outer dims perm on the iteration domain. Start by identifying the
117   // set of domain dims affected by the outer permutation along with the
118   // permuted ordering for those dims. Then the full outer dims permutation can
119   // be constructed by replacing the affected dims with the permuted result in a
120   // numLoops-rank identity. e.g.
121   //   outerDimsPerm = [1, 2, 0]
122   //   indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
123   //
124   //   permutedOuterDims =        [4,    3, 1]
125   //   outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
126   //
127   // Non-affine dim expressions must not be permuted by the outer dims
128   // permutation.
129   SmallVector<int64_t> permutedOuterDims;
130   for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
131     auto permutedExpr = indexingMap.getResult(dim);
132     if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
133       permutedOuterDims.push_back(dimExpr.getPosition());
134       continue;
135     }
136 
137     // TODO: Allow propagation with transposes on non affine dim expressions,
138     // e.g. d0 + d1 which implies transposing both dims simultaneously while
139     // maintaining the relative position between them.
140     if (static_cast<int64_t>(index) != dim)
141       return failure();
142   }
143   if (!permutedOuterDims.empty()) {
144     int64_t outerDimIndex = 0;
145     llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
146                                                permutedOuterDims.end());
147     for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
148       packInfo.outerDimsOnDomainPerm.push_back(
149           permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
150                                          : i);
151     LLVM_DEBUG({
152       llvm::dbgs() << "map outer dimsDimsPerm to ";
153       for (auto dim : packInfo.outerDimsOnDomainPerm)
154         llvm::dbgs() << dim << " ";
155       llvm::dbgs() << "\n";
156     });
157   }
158 
159   return packInfo;
160 }
161 
162 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
163                                              ArrayRef<AffineExpr> exprs) {
164   // Compute `outer_dims_perm`. See example:
165   // current exprs      : (d0, d1, d2, d3) -> (d2, d3)
166   // perm               : [0, 3, 1, 2]
167   // First map d2, d3 with their position in the array as:
168   // currentPositionTileLoops: dim | pos
169   //                           d2  | 0
170   //                           d3  | 1
171   // then scan `perm` in order and get the `outer_dims_perm`
172   // to be used, here it would be [1, 0].
173   assert(!perm.empty() && "expect perm not to be empty");
174   assert(!exprs.empty() && "expect exprs not to be empty");
175   if (exprs.size() == 1)
176     return {};
177   SmallVector<int64_t> outerDimsPerm;
178   DenseMap<int64_t, int64_t> currentPositionTileLoops;
179   for (auto [pos, expr] : llvm::enumerate(exprs)) {
180     // Here we rely on the assumption that the outer dims permutation
181     // when propagating currently requires that non-affine dim expressions
182     // are not permuted, thus allowing the identity assignment below.
183     if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
184       currentPositionTileLoops[dimExpr.getPosition()] = pos;
185     else
186       currentPositionTileLoops[pos] = pos;
187   }
188   for (int64_t loopIdx : perm) {
189     if (currentPositionTileLoops.count(loopIdx))
190       outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
191   }
192   return outerDimsPerm;
193 }
194 
195 /// Returns a tuple for packed operand and indexing_map with the assumptions:
196 ///   1) The generic op is the producer of the pack op.
197 ///   2) The generic op has only one result.
198 /// If the operand is a scalar or packing dimensions are all irrelevant to the
199 /// operand, the operand and the updated indexing map will be returned.
200 /// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
201 ///
202 ///   #map0 = affine_map<(d0, d1) -> (d0, d1)>
203 ///   #map1 = affine_map<(d0, d1) -> (d0)>
204 ///   #map2 = affine_map<(d0, d1) -> (d1)>
205 ///   %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
206 ///                        iterator_types = ["parallel", "parallel"]}
207 ///      ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
208 ///      outs(%init : tensor<?x?xf32>) {
209 ///    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
210 ///      %4 = arith.addf %arg3, %arg4 : f32
211 ///      linalg.yield %4 : f32
212 ///  } -> tensor<?x?xf32>
213 ///  %1 = tensor.pack %0
214 ///    inner_dims_pos = [0, 1]
215 ///    inner_tiles = [8, 2]
216 ///    into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
217 ///
218 ///  Taking the first input operand as an example, the inner tile size of d1 is
219 ///  8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
220 ///  affine_map<(d1, d3)>` will be returned.
221 ///
222 ///  %pack = tensor.pack %arg0
223 ///    inner_dims_pos = [0]
224 ///    inner_tiles = [8]
225 ///    into %init : tensor<?xf32> -> tensor<?x8xf32>
226 static std::tuple<Value, AffineMap>
227 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
228                                GenericOp genericOp, OpOperand *opOperand) {
229   int64_t numOrigLoops = genericOp.getNumLoops();
230   int64_t numInnerLoops = packInfo.getNumTiledLoops();
231   int64_t numLoops = numOrigLoops + numInnerLoops;
232   AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
233   llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
234   SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
235 
236   // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
237   if (genericOp.isScalar(opOperand) || exprs.empty())
238     return std::make_tuple(opOperand->get(),
239                            AffineMap::get(numLoops, 0, exprs, b.getContext()));
240 
241   // Step 1. Construct the information of packing data dimensions; append inner
242   // dimensions to the indexing maps for the operand.
243   for (auto [index, expr] : llvm::enumerate(exprs)) {
244     if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
245       int64_t dimPos = dimExpr.getPosition();
246       domainDimToOperandDim[dimPos] = index;
247       continue;
248     }
249   }
250   SmallVector<int64_t> innerDimsPos;
251   SmallVector<OpFoldResult> innerTileSizes;
252   for (auto dimPos : packInfo.tiledDimsPos) {
253     if (!domainDimToOperandDim.count(dimPos))
254       continue;
255     int64_t index = domainDimToOperandDim[dimPos];
256     innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
257     innerDimsPos.push_back(index);
258     exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
259   }
260 
261   // Step 2. Handle outer dim permutations.
262   SmallVector<int64_t> outerDimsPerm;
263   if (!packInfo.outerDimsOnDomainPerm.empty()) {
264     outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
265 
266     // Step 2.1: Fold transpose into the linalg.generic.
267     SmallVector<int64_t> inversedOuterPerm =
268         invertPermutationVector(packInfo.outerDimsOnDomainPerm);
269     for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
270       if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
271         int64_t dimPos = dimExpr.getPosition();
272         exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
273         continue;
274       }
275       assert(isa<AffineConstantExpr>(exprs[i]) &&
276              "Attempted to permute non-constant and non-affine dim expression");
277     }
278     // Step 2.2: Undo the transposition on `exprs` and propagate the
279     // transposition on the pack using outerDimsPerm.
280     if (!outerDimsPerm.empty()) {
281       SmallVector<AffineExpr> auxVec = exprs;
282       for (const auto &en : enumerate(outerDimsPerm))
283         auxVec[en.index()] = exprs[en.value()];
284       exprs = auxVec;
285     }
286   }
287   auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
288 
289   // The operand does not have dimensions that relates to pack op.
290   if (innerDimsPos.empty() && outerDimsPerm.empty())
291     return std::make_tuple(opOperand->get(), indexingMap);
292 
293   auto empty = tensor::PackOp::createDestinationTensor(
294       b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
295   auto packedOperand = b.create<tensor::PackOp>(
296       loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
297       /*padding=*/std::nullopt, outerDimsPerm);
298   return std::make_tuple(packedOperand, indexingMap);
299 }
300 
301 /// Pack a genericOp and return it.
302 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
303                                Value dest, AffineMap packedOutIndexingMap,
304                                const PackInfo &packInfo) {
305   Location loc = genericOp.getLoc();
306   SmallVector<Value> inputOperands;
307   SmallVector<AffineMap> indexingMaps;
308   for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
309     auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
310         rewriter, loc, packInfo, genericOp, inputOperand);
311     inputOperands.push_back(packedOperand);
312     indexingMaps.push_back(packedIndexingMap);
313   }
314 
315   int64_t numInnerLoops = packInfo.getNumTiledLoops();
316   SmallVector<utils::IteratorType> iterTypes =
317       genericOp.getIteratorTypesArray();
318   iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
319 
320   indexingMaps.push_back(packedOutIndexingMap);
321 
322   auto newGenericOp = rewriter.create<linalg::GenericOp>(
323       loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
324       /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
325   rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
326                              newGenericOp.getRegion().begin());
327   return newGenericOp;
328 }
329 
330 /// Bubbles up tensor.pack op through a producer generic op. This
331 /// swap pack(generic) to generic(pack). The new generic op works on packed
332 /// domain; pack ops are created for input and output operands. E.g.,
333 ///
334 ///     #map0 = affine_map<(d0, d1) -> (d0, d1)>
335 ///     %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
336 ///     %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
337 ///     %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
338 ///     %3 = linalg.generic {indexing_maps = [#map0, #map0],
339 ///                          iterator_types = ["parallel", "parallel"]}
340 ///         ins(%arg0 : tensor<?x?xf32>)
341 ///         outs(%2 : tensor<?x?xf32>) {
342 ///       ^bb0(%arg3: f32, %arg4: f32):
343 ///         %4 = arith.addf %arg3, %arg3 : f32
344 ///         linalg.yield %4 : f32
345 ///     } -> tensor<?x?xf32>
346 ///     %4 = tensor.pack %3
347 ///       inner_dims_pos = [0, 1]
348 ///       inner_tiles = [8, 2]
349 ///       into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
350 ///
351 /// will be converted to
352 ///
353 ///     #map = affine_map<()[s0] -> (s0 ceildiv 8)>
354 ///     #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
355 ///     #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
356 ///     %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
357 ///     %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
358 ///     %0 = affine.apply #map()[%dim]
359 ///     %1 = affine.apply #map1()[%dim_0]
360 ///     %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
361 ///     %pack = tensor.pack %arg0
362 ///       inner_dims_pos = [0, 1]
363 ///       inner_tiles = [8, 2]
364 ///       into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
365 ///     %3 = linalg.generic {indexing_maps = [#map2, #map2],
366 ///       iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
367 ///       ins(%pack : tensor<?x?x8x2xf32>)
368 ///       outs(%arg1 : tensor<?x?x8x2xf32>) {
369 ///     ^bb0(%in: f32, %out: f32):
370 ///       %4 = arith.addf %in, %in : f32
371 ///       linalg.yield %4 : f32
372 ///     } -> tensor<?x?x8x2xf32>
373 static FailureOr<GenericOp>
374 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
375                                const ControlPropagationFn &controlFn) {
376   auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
377   if (!genericOp)
378     return failure();
379 
380   // User controlled propagation function.
381   if (!controlFn(&packOp.getSourceMutable()))
382     return failure();
383 
384   // TODO: Enable propagation in the presence of linalg.index and
385   // tensor.extract, likely as a separate pattern as the pack information and
386   // propagation decision needs to be inferred from the region of the generic.
387   if (hasGatherSemantics(genericOp))
388     return failure();
389 
390   // TODO: Relax the restriction. We are able to bubble up the pack op through
391   // multi-result generic op. It just needs more work.
392   if (genericOp.getNumResults() != 1)
393     return failure();
394 
395   // Bail-out if the result of the generic has multiple uses, as bubbling up
396   // creates recomputation if the generic has multiple users.
397   // TODO: Enable the case where every use is an identical pack op as no
398   // recomputation is needed in that case.
399   if (!genericOp->getResult(0).hasOneUse())
400     return failure();
401 
402   // We want to move the pack not the generic.
403   OpBuilder::InsertionGuard guard(rewriter);
404   rewriter.setInsertionPoint(genericOp);
405 
406   // We need to handle two cases:
407   // 1) The tensor.pack destination is a tensor.empty. If this is the case, we
408   // create a new tensor.empty to avoid breaking dominance, as we are moving the
409   // tensor.pack above the linalg.generic.
410   // 2) The destination is not a tensor.empty. In this case we can replace only
411   // if the destination of the tensor.pack dominates the linalg.generic.
412   Value packOpDest = packOp.getDest();
413   if (!packOpDest.hasOneUse())
414     return failure();
415   if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
416     packOpDest = rewriter.create<tensor::EmptyOp>(
417         genericOp->getLoc(), emptyOp.getMixedSizes(),
418         emptyOp.getType().getElementType());
419   } else {
420     DominanceInfo dom(genericOp);
421     if (!dom.properlyDominates(packOpDest, genericOp))
422       return failure();
423   }
424 
425   // TODO: Add an option for allowing padding values. It could introduce
426   // undefined behavior if we unconditionally propagate pack op through all
427   // the ops. E.g., if the padding value is zero and there are division ops in
428   // a generic op. Some values of padding area could be NaN (0/0).
429   if (packOp.getPaddingValue())
430     return failure();
431 
432   OpOperand *opOperand = genericOp.getDpsInitOperand(0);
433   auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
434   if (failed(packInfo))
435     return failure();
436 
437   // Rebuild the indexing map for the corresponding init operand.
438   auto [packedOutOperand, packedOutIndexingMap] =
439       getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
440                                      genericOp, opOperand);
441 
442   // If the dps init operand of the generic is a tensor.empty forward the pack
443   // op destination.
444   Value dest = packedOutOperand;
445   if (auto initTensor = genericOp.getDpsInitOperand(0)
446                             ->get()
447                             .getDefiningOp<tensor::EmptyOp>()) {
448     dest = packOpDest;
449   }
450   return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
451                        *packInfo);
452 }
453 
454 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
455 struct BubbleUpPackOpThroughGenericOpPattern
456     : public OpRewritePattern<tensor::PackOp> {
457 public:
458   BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
459                                         ControlPropagationFn fun)
460       : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
461 
462   LogicalResult matchAndRewrite(tensor::PackOp packOp,
463                                 PatternRewriter &rewriter) const override {
464     auto genericOp =
465         bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
466     if (failed(genericOp))
467       return failure();
468     rewriter.replaceOp(packOp, genericOp->getResults());
469     return success();
470   }
471 
472 private:
473   ControlPropagationFn controlFn;
474 };
475 
476 /// Propagate a tensor.pack operation up through a tensor.pad. The idea is to
477 /// add as many zero padding dimensions in `high` and `low` based on the number
478 /// of point loops.
479 class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
480 public:
481   BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
482       : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
483 
484   LogicalResult matchAndRewrite(tensor::PackOp packOp,
485                                 PatternRewriter &rewriter) const override {
486     auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
487     if (!padOp)
488       return failure();
489 
490     // User controlled propagation function.
491     if (!controlFn(&packOp.getSourceMutable()))
492       return failure();
493 
494     if (!padOp.getResult().hasOneUse())
495       return failure();
496 
497     // TODO: Enable padding when the padding values are the same.
498     if (packOp.getPaddingValue())
499       return failure();
500 
501     // Fail for non-constant padding values. The body of the pad could
502     // depend on the padding indices and/or properties of the padded
503     // tensor so for now we fail.
504     // TODO: Support non-constant padding values.
505     Value paddingVal = padOp.getConstantPaddingValue();
506     if (!paddingVal)
507       return failure();
508 
509     if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
510       return failure();
511 
512     ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
513     ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
514 
515     // Bail out if one of the padded dimension is a tiled one.
516     llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
517     llvm::SmallBitVector innerDims(paddedDims.size());
518     for (int64_t dim : innerDimsPos)
519       innerDims.flip(dim);
520     if (paddedDims.anyCommon(innerDims))
521       return failure();
522 
523     Location loc = padOp->getLoc();
524     OpBuilder::InsertionGuard guard(rewriter);
525     rewriter.setInsertionPoint(padOp);
526 
527     auto empty = tensor::PackOp::createDestinationTensor(
528         rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos,
529         outerDimsPerm);
530     Value packedSource = rewriter.create<tensor::PackOp>(
531         loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(),
532         /*padding=*/std::nullopt, outerDimsPerm);
533 
534     // If we have `outer_dims_perms` we need to adjust the padded dimensions.
535     SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
536     SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
537     if (!outerDimsPerm.empty()) {
538       applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
539       applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
540     }
541     // The tiled dimensions were verified to be unpadded above, so here we
542     // just append 0 for the inner tile dimensions.
543     size_t pointLoopsSize = innerDimsPos.size();
544     lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
545     highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
546 
547     auto newPadOp = rewriter.create<tensor::PadOp>(
548         loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal,
549         padOp.getNofold());
550     rewriter.replaceOp(packOp, newPadOp.getResult());
551     return success();
552   }
553 
554 private:
555   ControlPropagationFn controlFn;
556 };
557 
558 /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
559 ///
560 /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
561 /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
562 /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
563 /// non-unit projected dims in pos [2, 3] is 2.
564 ///
565 /// If all candidates in a reassociation are unit dims, it chooses the
566 /// inner-most dim pos.
567 static SmallVector<int64_t>
568 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
569                                  ArrayRef<ReassociationIndices> reassocIndices,
570                                  ArrayRef<int64_t> targetShape) {
571   SmallVector<int64_t> projectedDimsPos;
572   for (auto pos : dimsPos) {
573     // In the case all dims are unit, this will return the inner-most one.
574     int64_t projectedPos = reassocIndices[pos].back();
575     for (auto i : llvm::reverse(reassocIndices[pos])) {
576       int64_t dim = targetShape[i];
577       if (dim > 1 || ShapedType::isDynamic(dim)) {
578         projectedPos = i;
579         break;
580       }
581     }
582     projectedDimsPos.push_back(projectedPos);
583   }
584   return projectedDimsPos;
585 }
586 
587 /// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
588 static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
589                                        ArrayRef<int64_t> shape,
590                                        ArrayRef<int64_t> tileSizes) {
591   for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
592     int64_t dim = shape[pos];
593     if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
594       return false;
595   }
596   return true;
597 }
598 
599 /// Permutate the reassociation indices and reindex them in the sequence order.
600 /// Returns the next dim pos in the sequence.
601 ///
602 /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
603 /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
604 /// [[0], [1, 2]].
605 static int64_t applyPermutationAndReindexReassoc(
606     SmallVector<ReassociationIndices> &reassocIndices,
607     ArrayRef<int64_t> permutation) {
608   if (!permutation.empty())
609     applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
610   int64_t nextPos = 0;
611   for (ReassociationIndices &indices : reassocIndices) {
612     for (auto &index : indices) {
613       index = nextPos;
614       nextPos += 1;
615     }
616   }
617   return nextPos;
618 }
619 
620 /// Bubble up pack op through collapse shape op when the packed dims can be
621 /// projected to the dims before collapsing. This is possible when the inner
622 /// tile sizes can divide the projected dims.
623 ///
624 /// For example:
625 ///
626 /// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
627 ///     : tensor<?x16x4xf32> into tensor<?x4xf32>
628 /// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1]
629 ///     inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
630 ///     : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
631 ///
632 /// can be transformed into:
633 ///
634 /// %pack = tensor.pack %in outer_dims_perm = [1, 2]
635 ///     inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
636 ///     : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
637 /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
638 ///     : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
639 static LogicalResult
640 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
641                                    tensor::PackOp packOp,
642                                    PatternRewriter &rewriter) {
643   SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
644   ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
645   ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
646 
647   ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
648   SmallVector<ReassociationIndices> reassocIndices =
649       collapseOp.getReassociationIndices();
650   // Project inner tile pos to the dim pos before collapsing. For example, if
651   // dims [x, y] is collapsed into [z], packing on dim z can be projected back
652   // to pack on dim y.
653   //
654   // Project to inner-most non-unit dims to increase the chance that they can be
655   // divided by the inner tile sizes. This is correct because for [..., x, 1],
656   // packing on dim 1 is equivalent to packing on dim x.
657   SmallVector<int64_t> projectedInnerDimsPos =
658       projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
659 
660   if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
661                                   innerTileSizes)) {
662     return failure();
663   }
664   // Expand the outer dims permutation with the associated source dims for the
665   // new permutation after bubbling. This is because moving a collapsed dim is
666   // equivalent to moving the associated source dims together.
667   SmallVector<int64_t> newOuterDimsPerm;
668   for (auto outerPos : outerDimsPerm) {
669     newOuterDimsPerm.insert(newOuterDimsPerm.end(),
670                             reassocIndices[outerPos].begin(),
671                             reassocIndices[outerPos].end());
672   }
673 
674   auto emptyOp = tensor::PackOp::createDestinationTensor(
675       rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
676       projectedInnerDimsPos, newOuterDimsPerm);
677   auto newPackOp = rewriter.create<tensor::PackOp>(
678       packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
679       packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
680 
681   SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
682   // First apply the permutation on the reassociations of the outer dims.
683   // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
684   // -> [[0], [1, 2]]
685   int64_t nextPos =
686       applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
687   // Then add direct mapping for the inner tile dims.
688   for (size_t i = 0; i < innerDimsPos.size(); ++i) {
689     newReassocIndices.push_back({nextPos});
690     nextPos += 1;
691   }
692 
693   auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
694       collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
695   rewriter.replaceOp(packOp, newCollapseOp);
696 
697   return success();
698 }
699 
700 /// Project dimsPos to their collapsed positions in the reassocIndices.
701 ///
702 /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
703 /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
704 /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
705 /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
706 static SmallVector<int64_t>
707 projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
708                              ArrayRef<ReassociationIndices> reassocIndices) {
709   SmallVector<int64_t> projectedPos;
710 
711   // Map each dimension to the position of corresponding reassociation index.
712   for (auto pos : dimsPos) {
713     for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
714       // If the dimension is present in the current indices group, the group
715       // position within the reassociation map is the desired projected
716       // dimension position.
717       if (llvm::any_of(indices,
718                        [&](int64_t expandDim) { return expandDim == pos; })) {
719         projectedPos.push_back(idx);
720         break;
721       }
722     }
723   }
724   assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
725 
726   return projectedPos;
727 }
728 
729 /// Bubble up pack op through expand shape op.
730 ///
731 /// For example:
732 ///
733 /// %expand = tensor.expand_shape %in [[0], [1, 2]]
734 ///     : tensor<?x64xf32> into tensor<?x4x16xf32>
735 /// %pack = tensor.pack %expand outer_dims_perm = [0, 1]
736 ///     inner_dims_pos = [2] inner_tiles = [8] into %empty
737 ///     : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
738 ///
739 /// can be transformed into:
740 ///
741 /// %pack = tensor.pack %in outer_dims_perm = [1, 2]
742 ///     inner_dims_pos = [1] inner_tiles = [8] into %empty
743 ///     : tensor<?x64xf32> -> tensor<?x8x8xf32>
744 /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
745 ///     : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
746 static LogicalResult
747 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
748                                  tensor::PackOp packOp,
749                                  PatternRewriter &rewriter) {
750   // Outer dimensions permutation is not supported currently.
751   // TODO: Handle outer_dims_perm variants.
752   ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
753   if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
754     return rewriter.notifyMatchFailure(packOp,
755                                        "non-identity outer dims perm NYI");
756   }
757 
758   // Validate dimensions' relations between shape expansion and packing.
759   SmallVector<ReassociationIndices, 4> reassoc =
760       expandOp.getReassociationIndices();
761   ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
762   llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(),
763                                        packInnerDims.end());
764 
765   for (auto [idx, indices] : llvm::enumerate(reassoc)) {
766     // For each expand_shape reassociation, figure out which dimensions get
767     // packed if any.
768     llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end());
769     llvm::SetVector<int64_t> packedDims =
770         llvm::set_intersection(packDimsPos, expandDimPos);
771 
772     // The expanded dimension is not packed so, it does not affect moving pack
773     // before shape expansion - simply continue.
774     if (packedDims.empty())
775       continue;
776     // Shape expansion cannot be propagated when multiple expanded dimension are
777     // packed - in this case operation reordering would affect final element
778     // positions and/or shapes can no longer be projected.
779     if (packedDims.size() != 1)
780       return rewriter.notifyMatchFailure(
781           packOp, "only one of the expanded dimensions can be packed");
782     // Only the inner-most expanded dimension should be packed. Otherwise,
783     // elements order will be affected after operation reordering.
784     if (packedDims.front() != indices.back())
785       return rewriter.notifyMatchFailure(
786           packOp, "can only pack the inner-most expanded dimension");
787   }
788 
789   // Project pack.inner_dims_pos to positions before shape expansion.
790   SmallVector<int64_t> projectedInnerDimsPos =
791       projectDimsPosIntoReassocPos(packInnerDims, reassoc);
792 
793   // Project the shape expansion to new packed shape.
794   // The pack.outer_dims_perm is restricted to identity so, the permutation can
795   // be omitted for simplicity.
796   // TODO: Account for outer dimensions permutation.
797   //
798   // If reassociation is not possible, then reordering cannot happen.
799   // This can be caused by pack padding affecting previously expanded
800   // dimensions or packing extending dimensions.
801   RankedTensorType newPackType = tensor::PackOp::inferPackedType(
802       expandOp.getSrcType(), packOp.getStaticInnerTiles(),
803       projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
804   auto reassocExpand =
805       getReassociationIndicesForReshape(newPackType, packOp.getDestType());
806   if (!reassocExpand)
807     return rewriter.notifyMatchFailure(
808         packOp, "could not reassociate dims after bubbling up");
809 
810   Value destTensor = tensor::PackOp::createDestinationTensor(
811       rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
812       projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
813   Value packedVal = rewriter.create<tensor::PackOp>(
814       packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
815       packOp.getMixedTiles(), packOp.getPaddingValue(),
816       /*outerDimsPerm=*/SmallVector<int64_t>{});
817 
818   Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
819       packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
820   rewriter.replaceOp(packOp, newExpandOp);
821 
822   return success();
823 }
824 
825 class BubbleUpPackOpThroughReshapeOp final
826     : public OpRewritePattern<tensor::PackOp> {
827 public:
828   BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
829       : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
830 
831   LogicalResult matchAndRewrite(tensor::PackOp packOp,
832                                 PatternRewriter &rewriter) const override {
833     Operation *srcOp = packOp.getSource().getDefiningOp();
834     // Currently only support when the pack op is the only user.
835     if (!srcOp || !(srcOp->getNumResults() == 1) ||
836         !srcOp->getResult(0).hasOneUse()) {
837       return failure();
838     }
839     // Currently only support static inner tile sizes.
840     if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
841           return ShapedType::isDynamic(size);
842         })) {
843       return failure();
844     }
845 
846     // User controlled propagation function.
847     if (!controlFn(&packOp.getSourceMutable()))
848       return failure();
849 
850     return TypeSwitch<Operation *, LogicalResult>(srcOp)
851         .Case([&](tensor::CollapseShapeOp op) {
852           return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
853         })
854         .Case([&](tensor::ExpandShapeOp op) {
855           return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
856         })
857         .Default([](Operation *) { return failure(); });
858   }
859 
860 private:
861   ControlPropagationFn controlFn;
862 };
863 
864 /// Push down unpack op through expand shape op when the packed dims can be
865 /// projected to the dims after expanding. This is possible when the inner tile
866 /// sizes can divide the projected dims.
867 ///
868 /// For example:
869 ///
870 /// %unpack = tensor.unpack %in outer_dims_perm = [0, 1]
871 ///     inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
872 ///     : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
873 /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
874 ///     : tensor<?x256xf32> into tensor<?x256x256xf32>
875 ///
876 /// can be transformed into:
877 ///
878 /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
879 ///     : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
880 /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
881 ///     inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
882 ///     : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
883 static LogicalResult pushDownUnPackOpThroughExpandShape(
884     tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
885     PatternRewriter &rewriter, ControlPropagationFn controlFn) {
886   // User controlled propagation function.
887   if (!controlFn(&expandOp.getSrcMutable()))
888     return failure();
889 
890   SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
891   ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
892   ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
893 
894   auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
895   if (!expandTy)
896     return failure();
897   ArrayRef<int64_t> dstShape = expandTy.getShape();
898   SmallVector<ReassociationIndices> reassocIndices =
899       expandOp.getReassociationIndices();
900   // Project inner tile pos to the dim pos after expanding. For example, if dims
901   // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
902   // on dim y.
903   //
904   // Project to inner-most non-unit dims to increase the chance that they can be
905   // divided by the inner tile sizes. This is correct because for [..., x, 1],
906   // unpacking on dim 1 is equivalent to unpacking on dim x.
907   SmallVector<int64_t> projectedInnerDimsPos =
908       projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
909 
910   if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
911                                   innerTileSizes)) {
912     return failure();
913   }
914   // Expand the outer dims permutation with the associated expanded dims for the
915   // new permutation after pushing. This is because moving a source dim is
916   // equivalent to moving the associated expanded dims together.
917   SmallVector<int64_t> newOuterDimsPerm;
918   for (auto outerPos : outerDimsPerm) {
919     newOuterDimsPerm.insert(newOuterDimsPerm.end(),
920                             reassocIndices[outerPos].begin(),
921                             reassocIndices[outerPos].end());
922   }
923 
924   SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
925   // First apply the permutation on the reassociations of the outer dims.
926   // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
927   // -> [[0], [1, 2]]
928   int64_t nextPos =
929       applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
930   // Then add direct mapping for the inner tile dims.
931   for (size_t i = 0; i < innerDimsPos.size(); ++i) {
932     newReassocIndices.push_back({nextPos});
933     nextPos += 1;
934   }
935 
936   RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
937       expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
938   auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
939       expandOp.getLoc(), newExpandType, unPackOp.getSource(),
940       newReassocIndices);
941 
942   auto emptyOp = tensor::UnPackOp::createDestinationTensor(
943       rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
944       projectedInnerDimsPos, newOuterDimsPerm);
945   auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
946       unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
947       projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
948   rewriter.replaceOp(expandOp, newUnPackOp);
949 
950   return success();
951 }
952 
953 class PushDownUnPackOpThroughReshapeOp final
954     : public OpRewritePattern<tensor::UnPackOp> {
955 public:
956   PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
957                                    ControlPropagationFn fun)
958       : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
959   }
960 
961   LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
962                                 PatternRewriter &rewriter) const override {
963     Value result = unPackOp.getResult();
964     // Currently only support unpack op with the single user.
965     if (!result.hasOneUse()) {
966       return failure();
967     }
968     // Currently only support static inner tile sizes.
969     if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
970           return ShapedType::isDynamic(size);
971         })) {
972       return failure();
973     }
974 
975     Operation *consumerOp = *result.user_begin();
976     return TypeSwitch<Operation *, LogicalResult>(consumerOp)
977         .Case([&](tensor::ExpandShapeOp op) {
978           return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
979                                                     controlFn);
980         })
981         .Default([](Operation *) { return failure(); });
982   }
983 
984 private:
985   ControlPropagationFn controlFn;
986 };
987 
988 // TODO: Relax this restriction. We should unpack a generic op also
989 // in the presence of multiple unpack ops as producers.
990 /// Return the unpacked operand, if present, for the current generic op.
991 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
992   OpOperand *unPackedOperand = nullptr;
993   for (OpOperand &operand : genericOp->getOpOperands()) {
994     auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>();
995     if (!unPackOp)
996       continue;
997     if (unPackedOperand)
998       return failure();
999     unPackedOperand = &operand;
1000   }
1001   if (!unPackedOperand)
1002     return failure();
1003   return unPackedOperand;
1004 }
1005 
1006 /// Push down a tensor.unpack op through a generic op.
1007 /// The new generic op works on packed domain; pack ops are created for input
1008 /// and output operands. A tensor.unpack op is inserted right after the packed
1009 /// generic. E.g.
1010 ///
1011 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1012 ///
1013 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
1014 ///
1015 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1016 /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
1017 ///                          inner_dims_pos = [3] inner_tiles = [32] into %0
1018 /// %2 = linalg.generic {indexing_maps = [#map],
1019 ///      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1020 ///      outs(%1 : tensor<12x56x56x64xf32>) {
1021 ///      ^bb0(%out : f32):
1022 ///         linalg.yield %out : f32
1023 ///      } -> tensor<12x56x56x64xf32>
1024 ///
1025 /// will be converted to
1026 ///
1027 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
1028 ///
1029 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1030 /// %1 = linalg.generic {indexing_maps = [#map],
1031 ///      iterator_types = ["parallel", "parallel", "parallel",
1032 ///                        "parallel", "parallel"]}
1033 ///      outs(%arg0 : tensor<12x2x56x56x32xf32>) {
1034 ///      ^bb0(%out : f32):
1035 ///         linalg.yield %out : f32
1036 ///      } -> tensor<12x2x56x56x32xf32>
1037 /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2]
1038 ///                       inner_dims_pos = [3] inner_tiles = [32] into %0
1039 ///
1040 static FailureOr<std::tuple<GenericOp, Value>>
1041 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
1042                                  ControlPropagationFn controlFn) {
1043   if (genericOp.getNumResults() != 1)
1044     return failure();
1045 
1046   if (hasGatherSemantics(genericOp))
1047     return failure();
1048 
1049   // Collect the unPacked operand, if present.
1050   auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1051   if (failed(maybeUnPackedOperand))
1052     return failure();
1053   OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1054 
1055   // Extract packing information.
1056   tensor::UnPackOp producerUnPackOp =
1057       unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
1058   assert(producerUnPackOp && "expect a valid UnPackOp");
1059 
1060   if (!controlFn(unPackedOperand))
1061     return failure();
1062 
1063   auto packInfo =
1064       getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1065   if (failed(packInfo))
1066     return failure();
1067 
1068   // Rebuild the indexing map for the corresponding init operand.
1069   auto [packedOutOperand, packedOutIndexingMap] =
1070       getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
1071                                      genericOp, genericOp.getDpsInitOperand(0));
1072   auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
1073 
1074   // If the dps init operand of the generic is a tensor.empty, do not pack it
1075   // and forward the new tensor.empty as a destination.
1076   Value dest = packedOutOperand;
1077   if (auto initTensor = genericOp.getDpsInitOperand(0)
1078                             ->get()
1079                             .getDefiningOp<tensor::EmptyOp>()) {
1080     if (destPack)
1081       dest = destPack.getDest();
1082   }
1083 
1084   // Pack the genericOp.
1085   GenericOp newGenericOp =
1086       packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
1087   Value newResult =
1088       newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1089 
1090   // If the output is unaffected, no need to unpack.
1091   if (!destPack)
1092     return std::make_tuple(newGenericOp, newResult);
1093 
1094   auto mixedTiles = destPack.getMixedTiles();
1095   auto innerDimsPos = destPack.getInnerDimsPos();
1096   auto outerDimsPerm = destPack.getOuterDimsPerm();
1097 
1098   // If the output type for the generic differs from the source
1099   // unpack op, we need to create a new destination tensor. In the
1100   // dynamic case we always need a new destination.
1101   auto loc = genericOp.getLoc();
1102   Value unPackDest = producerUnPackOp.getDest();
1103   auto genericOutType =
1104       cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
1105   if (producerUnPackOp.getDestType() != genericOutType ||
1106       !genericOutType.hasStaticShape()) {
1107     unPackDest = tensor::UnPackOp::createDestinationTensor(
1108         rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
1109   }
1110 
1111   // Insert an unPackOp right after the packed generic.
1112   Value unPackOpRes =
1113       rewriter
1114           .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
1115                                     mixedTiles, outerDimsPerm)
1116           .getResult();
1117 
1118   return std::make_tuple(newGenericOp, unPackOpRes);
1119 }
1120 
1121 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
1122 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
1123 public:
1124   PushDownUnPackOpThroughGenericOp(MLIRContext *context,
1125                                    ControlPropagationFn fun)
1126       : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1127 
1128   LogicalResult matchAndRewrite(GenericOp genericOp,
1129                                 PatternRewriter &rewriter) const override {
1130     auto genericAndRepl =
1131         pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
1132     if (failed(genericAndRepl))
1133       return failure();
1134     rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1135     return success();
1136   }
1137 
1138 private:
1139   ControlPropagationFn controlFn;
1140 };
1141 
1142 /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
1143 /// add as many zero padding dimensions in `high` and `low` based on the number
1144 /// of point loops.
1145 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1146   PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1147       : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
1148 
1149   LogicalResult matchAndRewrite(tensor::PadOp padOp,
1150                                 PatternRewriter &rewriter) const override {
1151     tensor::UnPackOp unpackOp =
1152         padOp.getSource().getDefiningOp<tensor::UnPackOp>();
1153     if (!unpackOp)
1154       return failure();
1155 
1156     if (!controlFn(&padOp.getSourceMutable()))
1157       return failure();
1158 
1159     Location loc = padOp.getLoc();
1160     // Bail out if one of the padded dimension is a tiled one.
1161     llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1162     ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1163     llvm::SmallBitVector innerDims(paddedDims.size());
1164     for (int64_t dim : innerDimsPos)
1165       innerDims.flip(dim);
1166     if (paddedDims.anyCommon(innerDims))
1167       return failure();
1168 
1169     Value paddingVal = padOp.getConstantPaddingValue();
1170     if (!paddingVal)
1171       return failure();
1172 
1173     // If we have `outer_dims_perms` we need to adjust the padded dimensions.
1174     ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1175     SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
1176     SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
1177     if (!outerDimsPerm.empty()) {
1178       applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1179       applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1180     }
1181     // Add zero padding for the point loops.
1182     size_t pointLoopsSize = innerDimsPos.size();
1183     lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1184     highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1185 
1186     auto newPadOp = rewriter.create<tensor::PadOp>(
1187         loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
1188         paddingVal, padOp.getNofold());
1189 
1190     // Inject the tensor.unpack right after the packed padOp.
1191     Value outputUnPack = rewriter.create<tensor::EmptyOp>(
1192         loc, padOp.getResultType().getShape(),
1193         padOp.getResultType().getElementType());
1194 
1195     Value replacement = rewriter.create<tensor::UnPackOp>(
1196         loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1197         unpackOp.getMixedTiles(), outerDimsPerm);
1198     rewriter.replaceOp(padOp, replacement);
1199     return success();
1200   }
1201 
1202 private:
1203   ControlPropagationFn controlFn;
1204 };
1205 
1206 } // namespace
1207 
1208 void mlir::linalg::populateDataLayoutPropagationPatterns(
1209     RewritePatternSet &patterns,
1210     const ControlPropagationFn &controlPackUnPackPropagation) {
1211   patterns
1212       .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1213               BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1214               PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1215           patterns.getContext(), controlPackUnPackPropagation);
1216 }
1217