xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (revision b4563ee17ce45728a323c2708e549627b0a8ee9c)
1 //===- DataLayoutPropagation.cpp -----------------------------------------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Passes.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Tensor/Utils/Utils.h"
17 #include "mlir/Dialect/Utils/IndexingUtils.h"
18 #include "mlir/IR/Dominance.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "llvm/Support/Debug.h"
21 #include <optional>
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 #define DEBUG_TYPE "linalg-data-layout-propagation"
32 
33 namespace {
34 
35 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
36   for (Operation &op : genericOp.getBody()->getOperations())
37     if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
38       return true;
39   return false;
40 }
41 
42 // The struct contains the infomation about mapping packing information to
43 // the iteration domain of Linalg ops.
44 struct PackInfo {
45   int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
46   // InnerDimsPos on iteration domain, which follows the order in pack ops.
47   SmallVector<int64_t> tiledDimsPos;
48   // The sizes of tiling data dimensions on iteration domain.
49   llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
50   // The mapping from a dimension of iteration domain to the corresponding inner
51   // tiling dimension on iteration domain.
52   llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
53   // The permutation of outer dims (on domain).
54   SmallVector<int64_t> outerDimsOnDomainPerm;
55 };
56 
57 template <typename OpTy>
58 static FailureOr<PackInfo>
59 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
60                           OpTy packOrUnPackOp) {
61   static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
62                 "applies to only pack or unpack operations");
63   LLVM_DEBUG(
64       { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
65 
66   AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
67   SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
68   SmallVector<utils::IteratorType> iterators =
69       genericOp.getIteratorTypesArray();
70 
71   PackInfo packInfo;
72   int64_t origNumDims = indexingMap.getNumDims();
73   SmallVector<AffineExpr> exprs(indexingMap.getResults());
74   ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
75   for (auto [index, innerDimPos, tileSize] :
76        llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
77                        innerDimsPos, packOrUnPackOp.getMixedTiles())) {
78     auto expr = exprs[innerDimPos];
79     if (!expr.template isa<AffineDimExpr>())
80       return failure();
81     int64_t domainDimPos =
82         exprs[innerDimPos].template cast<AffineDimExpr>().getPosition();
83     if (!isParallelIterator(iterators[domainDimPos]))
84       return failure();
85     packInfo.tiledDimsPos.push_back(domainDimPos);
86     packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
87     packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
88     LLVM_DEBUG({
89       llvm::dbgs() << "map innerDimPos=" << innerDimPos
90                    << " to iteration dimension (d" << domainDimPos << ", d"
91                    << packInfo.tileToPointMapping[domainDimPos]
92                    << "), which has size=("
93                    << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
94     });
95   }
96 
97   // Bail out if a tiled dimension is present in a map but not as an affine dim
98   // expression.
99   auto areAllAffineDimExpr = [&](int dim) {
100     for (AffineMap map : indexingMaps) {
101       if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
102             return expr.isFunctionOfDim(dim) && !expr.isa<AffineDimExpr>();
103           })) {
104         return false;
105       }
106     }
107     return true;
108   };
109   for (int64_t i : packInfo.tiledDimsPos)
110     if (!areAllAffineDimExpr(i))
111       return failure();
112 
113   // Get the outer dims perm on the iteration domain. Start by identifying the
114   // set of domain dims affected by the outer permutation along with the
115   // permuted ordering for those dims. Then the full outer dims permutation can
116   // be constructed by replacing the affected dims with the permuted result in a
117   // numLoops-rank identity. e.g.
118   //   outerDimsPerm = [1, 2, 0]
119   //   indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
120   //
121   //   permutedOuterDims =        [4,    3, 1]
122   //   outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
123   //
124   // Non-affine dim expressions must not be permuted by the outer dims
125   // permutation.
126   SmallVector<int64_t> permutedOuterDims;
127   for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
128     auto permutedExpr = indexingMap.getResult(dim);
129     if (auto dimExpr = permutedExpr.template dyn_cast<AffineDimExpr>()) {
130       permutedOuterDims.push_back(dimExpr.getPosition());
131       continue;
132     }
133 
134     // TODO: Allow propagation with transposes on non affine dim expressions,
135     // e.g. d0 + d1 which implies transposing both dims simultaneously while
136     // maintaining the relative position between them.
137     if (static_cast<int64_t>(index) != dim)
138       return failure();
139   }
140   if (!permutedOuterDims.empty()) {
141     int64_t outerDimIndex = 0;
142     llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
143                                                permutedOuterDims.end());
144     for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
145       packInfo.outerDimsOnDomainPerm.push_back(
146           permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
147                                          : i);
148     LLVM_DEBUG({
149       llvm::dbgs() << "map outer dimsDimsPerm to ";
150       for (auto dim : packInfo.outerDimsOnDomainPerm)
151         llvm::dbgs() << dim << " ";
152       llvm::dbgs() << "\n";
153     });
154   }
155 
156   return packInfo;
157 }
158 
159 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
160                                              ArrayRef<AffineExpr> exprs) {
161   // Compute `outer_dims_perm`. See example:
162   // current exprs      : (d0, d1, d2, d3) -> (d2, d3)
163   // perm               : [0, 3, 1, 2]
164   // First map d2, d3 with their position in the array as:
165   // currentPositionTileLoops: dim | pos
166   //                           d2  | 0
167   //                           d3  | 1
168   // then scan `perm` in order and get the `outer_dims_perm`
169   // to be used, here it would be [1, 0].
170   assert(!perm.empty() && "expect perm not to be empty");
171   assert(!exprs.empty() && "expect exprs not to be empty");
172   if (exprs.size() == 1)
173     return {};
174   SmallVector<int64_t> outerDimsPerm;
175   DenseMap<int64_t, int64_t> currentPositionTileLoops;
176   for (auto [pos, expr] : llvm::enumerate(exprs)) {
177     // Here we rely on the assumption that the outer dims permutation
178     // when propagating currently requires that non-affine dim expressions
179     // are not permuted, thus allowing the identity assignment below.
180     if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
181       currentPositionTileLoops[dimExpr.getPosition()] = pos;
182     else
183       currentPositionTileLoops[pos] = pos;
184   }
185   for (int64_t loopIdx : perm) {
186     if (currentPositionTileLoops.count(loopIdx))
187       outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
188   }
189   return outerDimsPerm;
190 }
191 
192 /// Returns a tuple for packed operand and indexing_map with the assumptions:
193 ///   1) The generic op is the producer of the pack op.
194 ///   2) The generic op has only one result.
195 /// If the operand is a scalar or packing dimensions are all irrelevant to the
196 /// operand, the operand and the updated indexing map will be returned.
197 /// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
198 ///
199 ///   #map0 = affine_map<(d0, d1) -> (d0, d1)>
200 ///   #map1 = affine_map<(d0, d1) -> (d0)>
201 ///   #map2 = affine_map<(d0, d1) -> (d1)>
202 ///   %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
203 ///                        iterator_types = ["parallel", "parallel"]}
204 ///      ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
205 ///      outs(%init : tensor<?x?xf32>) {
206 ///    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
207 ///      %4 = arith.addf %arg3, %arg4 : f32
208 ///      linalg.yield %4 : f32
209 ///  } -> tensor<?x?xf32>
210 ///  %1 = tensor.pack %0
211 ///    inner_dims_pos = [0, 1]
212 ///    inner_tiles = [8, 2]
213 ///    into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
214 ///
215 ///  Taking the first input operand as an example, the inner tile size of d1 is
216 ///  8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
217 ///  affine_map<(d1, d3)>` will be returned.
218 ///
219 ///  %pack = tensor.pack %arg0
220 ///    inner_dims_pos = [0]
221 ///    inner_tiles = [8]
222 ///    into %init : tensor<?xf32> -> tensor<?x8xf32>
223 static std::tuple<Value, AffineMap>
224 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
225                                GenericOp genericOp, OpOperand *opOperand) {
226   int64_t numOrigLoops = genericOp.getNumLoops();
227   int64_t numInnerLoops = packInfo.getNumTiledLoops();
228   int64_t numLoops = numOrigLoops + numInnerLoops;
229   AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
230   llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
231   SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
232   if (genericOp.isScalar(opOperand) || exprs.empty())
233     return std::make_tuple(opOperand->get(),
234                            AffineMap::get(numLoops, 0, exprs, b.getContext()));
235 
236   // Step 1. Construct the information of packing data dimensions; append inner
237   // dimensions to the indexing maps for the operand.
238   for (auto [index, expr] : llvm::enumerate(exprs)) {
239     if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
240       int64_t dimPos = dimExpr.getPosition();
241       domainDimToOperandDim[dimPos] = index;
242       continue;
243     }
244   }
245   SmallVector<int64_t> innerDimsPos;
246   SmallVector<OpFoldResult> innerTileSizes;
247   for (auto dimPos : packInfo.tiledDimsPos) {
248     if (!domainDimToOperandDim.count(dimPos))
249       continue;
250     int64_t index = domainDimToOperandDim[dimPos];
251     innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
252     innerDimsPos.push_back(index);
253     exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
254   }
255 
256   // Step 2. Handle outer dim permutations.
257   SmallVector<int64_t> outerDimsPerm;
258   if (!packInfo.outerDimsOnDomainPerm.empty()) {
259     outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
260 
261     // Step 2.1: Fold transpose into the linalg.generic.
262     SmallVector<int64_t> inversedOuterPerm =
263         invertPermutationVector(packInfo.outerDimsOnDomainPerm);
264     for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
265       if (auto dimExpr = exprs[i].dyn_cast<AffineDimExpr>()) {
266         int64_t dimPos = dimExpr.getPosition();
267         exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
268         continue;
269       }
270       assert(exprs[i].isa<AffineConstantExpr>() &&
271              "Attempted to permute non-constant and non-affine dim expression");
272     }
273     // Step 2.2: Undo the transposition on `exprs` and propagate the
274     // transposition on the pack using outerDimsPerm.
275     if (!outerDimsPerm.empty()) {
276       SmallVector<AffineExpr> auxVec = exprs;
277       for (const auto &en : enumerate(outerDimsPerm))
278         auxVec[en.index()] = exprs[en.value()];
279       exprs = auxVec;
280     }
281   }
282   auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
283 
284   // The operand does not have dimensions that relates to pack op.
285   if (innerDimsPos.empty() && outerDimsPerm.empty())
286     return std::make_tuple(opOperand->get(), indexingMap);
287 
288   auto empty = tensor::PackOp::createDestinationTensor(
289       b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
290   auto packedOperand = b.create<tensor::PackOp>(
291       loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
292       /*padding=*/std::nullopt, outerDimsPerm);
293   return std::make_tuple(packedOperand, indexingMap);
294 }
295 
296 /// Pack an element-wise genericOp and return it.
297 static GenericOp packElementWiseOp(RewriterBase &rewriter, GenericOp genericOp,
298                                    Value dest, AffineMap packedOutIndexingMap,
299                                    const PackInfo &packInfo) {
300   Location loc = genericOp.getLoc();
301   SmallVector<Value> inputOperands;
302   SmallVector<AffineMap> indexingMaps;
303   for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
304     auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
305         rewriter, loc, packInfo, genericOp, inputOperand);
306     inputOperands.push_back(packedOperand);
307     indexingMaps.push_back(packedIndexingMap);
308   }
309 
310   int64_t numInnerLoops = packInfo.getNumTiledLoops();
311   SmallVector<utils::IteratorType> iterTypes =
312       genericOp.getIteratorTypesArray();
313   iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
314 
315   indexingMaps.push_back(packedOutIndexingMap);
316 
317   auto newGenericOp = rewriter.create<linalg::GenericOp>(
318       loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
319       /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
320   rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
321                              newGenericOp.getRegion().begin());
322   return newGenericOp;
323 }
324 
325 /// Bubbles up tensor.pack op through a producer generic op. This
326 /// swap pack(generic) to generic(pack). The new generic op works on packed
327 /// domain; pack ops are created for input and output operands. E.g.,
328 ///
329 ///     #map0 = affine_map<(d0, d1) -> (d0, d1)>
330 ///     %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
331 ///     %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
332 ///     %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
333 ///     %3 = linalg.generic {indexing_maps = [#map0, #map0],
334 ///                          iterator_types = ["parallel", "parallel"]}
335 ///         ins(%arg0 : tensor<?x?xf32>)
336 ///         outs(%2 : tensor<?x?xf32>) {
337 ///       ^bb0(%arg3: f32, %arg4: f32):
338 ///         %4 = arith.addf %arg3, %arg3 : f32
339 ///         linalg.yield %4 : f32
340 ///     } -> tensor<?x?xf32>
341 ///     %4 = tensor.pack %3
342 ///       inner_dims_pos = [0, 1]
343 ///       inner_tiles = [8, 2]
344 ///       into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
345 ///
346 /// will be converted to
347 ///
348 ///     #map = affine_map<()[s0] -> (s0 ceildiv 8)>
349 ///     #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
350 ///     #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
351 ///     %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
352 ///     %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
353 ///     %0 = affine.apply #map()[%dim]
354 ///     %1 = affine.apply #map1()[%dim_0]
355 ///     %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
356 ///     %pack = tensor.pack %arg0
357 ///       inner_dims_pos = [0, 1]
358 ///       inner_tiles = [8, 2]
359 ///       into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
360 ///     %3 = linalg.generic {indexing_maps = [#map2, #map2],
361 ///       iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
362 ///       ins(%pack : tensor<?x?x8x2xf32>)
363 ///       outs(%arg1 : tensor<?x?x8x2xf32>) {
364 ///     ^bb0(%in: f32, %out: f32):
365 ///       %4 = arith.addf %in, %in : f32
366 ///       linalg.yield %4 : f32
367 ///     } -> tensor<?x?x8x2xf32>
368 static FailureOr<GenericOp>
369 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
370                                ControlPropagationFn controlFn) {
371   auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
372   if (!genericOp)
373     return failure();
374 
375   // User controlled propagation function.
376   if (!controlFn(genericOp))
377     return failure();
378 
379   // TODO: Enable propagation in the presence of linalg.index and
380   // tensor.extract, likely as a separate pattern as the pack information and
381   // propagation decision needs to be inferred from the region of the generic.
382   if (hasGatherSemantics(genericOp))
383     return failure();
384 
385   // TODO: Relax the restriction. We are able to bubble up the pack op through
386   // multi-result generic op. It just needs more work.
387   if (genericOp.getNumResults() != 1)
388     return failure();
389 
390   // Bail-out if the result of the generic has multiple uses, as bubbling up
391   // creates recomputation if the generic has multiple users.
392   // TODO: Enable the case where every use is an identical pack op as no
393   // recomputation is needed in that case.
394   if (!genericOp->getResult(0).hasOneUse())
395     return failure();
396 
397   // We want to move the pack not the generic.
398   OpBuilder::InsertionGuard guard(rewriter);
399   rewriter.setInsertionPoint(genericOp);
400 
401   // We need to handle two cases:
402   // 1) The tensor.pack destination is a tensor.empty. If this is the case, we
403   // create a new tensor.empty to avoid breaking dominance, as we are moving the
404   // tensor.pack above the linalg.generic.
405   // 2) The destination is not a tensor.empty. In this case we can replace only
406   // if the destination of the tensor.pack dominates the linalg.generic.
407   Value packOpDest = packOp.getDest();
408   if (!packOpDest.hasOneUse())
409     return failure();
410   if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
411     packOpDest = rewriter.create<tensor::EmptyOp>(
412         genericOp->getLoc(), emptyOp.getMixedSizes(),
413         emptyOp.getType().getElementType());
414   } else {
415     DominanceInfo dom(genericOp);
416     if (!dom.properlyDominates(packOpDest, genericOp))
417       return failure();
418   }
419 
420   // TODO: Add an option for allowing padding values. It could introduce
421   // undefined behavior if we unconditionally propagate pack op through all
422   // the ops. E.g., if the padding value is zero and there are division ops in
423   // a generic op. Some values of padding area could be NaN (0/0).
424   if (packOp.getPaddingValue())
425     return failure();
426 
427   OpOperand *opOperand = genericOp.getDpsInitOperand(0);
428   auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
429   if (failed(packInfo))
430     return failure();
431 
432   // Rebuild the indexing map for the corresponding init operand.
433   auto [packedOutOperand, packedOutIndexingMap] =
434       getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
435                                      genericOp, opOperand);
436 
437   // We'll replace the init operand with the destination of pack op if the init
438   // operand has not users in the body of the linalg.generic (pure elementwise).
439   // If it has users we need to pack the init operand too and replace the init
440   // with the packing result.
441   Value dest = (genericOp.getRegionOutputArgs()[0].use_empty())
442                    ? packOpDest
443                    : packedOutOperand;
444 
445   return packElementWiseOp(rewriter, genericOp, dest, packedOutIndexingMap,
446                            *packInfo);
447 }
448 
449 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
450 struct BubbleUpPackOpThroughGenericOpPattern
451     : public OpRewritePattern<tensor::PackOp> {
452 public:
453   BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
454                                         ControlPropagationFn fun)
455       : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
456 
457   LogicalResult matchAndRewrite(tensor::PackOp packOp,
458                                 PatternRewriter &rewriter) const override {
459     auto genericOp =
460         bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
461     if (failed(genericOp))
462       return failure();
463     rewriter.replaceOp(packOp, genericOp->getResults());
464     return success();
465   }
466 
467 private:
468   ControlPropagationFn controlFn;
469 };
470 
471 // TODO: Relax this restriction. We should unpack an elementwise also
472 // in the presence of multiple unpack ops as producers.
473 /// Return the unpacked operand, if present, for the current generic op.
474 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
475   OpOperand *unPackedOperand = nullptr;
476   for (OpOperand &operand : genericOp->getOpOperands()) {
477     auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>();
478     if (!unPackOp)
479       continue;
480     if (unPackedOperand)
481       return failure();
482     unPackedOperand = &operand;
483   }
484   if (!unPackedOperand)
485     return failure();
486   return unPackedOperand;
487 }
488 
489 /// Push down a tensor.unpack op through elementwise generic op.
490 /// The new generic op works on packed domain; pack ops are created for input
491 /// and output operands. A tensor.unpack op is inserted right after the packed
492 /// generic. E.g.
493 ///
494 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
495 ///
496 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
497 ///
498 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
499 /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
500 ///                          inner_dims_pos = [3] inner_tiles = [32] into %0
501 /// %2 = linalg.generic {indexing_maps = [#map],
502 ///      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
503 ///      outs(%1 : tensor<12x56x56x64xf32>) {
504 ///      ^bb0(%out : f32):
505 ///         linalg.yield %out : f32
506 ///      } -> tensor<12x56x56x64xf32>
507 ///
508 /// will be converted to
509 ///
510 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
511 ///
512 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
513 /// %1 = linalg.generic {indexing_maps = [#map],
514 ///      iterator_types = ["parallel", "parallel", "parallel",
515 ///                        "parallel", "parallel"]}
516 ///      outs(%arg0 : tensor<12x2x56x56x32xf32>) {
517 ///      ^bb0(%out : f32):
518 ///         linalg.yield %out : f32
519 ///      } -> tensor<12x2x56x56x32xf32>
520 /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2]
521 ///                       inner_dims_pos = [3] inner_tiles = [32] into %0
522 ///
523 static FailureOr<std::tuple<GenericOp, Value>>
524 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
525   if (genericOp.getNumResults() != 1)
526     return failure();
527 
528   if (hasGatherSemantics(genericOp))
529     return failure();
530 
531   // Collect the unPacked operand, if present.
532   auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
533   if (failed(maybeUnPackedOperand))
534     return failure();
535   OpOperand *unPackedOperand = *(maybeUnPackedOperand);
536 
537   // Extract packing information.
538   tensor::UnPackOp producerUnPackOp =
539       unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
540   assert(producerUnPackOp && "expect a valid UnPackOp");
541   auto packInfo =
542       getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
543   if (failed(packInfo))
544     return failure();
545 
546   // Rebuild the indexing map for the corresponding init operand.
547   auto [packedOutOperand, packedOutIndexingMap] =
548       getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
549                                      genericOp, genericOp.getDpsInitOperand(0));
550   auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
551 
552   // If the dps init operand of the generic is a tensor.empty, do not pack it
553   // and forward the new tensor.empty as a destination.
554   Value dest = packedOutOperand;
555   if (auto initTensor = genericOp.getDpsInitOperand(0)
556                             ->get()
557                             .getDefiningOp<tensor::EmptyOp>()) {
558     if (destPack)
559       dest = destPack.getDest();
560   }
561 
562   // Pack the genericOp.
563   GenericOp newGenericOp = packElementWiseOp(rewriter, genericOp, dest,
564                                              packedOutIndexingMap, *packInfo);
565   Value newResult =
566       newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
567 
568   // If the output is unaffected, no need to unpack.
569   if (!destPack)
570     return std::make_tuple(newGenericOp, newResult);
571 
572   auto mixedTiles = destPack.getMixedTiles();
573   auto innerDimsPos = destPack.getInnerDimsPos();
574   auto outerDimsPerm = destPack.getOuterDimsPerm();
575 
576   // If the output type for the generic differs from the source
577   // unpack op, we need to create a new destination tensor. In the
578   // dynamic case we always need a new destination.
579   auto loc = genericOp.getLoc();
580   Value unPackDest = producerUnPackOp.getDest();
581   auto genericOutType =
582       genericOp.getDpsInitOperand(0)->get().getType().cast<RankedTensorType>();
583   if (producerUnPackOp.getDestType() != genericOutType ||
584       !genericOutType.hasStaticShape()) {
585     unPackDest = tensor::UnPackOp::createDestinationTensor(
586         rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
587   }
588 
589   // Insert an unPackOp right after the packed generic.
590   Value unPackOpRes =
591       rewriter
592           .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
593                                     mixedTiles, outerDimsPerm)
594           .getResult();
595 
596   return std::make_tuple(newGenericOp, unPackOpRes);
597 }
598 
599 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
600 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
601 public:
602   PushDownUnPackOpThroughGenericOp(MLIRContext *context,
603                                    ControlPropagationFn fun)
604       : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
605 
606   LogicalResult matchAndRewrite(GenericOp genericOp,
607                                 PatternRewriter &rewriter) const override {
608     if (!controlFn(genericOp))
609       return failure();
610 
611     auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
612     if (failed(genericAndRepl))
613       return failure();
614     rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
615     return success();
616   }
617 
618 private:
619   ControlPropagationFn controlFn;
620 };
621 
622 /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
623 /// add as many zero padding dimensions in `high` and `low` based on the number
624 /// of point loops.
625 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
626   PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
627       : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
628 
629   LogicalResult matchAndRewrite(tensor::PadOp padOp,
630                                 PatternRewriter &rewriter) const override {
631     tensor::UnPackOp unpackOp =
632         padOp.getSource().getDefiningOp<tensor::UnPackOp>();
633     if (!unpackOp)
634       return failure();
635 
636     if (!controlFn(padOp))
637       return failure();
638 
639     Location loc = padOp.getLoc();
640     // Bail out if one of the padded dimension is a tiled one.
641     llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
642     ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
643     llvm::SmallBitVector innerDims(paddedDims.size());
644     for (int64_t dim : innerDimsPos)
645       innerDims.flip(dim);
646     if (paddedDims.anyCommon(innerDims))
647       return failure();
648 
649     Value paddingVal = padOp.getConstantPaddingValue();
650     if (!paddingVal)
651       return failure();
652 
653     // If we have `outer_dims_perms` we need to adjust the padded dimensions.
654     ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
655     SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
656     SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
657     if (!outerDimsPerm.empty()) {
658       applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
659       applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
660     }
661     // Add zero padding for the point loops.
662     size_t pointLoopsSize = innerDimsPos.size();
663     lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
664     highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
665 
666     auto newPadOp = rewriter.create<tensor::PadOp>(
667         loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
668         paddingVal, padOp.getNofold());
669 
670     // Inject the tensor.unpack right after the packed padOp.
671     Value outputUnPack = rewriter.create<tensor::EmptyOp>(
672         loc, padOp.getResultType().getShape(),
673         padOp.getResultType().getElementType());
674 
675     Value replacement = rewriter.create<tensor::UnPackOp>(
676         loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
677         unpackOp.getMixedTiles(), outerDimsPerm);
678     rewriter.replaceOp(padOp, replacement);
679     return success();
680   }
681 
682 private:
683   ControlPropagationFn controlFn;
684 };
685 
686 } // namespace
687 
688 void mlir::linalg::populateDataLayoutPropagationPatterns(
689     RewritePatternSet &patterns,
690     const ControlPropagationFn &controlPackUnPackPropagation) {
691   patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
692                   PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
693       patterns.getContext(), controlPackUnPackPropagation);
694 }
695