xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp (revision 0ac4821b718dd14e80d3856efa532d52df6878bb)
1*0ac4821bSJaved Absar //===- DecomposeGenericByUnfoldingPermutation.cpp                   -------===//
2*0ac4821bSJaved Absar //
3*0ac4821bSJaved Absar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*0ac4821bSJaved Absar // See https://llvm.org/LICENSE.txt for license information.
5*0ac4821bSJaved Absar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*0ac4821bSJaved Absar //
7*0ac4821bSJaved Absar //===----------------------------------------------------------------------===//
8*0ac4821bSJaved Absar //
9*0ac4821bSJaved Absar #include "mlir/Dialect/Affine/IR/AffineOps.h"
10*0ac4821bSJaved Absar #include "mlir/Dialect/Linalg/IR/Linalg.h"
11*0ac4821bSJaved Absar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
12*0ac4821bSJaved Absar #include <map>
13*0ac4821bSJaved Absar #include <optional>
14*0ac4821bSJaved Absar #include <utility>
15*0ac4821bSJaved Absar 
16*0ac4821bSJaved Absar using namespace mlir;
17*0ac4821bSJaved Absar using namespace mlir::linalg;
18*0ac4821bSJaved Absar 
19*0ac4821bSJaved Absar namespace {
20*0ac4821bSJaved Absar 
21*0ac4821bSJaved Absar /// This pattern decomposes the input operand(s) of a linalg.generic that has
22*0ac4821bSJaved Absar /// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose
23*0ac4821bSJaved Absar /// and broadcast. Having them folded into the linalg.generic is a good
24*0ac4821bSJaved Absar /// optimization but sometimes we may want to unwrap, i.e., `unfold` them as
25*0ac4821bSJaved Absar /// explicit transpose and broadcast. This rewrite pattern helps do it for
26*0ac4821bSJaved Absar /// each input operand. This is useful for instance when trying to recognize
27*0ac4821bSJaved Absar /// named ops.
28*0ac4821bSJaved Absar ///
29*0ac4821bSJaved Absar /// The transpose, broadcast, or mixture of both, are expressed in the affine
30*0ac4821bSJaved Absar /// map of the operand. Technically it is essentially `projected permutation`.
31*0ac4821bSJaved Absar ///
32*0ac4821bSJaved Absar ///  Example
33*0ac4821bSJaved Absar ///
34*0ac4821bSJaved Absar /// ```mlir
35*0ac4821bSJaved Absar ///
36*0ac4821bSJaved Absar /// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
37*0ac4821bSJaved Absar /// #identity   = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
38*0ac4821bSJaved Absar /// ...
39*0ac4821bSJaved Absar ///    %res = linalg.generic
40*0ac4821bSJaved Absar ///       { indexing_maps = [#projection, #identity, #identity],
41*0ac4821bSJaved Absar ///       iterator_types = ["parallel", "parallel", "parallel",
42*0ac4821bSJaved Absar ///                         "parallel", "parallel"]}
43*0ac4821bSJaved Absar ///       ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>)
44*0ac4821bSJaved Absar ///       outs(%z : tensor<5x9x7x8x10xf32>) {
45*0ac4821bSJaved Absar ///         ^bb0(%in: f32, %in_1: f32, %out: f32):
46*0ac4821bSJaved Absar ///              %div = arith.divf %in, %in_1 : f32
47*0ac4821bSJaved Absar ///              linalg.yield %div : f32
48*0ac4821bSJaved Absar ///    } -> tensor<5x9x7x8x10xf32>
49*0ac4821bSJaved Absar /// ```
50*0ac4821bSJaved Absar ///
51*0ac4821bSJaved Absar /// In the above IR operand `%x` map is a projected-permutation. This can be
52*0ac4821bSJaved Absar /// unfolded as:
53*0ac4821bSJaved Absar ///
54*0ac4821bSJaved Absar /// ```mlir
55*0ac4821bSJaved Absar ///   ...
56*0ac4821bSJaved Absar ///   %x_trans = linalg.transpose
57*0ac4821bSJaved Absar ///                   ins(%x : tensor<7x8x9xf32>)
58*0ac4821bSJaved Absar ///                   outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
59*0ac4821bSJaved Absar ///   ...
60*0ac4821bSJaved Absar ///   %x_trans_bc = linalg.broadcast
61*0ac4821bSJaved Absar ///                   ins(%x_trans : tensor<9x7x8xf32>)
62*0ac4821bSJaved Absar ///                   outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
63*0ac4821bSJaved Absar ///   %2 = linalg.div
64*0ac4821bSJaved Absar ///           ins(%x_trans_bc, %y :
65*0ac4821bSJaved Absar ///                  tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>)
66*0ac4821bSJaved Absar ///           outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
67*0ac4821bSJaved Absar ///
68*0ac4821bSJaved Absar /// Note that linalg.generic has been 'specialized' to linalg.div.
69*0ac4821bSJaved Absar ///
70*0ac4821bSJaved Absar /// To unfold it, it is more optimal to transpose first and then do the
71*0ac4821bSJaved Absar /// broadcast. However, if transpose is done first, the permutation map needs
72*0ac4821bSJaved Absar /// to be expressed in terms of reduced dimension as broadcast hasn't happened
73*0ac4821bSJaved Absar /// yet. Also, the broadcast dimensions in a linalg.generic come from other
74*0ac4821bSJaved Absar /// operands (those not broadcasted along that particular dimension). We work
75*0ac4821bSJaved Absar /// this out by computing the convex-polyhedron shape of the linalg.generic
76*0ac4821bSJaved Absar /// iteration space from shapes of all the operands, both inputs and outputs.
77*0ac4821bSJaved Absar ///
78*0ac4821bSJaved Absar struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> {
79*0ac4821bSJaved Absar   using OpRewritePattern<GenericOp>::OpRewritePattern;
80*0ac4821bSJaved Absar 
81*0ac4821bSJaved Absar   LogicalResult matchAndRewrite(GenericOp genericOp,
82*0ac4821bSJaved Absar                                 PatternRewriter &rewriter) const override;
83*0ac4821bSJaved Absar };
84*0ac4821bSJaved Absar 
85*0ac4821bSJaved Absar /// For the given `map`, determine what dimensions are transposed and what
86*0ac4821bSJaved Absar /// dimensions are broadcasted.
87*0ac4821bSJaved Absar /// Returns :
88*0ac4821bSJaved Absar ///   transpose-permutation, broadcast-dimensions` (empty if not needed)
89*0ac4821bSJaved Absar ///
90*0ac4821bSJaved Absar std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
91*0ac4821bSJaved Absar computeTransposeBroadcast(AffineMap &map) {
92*0ac4821bSJaved Absar   assert(map.isProjectedPermutation(false) && "not a projection");
93*0ac4821bSJaved Absar 
94*0ac4821bSJaved Absar   // As the map is a projection it likely operates on a smaller set of
95*0ac4821bSJaved Absar   // dimensions as far as the transpose is concerned (rest are broadcast).
96*0ac4821bSJaved Absar   int64_t minorSize = map.getNumResults();
97*0ac4821bSJaved Absar 
98*0ac4821bSJaved Absar   SmallVector<int64_t> minorResult;
99*0ac4821bSJaved Absar   for (int64_t i = 0; i < minorSize; ++i) {
100*0ac4821bSJaved Absar     auto expr = cast<AffineDimExpr>(map.getResults()[i]);
101*0ac4821bSJaved Absar     minorResult.push_back(expr.getPosition());
102*0ac4821bSJaved Absar   }
103*0ac4821bSJaved Absar 
104*0ac4821bSJaved Absar   // If dims are not monotonically increasing then transpose is present.
105*0ac4821bSJaved Absar   SmallVector<int64_t> sortedResMap(minorResult);
106*0ac4821bSJaved Absar   std::sort(sortedResMap.begin(), sortedResMap.end());
107*0ac4821bSJaved Absar   bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(),
108*0ac4821bSJaved Absar                                   sortedResMap.begin(), sortedResMap.end());
109*0ac4821bSJaved Absar 
110*0ac4821bSJaved Absar   // Walk the sorted map result to determine which dimensions are broadcasted.
111*0ac4821bSJaved Absar   SmallVector<int64_t> broadcast;
112*0ac4821bSJaved Absar   for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) {
113*0ac4821bSJaved Absar     if (j < minorSize && sortedResMap[j] == i) {
114*0ac4821bSJaved Absar       j++;
115*0ac4821bSJaved Absar       continue;
116*0ac4821bSJaved Absar     }
117*0ac4821bSJaved Absar     broadcast.push_back(i);
118*0ac4821bSJaved Absar   }
119*0ac4821bSJaved Absar 
120*0ac4821bSJaved Absar   SmallVector<int64_t> permutation;
121*0ac4821bSJaved Absar   if (hasTranspose) {
122*0ac4821bSJaved Absar     // Consider an operand `x : tensor<7x8x9>` of a genericOp that has
123*0ac4821bSJaved Absar     // affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
124*0ac4821bSJaved Absar     // `x`s access is both transposed and broadcast. But when specifying
125*0ac4821bSJaved Absar     // the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
126*0ac4821bSJaved Absar     // specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
127*0ac4821bSJaved Absar     // refering to d3, d4. Therefore, re-base the transpose dimensions so
128*0ac4821bSJaved Absar     // that they start from d0.
129*0ac4821bSJaved Absar     permutation.resize(minorSize);
130*0ac4821bSJaved Absar     std::map<int64_t, int64_t> minorMap;
131*0ac4821bSJaved Absar     for (int64_t i = 0; i < minorSize; ++i)
132*0ac4821bSJaved Absar       minorMap.insert({sortedResMap[i], i});
133*0ac4821bSJaved Absar 
134*0ac4821bSJaved Absar     // Re-map the dimensions.
135*0ac4821bSJaved Absar     SmallVector<int64_t> remappedResult(minorSize);
136*0ac4821bSJaved Absar     for (int64_t i = 0; i < minorSize; ++i)
137*0ac4821bSJaved Absar       remappedResult[i] = minorMap[minorResult[i]];
138*0ac4821bSJaved Absar 
139*0ac4821bSJaved Absar     /// Calculate the permutation for the transpose.
140*0ac4821bSJaved Absar     for (unsigned i = 0; i < minorSize; ++i) {
141*0ac4821bSJaved Absar       permutation[remappedResult[i]] = i;
142*0ac4821bSJaved Absar     }
143*0ac4821bSJaved Absar   }
144*0ac4821bSJaved Absar   return {permutation, broadcast};
145*0ac4821bSJaved Absar }
146*0ac4821bSJaved Absar 
147*0ac4821bSJaved Absar LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
148*0ac4821bSJaved Absar     GenericOp op, PatternRewriter &rewriter) const {
149*0ac4821bSJaved Absar   if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() ||
150*0ac4821bSJaved Absar       op.isSingleYieldOp() || !op.isAllParallelLoops())
151*0ac4821bSJaved Absar     return failure();
152*0ac4821bSJaved Absar 
153*0ac4821bSJaved Absar   // If the map of an operand is not a `projected permutation` then
154*0ac4821bSJaved Absar   // it cannot be decomposed to mere transpose and broadcast.
155*0ac4821bSJaved Absar   // The requirement that all maps be `projected permutation` may be
156*0ac4821bSJaved Absar   // over-restrictive but since we need to determine shape of the
157*0ac4821bSJaved Absar   // iteration space as well, reject if any map violates assumption.
158*0ac4821bSJaved Absar   for (auto &opOperand : op->getOpOperands()) {
159*0ac4821bSJaved Absar     auto map = op.getMatchingIndexingMap(&opOperand);
160*0ac4821bSJaved Absar     if (!map.isProjectedPermutation(false))
161*0ac4821bSJaved Absar       return failure();
162*0ac4821bSJaved Absar   }
163*0ac4821bSJaved Absar 
164*0ac4821bSJaved Absar   // Decomposing linalg.generic involves creating `tensor.empty`
165*0ac4821bSJaved Absar   // which can have dynamic shapes but then we would have to work
166*0ac4821bSJaved Absar   // out which operand can supply that runtime-value (tensor.dim).
167*0ac4821bSJaved Absar   // Leaving it as a future TODO.
168*0ac4821bSJaved Absar   if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) {
169*0ac4821bSJaved Absar         auto opType = cast<RankedTensorType>(oper.get().getType());
170*0ac4821bSJaved Absar         return ShapedType::isDynamicShape(opType.getShape());
171*0ac4821bSJaved Absar       }))
172*0ac4821bSJaved Absar     return failure();
173*0ac4821bSJaved Absar 
174*0ac4821bSJaved Absar   auto outputShape = op.getStaticLoopRanges();
175*0ac4821bSJaved Absar 
176*0ac4821bSJaved Absar   auto loc = op.getLoc();
177*0ac4821bSJaved Absar   bool isChanged = false;
178*0ac4821bSJaved Absar   SmallVector<Value> newInitValues = op.getDpsInputs();
179*0ac4821bSJaved Absar   SmallVector<AffineMap> newMap = op.getIndexingMapsArray();
180*0ac4821bSJaved Absar 
181*0ac4821bSJaved Absar   // Walk over each input operand and unfold if it is transposed, broadcast
182*0ac4821bSJaved Absar   // or mix of two via operand's affine-map.
183*0ac4821bSJaved Absar   for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
184*0ac4821bSJaved Absar     auto &map = newMap[i];
185*0ac4821bSJaved Absar     auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType());
186*0ac4821bSJaved Absar     auto elType = inputRTType.getElementType();
187*0ac4821bSJaved Absar 
188*0ac4821bSJaved Absar     /// Nothing to do if map is already an identity.
189*0ac4821bSJaved Absar     if (map.isIdentity())
190*0ac4821bSJaved Absar       continue;
191*0ac4821bSJaved Absar 
192*0ac4821bSJaved Absar     auto [permutation, broadcastedDims] = computeTransposeBroadcast(map);
193*0ac4821bSJaved Absar 
194*0ac4821bSJaved Absar     // Does it need transpose?
195*0ac4821bSJaved Absar     if (!permutation.empty()) {
196*0ac4821bSJaved Absar       /// linalg.transpose permutes the dimensions of input using
197*0ac4821bSJaved Absar       /// rule: dim(result, i) = dim(input, permutation[i])
198*0ac4821bSJaved Absar       SmallVector<int64_t> transposedShape(map.getNumResults());
199*0ac4821bSJaved Absar       for (int64_t i = 0; i < map.getNumResults(); ++i)
200*0ac4821bSJaved Absar         transposedShape[i] = inputRTType.getShape()[permutation[i]];
201*0ac4821bSJaved Absar 
202*0ac4821bSJaved Absar       Value emptyTensor =
203*0ac4821bSJaved Absar           rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType);
204*0ac4821bSJaved Absar 
205*0ac4821bSJaved Absar       auto transposeOp = rewriter.create<TransposeOp>(loc, newInitValues[i],
206*0ac4821bSJaved Absar                                                       emptyTensor, permutation);
207*0ac4821bSJaved Absar       newInitValues[i] = transposeOp->getResult(0);
208*0ac4821bSJaved Absar       isChanged = true;
209*0ac4821bSJaved Absar     }
210*0ac4821bSJaved Absar 
211*0ac4821bSJaved Absar     // Does it require broadcast?
212*0ac4821bSJaved Absar     if (!broadcastedDims.empty()) {
213*0ac4821bSJaved Absar       assert(broadcastedDims.size() && "should have non size broadcast");
214*0ac4821bSJaved Absar       Value emptyTensor = rewriter.create<tensor::EmptyOp>(
215*0ac4821bSJaved Absar           loc, outputShape, inputRTType.getElementType());
216*0ac4821bSJaved Absar 
217*0ac4821bSJaved Absar       auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
218*0ac4821bSJaved Absar           loc, newInitValues[i], emptyTensor, broadcastedDims);
219*0ac4821bSJaved Absar 
220*0ac4821bSJaved Absar       newInitValues[i] = broadcastOp->getResult(0);
221*0ac4821bSJaved Absar       isChanged = true;
222*0ac4821bSJaved Absar     }
223*0ac4821bSJaved Absar     newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
224*0ac4821bSJaved Absar   }
225*0ac4821bSJaved Absar 
226*0ac4821bSJaved Absar   if (isChanged) {
227*0ac4821bSJaved Absar     SmallVector<Value> operands = op->getOperands();
228*0ac4821bSJaved Absar     ValueRange operandsRef(operands);
229*0ac4821bSJaved Absar 
230*0ac4821bSJaved Absar     auto newOp = rewriter.create<linalg::GenericOp>(
231*0ac4821bSJaved Absar         /*location=*/op.getLoc(),
232*0ac4821bSJaved Absar         /*resultTensorTypes=*/op->getResultTypes(),
233*0ac4821bSJaved Absar         /*inputs=*/newInitValues,
234*0ac4821bSJaved Absar         /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
235*0ac4821bSJaved Absar         /*indexingMaps=*/newMap,
236*0ac4821bSJaved Absar         /*iteratorTypes=*/op.getIteratorTypesArray());
237*0ac4821bSJaved Absar 
238*0ac4821bSJaved Absar     newOp.getRegion().takeBody(op->getRegion(0));
239*0ac4821bSJaved Absar     rewriter.replaceOp(op, newOp->getResults());
240*0ac4821bSJaved Absar   }
241*0ac4821bSJaved Absar   return success();
242*0ac4821bSJaved Absar }
243*0ac4821bSJaved Absar 
244*0ac4821bSJaved Absar } // namespace
245*0ac4821bSJaved Absar 
246*0ac4821bSJaved Absar void mlir::linalg::populateDecomposeProjectedPermutationPatterns(
247*0ac4821bSJaved Absar     RewritePatternSet &patterns) {
248*0ac4821bSJaved Absar   patterns.insert<DecomposeProjectedPermutation>(patterns.getContext());
249*0ac4821bSJaved Absar }
250