xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp (revision 0a8e3dd432ff15ce871e4b9df0645e8a7e011fb3)
1 //===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns that transforms linalg.<op> +
10 // tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce
11 // the computation for the linalg op.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/Utils/Utils.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 using namespace mlir;
24 using namespace mlir::linalg;
25 
26 namespace {
27 /// Bubble up extract_slice above Linalg operation.
28 ///
29 /// A sequence of operations
30 ///
31 /// ```mlir
32 /// %0 = linalg.<op> ... arg0, arg1, ...
33 /// %1 = tensor.extract_slice %0 ...
34 /// ```
35 ///
36 /// can be replaced with
37 ///
38 /// ```mlir
39 /// %0 = tensor.extract_slice %arg0
40 /// %1 = tensor.extract_slice %arg1
41 /// %2 = linalg.<op> ... %0, %1, ...
42 /// ```
43 ///
44 /// This results in the reduce computation of the linalg operation.
45 ///
46 struct BubbleUpExtractSliceOpPattern
47     : OpRewritePattern<tensor::ExtractSliceOp> {
48   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
49 
matchAndRewrite__anon102dabe10111::BubbleUpExtractSliceOpPattern50   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
51                                 PatternRewriter &rewriter) const final {
52     Value source = sliceOp.getSource();
53     auto linalgOp = source.getDefiningOp<LinalgOp>();
54     if (!linalgOp) {
55       return rewriter.notifyMatchFailure(sliceOp,
56                                          "expected source to be linalg op");
57     }
58 
59     // TODO: we might relax this if we want heuristics to detect that all uses
60     // are small portion of the output.
61     if (!linalgOp->hasOneUse()) {
62       return rewriter.notifyMatchFailure(sliceOp,
63                                          "expected single use of linalg op");
64     }
65 
66     if (linalgOp.getNumDpsInits() != 1) {
67       return rewriter.notifyMatchFailure(sliceOp,
68                                          "expected single output of linalg op");
69     }
70 
71     if (!linalgOp.hasPureTensorSemantics()) {
72       return rewriter.notifyMatchFailure(sliceOp,
73                                          "expected tensor of linalg op");
74     }
75 
76     if (!sliceOp.hasUnitStride())
77       return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
78 
79     if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
80       return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
81     }
82 
83     OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
84     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand);
85     if (!indexingMap.isProjectedPermutation()) {
86       return rewriter.notifyMatchFailure(
87           sliceOp, "expected a projected permutation for output");
88     }
89 
90     auto linalgLoc = linalgOp.getLoc();
91     SmallVector<OpFoldResult> allShapeSizes =
92         linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
93     AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
94     if (!shapeSizesToLoopsMap) {
95       return rewriter.notifyMatchFailure(
96           linalgOp, "failed to get loops map from shape sizes");
97     }
98     SmallVector<OpFoldResult> sizeBounds =
99         affine::makeComposedFoldedMultiResultAffineApply(
100             rewriter, linalgLoc, shapeSizesToLoopsMap, allShapeSizes);
101 
102     // The offsets and sizes from the slice operation only give you the tile
103     // size of the output. Use that compute the tile sizes and offsets of the
104     // loops. For loops not used to access the output, set the tile sizes to
105     // loop bounds and set the offset to 0.
106     SmallVector<OpFoldResult> tileOffsets(sizeBounds.size(),
107                                           rewriter.getIndexAttr(0));
108     SmallVector<OpFoldResult> tileSizes = sizeBounds;
109     for (auto const &result : enumerate(indexingMap.getResults())) {
110       unsigned position = cast<AffineDimExpr>(result.value()).getPosition();
111       tileOffsets[position] = sliceOp.getMixedOffsets()[result.index()];
112       tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
113     }
114 
115     SmallVector<Value> valuesToTile = linalgOp->getOperands();
116     SmallVector<Value> tiledOperands =
117         makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile,
118                         tileOffsets, tileSizes, sizeBounds,
119                         /*omitPartialTileCheck=*/true);
120 
121     SmallVector<Type, 4> resultTensorTypes;
122     for (OpOperand &opOperand : linalgOp.getDpsInitsMutable())
123       resultTensorTypes.push_back(
124           tiledOperands[opOperand.getOperandNumber()].getType());
125 
126     Operation *newOp =
127         clone(rewriter, linalgOp, resultTensorTypes, tiledOperands);
128     rewriter.replaceOp(sliceOp, newOp->getResults());
129     return success();
130   }
131 };
132 } // namespace
133 
populateBubbleUpExtractSliceOpPatterns(RewritePatternSet & patterns)134 void mlir::linalg::populateBubbleUpExtractSliceOpPatterns(
135     RewritePatternSet &patterns) {
136   auto *context = patterns.getContext();
137   patterns.add<BubbleUpExtractSliceOpPattern>(context);
138 }
139