xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (revision 98e838a890191b9250ad33741a1c121a9591caa3)
1 //===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/IR/Linalg.h"
10 #include "mlir/Dialect/Tensor/IR/Tensor.h"
11 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
12 #include "mlir/Dialect/Utils/IndexingUtils.h"
13 #include "mlir/IR/PatternMatch.h"
14 
15 namespace mlir {
16 namespace tensor {
17 namespace {
18 
19 /// Returns the number of shape sizes that is either dynamic or greater than 1.
20 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
21   return llvm::count_if(
22       shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
23 }
24 
25 /// Returns success() if there is only 1 dimension size in non-packed domain
26 /// being greater than 1 and packing only happens on the dimension.
27 /// Note: this method should only be used by pack/unpack to reshape conversion.
28 /// It assumes that non-unit inner tile size must be used by the non-unit
29 /// dimension.
30 static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
31                                 ArrayRef<int64_t> srcShape,
32                                 ArrayRef<int64_t> innerPackTileSize) {
33   if (getNumGtOneDims(srcShape) > 1) {
34     return rewriter.notifyMatchFailure(
35         op, "expects non-packed domain to have at most one non-unit dims");
36   }
37   // Non-unit inner tile size must be used by the non-unit dimension. If not, it
38   // will faill on getting reassociation maps.
39   if (getNumGtOneDims(innerPackTileSize) > 1) {
40     return rewriter.notifyMatchFailure(
41         op, "expects at most one non-unit inner tiles");
42   }
43   return success();
44 }
45 
46 // If the `linalgOp` represents a transpose, return the permutation vector for
47 // the transpose. Otherwise, return failure.
48 static FailureOr<SmallVector<int64_t>>
49 getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
50   if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
51     return SmallVector<int64_t>(transposeOp.getPermutation());
52   if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
53     return failure();
54 
55   if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
56     return failure();
57   auto mapRange = linalgOp.getIndexingMapsArray();
58   if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
59       mapRange.front() == mapRange.back()) {
60     return failure();
61   }
62   if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
63     return failure();
64   AffineMap outMap = mapRange.back();
65   AffineMap inMap = mapRange.front();
66   // To get the permutation, look at each output index and find which
67   // dimension in the input we're reading from for that index.
68   return llvm::map_to_vector(outMap.getResults(),
69                              [&](AffineExpr expr) -> int64_t {
70                                return *inMap.getResultPosition(expr);
71                              });
72 }
73 
74 /// Packing one-dimensional tensor can be expressed as an expand shape op.
75 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
76   using OpRewritePattern<PackOp>::OpRewritePattern;
77 
78   FailureOr<Value>
79   insertExpand(RewriterBase &rewriter, Location loc, Value operand,
80                Type newOperandType,
81                ArrayRef<ReassociationIndices> reassociation) const {
82     if (operand.getType() == newOperandType)
83       return operand;
84     return rewriter
85         .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
86                                        reassociation)
87         .getResult();
88   }
89 
90   /// Returns success() if it is only packing on the innermost dimension.
91   LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
92                                      PackOp packOp) const {
93     auto outerDimsPerm = packOp.getOuterDimsPerm();
94     if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
95       return rewriter.notifyMatchFailure(
96           packOp,
97           "expects outer_dims_perm is empty or an identity permutation");
98     }
99 
100     int64_t srcRank = packOp.getSourceRank();
101     ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
102     if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
103       return rewriter.notifyMatchFailure(
104           packOp, "expects packing at the innermost dimension");
105     }
106     return success();
107   }
108 
109   LogicalResult matchAndRewrite(PackOp packOp,
110                                 PatternRewriter &rewriter) const override {
111     if (packOp.getPaddingValue())
112       return rewriter.notifyMatchFailure(packOp, "expects no padding value");
113 
114     RankedTensorType sourceType = packOp.getSourceType();
115     if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
116         failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
117                           packOp.getStaticTiles())) &&
118         !packOp.isLikePad()) {
119       return failure();
120     }
121 
122     RankedTensorType destType = packOp.getDestType();
123     auto reassociation =
124         getReassociationIndicesForReshape(sourceType, destType);
125     if (!reassociation)
126       return failure();
127     FailureOr<Value> expanded =
128         insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
129                      *reassociation);
130     if (failed(expanded)) {
131       return rewriter.notifyMatchFailure(
132           packOp, "unable to expand source of tensor.pack");
133     }
134     rewriter.replaceOp(packOp, *expanded);
135     return success();
136   }
137 };
138 
139 struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
140   using OpRewritePattern<UnPackOp>::OpRewritePattern;
141 
142   Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
143                        Type newOperandType, ArrayAttr reassociation) const {
144     if (operand.getType() == newOperandType)
145       return operand;
146     return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
147                                                     operand, reassociation);
148   }
149 
150   /// Returns success() if it is unpacking on the innermost dimension.
151   LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
152                                        UnPackOp unpackOp) const {
153     auto outerDimsPerm = unpackOp.getOuterDimsPerm();
154     if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
155       return rewriter.notifyMatchFailure(
156           unpackOp,
157           "expects outer_dims_perm is empty or an identity permutation");
158     }
159 
160     RankedTensorType sourceType = unpackOp.getSourceType();
161     RankedTensorType destType = unpackOp.getDestType();
162     if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
163       return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
164 
165     ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
166     if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
167       return rewriter.notifyMatchFailure(
168           unpackOp, "expects unpacking on the innermost dimension");
169     }
170 
171     return success();
172   }
173 
174   LogicalResult matchAndRewrite(UnPackOp unpackOp,
175                                 PatternRewriter &rewriter) const override {
176     RankedTensorType destType = unpackOp.getDestType();
177     if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
178         failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
179                           unpackOp.getStaticTiles())) &&
180         !unpackOp.isLikeUnPad()) {
181       return failure();
182     }
183 
184     RankedTensorType sourceType = unpackOp.getSourceType();
185     auto reassociation =
186         getReassociationIndicesForReshape(sourceType, destType);
187     if (!reassociation)
188       return failure();
189     Value collapsed = insertCollapse(
190         rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
191         getReassociationIndicesAttribute(rewriter, *reassociation));
192     rewriter.replaceOp(unpackOp, collapsed);
193     return success();
194   }
195 };
196 
197 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
198 /// the pad op has zero low paddings, or if `pack` has no padding values.
199 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
200   using OpRewritePattern<PackOp>::OpRewritePattern;
201 
202   LogicalResult matchAndRewrite(PackOp packOp,
203                                 PatternRewriter &rewriter) const override {
204     auto padOp = packOp.getSource().getDefiningOp<PadOp>();
205 
206     if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
207       return failure();
208 
209     Value constantPaddingValue = padOp.getConstantPaddingValue();
210     if (!constantPaddingValue)
211       return failure();
212 
213     if (auto paddingValue = packOp.getPaddingValue())
214       if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
215         return failure();
216 
217     rewriter.replaceOpWithNewOp<PackOp>(
218         packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
219         packOp.getMixedTiles(), constantPaddingValue,
220         packOp.getOuterDimsPerm());
221     return success();
222   }
223 };
224 
225 /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
226 /// has extract_slice semantics.
227 struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
228   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
229 
230   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
231                                 PatternRewriter &rewriter) const override {
232     auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
233     if (!unpackOp)
234       return failure();
235 
236     if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
237       return rewriter.notifyMatchFailure(
238           sliceOp, "rank-reduced folding is not supported");
239     }
240 
241     // Check all offsets are zeros, and all strides are ones.
242     if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
243         !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
244       return rewriter.notifyMatchFailure(
245           sliceOp, "expects offsets to be 0s and strides to be 1s");
246     }
247 
248     // Create a new empty output tensor.
249     Type elementType = unpackOp.getDestType().getElementType();
250     Value output = rewriter.create<EmptyOp>(
251         sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
252     rewriter.replaceOpWithNewOp<UnPackOp>(
253         sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
254         unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
255     return success();
256   }
257 };
258 
259 // Applies 'permutation' on 'inVec' and stores the result in resVec.
260 // 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
261 // `rank` sets the boundary for permutation i.e., the permutation dim can't be
262 // greater than the rank specified. If it's so then return false.
263 // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
264 // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
265 // not allowed since `3` exceeds the value of the rank in the given range.
266 static bool checkAndPermute(ArrayRef<int64_t> permutation,
267                             ArrayRef<int64_t> inVec,
268                             SmallVectorImpl<int64_t> &resVec, int64_t rank) {
269 
270   for (unsigned int i = 0; i < rank; ++i) {
271     int64_t remappedPosition = permutation[i];
272     if (remappedPosition >= rank)
273       return false;
274     if (!inVec.empty())
275       remappedPosition = inVec[remappedPosition];
276     resVec.push_back(remappedPosition);
277   }
278 
279   return true;
280 }
281 
282 /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
283 /// semantics.
284 struct FoldProducerPackWithConsumerLinalgTransposeOp
285     : public OpInterfaceRewritePattern<linalg::LinalgOp> {
286   using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
287 
288   LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
289                                 PatternRewriter &rewriter) const override {
290     auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
291 
292     if (!packOp)
293       return failure();
294 
295     FailureOr<SmallVector<int64_t>> maybePerm =
296         getTransposeOpPermutation(linalgOp);
297     if (failed(maybePerm))
298       return failure();
299 
300     auto innerDimsPos = packOp.getInnerDimsPos();
301     auto mixedInnerTiles = packOp.getMixedTiles();
302     auto outerDimsPerm = packOp.getOuterDimsPerm();
303     auto transposePerm = maybePerm.value();
304     SmallVector<int64_t> newOuterDimsPermVec;
305     SmallVector<int64_t> newInnerDimsPosVec;
306     SmallVector<OpFoldResult> newMixedInnerTilesVec;
307     int64_t srcRank = packOp.getSourceRank();
308 
309     if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
310                          srcRank))
311       return rewriter.notifyMatchFailure(
312           linalgOp,
313           "Cannot fold in tensor.pack if a tile dimension was transposed "
314           "with a non-tile dimension in linalg.transpose.");
315 
316     // Process transpose operation for tiled inner dimensions
317     for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
318       int64_t remappedPosition = transposePerm[i] - srcRank;
319       newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
320       newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
321     }
322 
323     Value output = packOp.createDestinationTensor(
324         rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
325         newInnerDimsPosVec, newOuterDimsPermVec);
326 
327     rewriter.replaceOpWithNewOp<PackOp>(
328         linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
329         newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
330 
331     return success();
332   }
333 };
334 
335 /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
336 /// semantics.
337 struct FoldConsumerPackWithProducerLinalgTransposeOp
338     : public OpRewritePattern<PackOp> {
339   using OpRewritePattern<PackOp>::OpRewritePattern;
340 
341   LogicalResult matchAndRewrite(PackOp packOp,
342                                 PatternRewriter &rewriter) const override {
343     auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
344     if (!linalgOp)
345       return failure();
346 
347     FailureOr<SmallVector<int64_t>> maybePerm =
348         getTransposeOpPermutation(linalgOp);
349     if (failed(maybePerm))
350       return failure();
351 
352     auto transposePermutation = maybePerm.value();
353     auto outerDimsPerm = packOp.getOuterDimsPerm();
354     auto innerDimsPos = packOp.getInnerDimsPos();
355     SmallVector<int64_t> newInnerDimsPosVec;
356     SmallVector<int64_t> newOuterDimsPermVec =
357         llvm::to_vector(transposePermutation);
358 
359     if (!outerDimsPerm.empty())
360       applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
361 
362     // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
363     // permutation rank won't necessarily be equal in all cases.
364     for (auto dim : innerDimsPos)
365       newInnerDimsPosVec.push_back(transposePermutation[dim]);
366 
367     Value output = packOp.createDestinationTensor(
368         rewriter, packOp.getLoc(), linalgOp->getOperand(0),
369         packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
370 
371     rewriter.replaceOpWithNewOp<PackOp>(
372         packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
373         packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
374 
375     return success();
376   }
377 };
378 
379 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
380 /// transpose semantics.
381 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
382     : public OpInterfaceRewritePattern<linalg::LinalgOp> {
383   using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
384 
385   LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
386                                 PatternRewriter &rewriter) const override {
387     auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
388 
389     if (!unPackOp)
390       return failure();
391 
392     FailureOr<SmallVector<int64_t>> maybePerm =
393         getTransposeOpPermutation(linalgOp);
394     if (failed(maybePerm))
395       return failure();
396 
397     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
398     auto innerDimsPos = unPackOp.getInnerDimsPos();
399     SmallVector<int64_t> newInnerDimsPosVec;
400     SmallVector<int64_t> newOuterDimsPermVec =
401         invertPermutationVector(maybePerm.value());
402 
403     // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
404     // permutation rank won't necessarily be equal in all cases.
405     for (auto dim : innerDimsPos)
406       newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
407 
408     if (!outerDimsPerm.empty())
409       applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
410 
411     // Reuse the destination of the transpose op.
412     rewriter.replaceOpWithNewOp<UnPackOp>(
413         linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
414         newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
415 
416     return success();
417   }
418 };
419 
420 /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
421 /// transpose semantics.
422 struct FoldConsumerUnPackWithProducerLinalgTransposeOp
423     : public OpRewritePattern<UnPackOp> {
424   using OpRewritePattern<UnPackOp>::OpRewritePattern;
425 
426   LogicalResult matchAndRewrite(UnPackOp unPackOp,
427                                 PatternRewriter &rewriter) const override {
428     auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
429     if (!linalgOp)
430       return failure();
431 
432     FailureOr<SmallVector<int64_t>> maybePerm =
433         getTransposeOpPermutation(linalgOp);
434     if (failed(maybePerm))
435       return failure();
436 
437     SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
438     if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
439       return failure();
440     }
441 
442     SmallVector<int64_t> inverseTransposePerm =
443         invertPermutationVector(maybePerm.value());
444     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
445     auto innerDimsPos = unPackOp.getInnerDimsPos();
446     int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
447     auto mixedInnerTilesVec = unPackOp.getMixedTiles();
448     SmallVector<int64_t> newOuterDimsPermVec;
449     SmallVector<int64_t> newInnerDimsPosVec;
450     SmallVector<OpFoldResult> newMixedInnerTilesVec;
451     if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
452                          newOuterDimsPermVec, destRank))
453       return rewriter.notifyMatchFailure(
454           unPackOp,
455           "Cannot fold in tensor.unpack if a tile dimension was transposed "
456           "with a non-tile dimension in linalg.transpose.");
457 
458     // Process transpose operation for tiled inner dimensions
459     for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
460       int64_t remappedPosition = inverseTransposePerm[i] - destRank;
461       newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
462       newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
463     }
464 
465     auto elemType =
466         cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
467     Value output = rewriter.create<tensor::EmptyOp>(
468         unPackOp->getLoc(), unpackOpResultDims[0], elemType);
469 
470     rewriter.replaceOpWithNewOp<UnPackOp>(
471         unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
472         newMixedInnerTilesVec, newOuterDimsPermVec);
473 
474     return success();
475   }
476 };
477 } // namespace
478 
479 void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
480   patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
481                   FoldProducerPackWithConsumerLinalgTransposeOp,
482                   FoldConsumerPackWithProducerLinalgTransposeOp,
483                   FoldConsumerUnPackWithProducerLinalgTransposeOp,
484                   FoldProducerUnPackWithConsumerLinalgTransposeOp>(
485       patterns.getContext());
486 }
487 
488 void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
489   patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
490       patterns.getContext());
491 }
492 
493 } // namespace tensor
494 } // namespace mlir
495