xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp (revision f1aa7837884c745ede497e365cc75d5581ecc714)
1 //===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 #include <optional>
14 
15 using namespace mlir;
16 using namespace mlir::linalg;
17 
18 namespace {
19 
20 /// Pattern to decompose a GenericOp that has more than two statements
21 /// into one GenericOp with the first statement (i.e. peeled operation), and
22 /// a second GenericOp with the remaining statements (i.e. residual operations).
23 
24 /// - The result of the first GenericOp has the same shape as the iteration
25 ///   space of the GenericOp. The body of the op yields as many values as the
26 ///   original op plus all the results of the peeled operation.
27 /// - The second GenericOp has as many operands as the original operation plus
28 /// all the results of the first Generic Op. It has the same number of yields as
29 /// the original op.
30 /// - If the result of the peeled operation was yielded by the original
31 ///   GenericOp the uses of the corresponding results will be replaced with the
32 ///   result of the first GenericOp created.
33 ///
34 ///  Example
35 ///
36 /// ```mlir
37 ///  %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
38 ///      outs(%init0, %init1 : ...) {
39 ///    ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...):
40 ///      %0 = <s0> %b0, %b1 : ...
41 ///      %1 = <s1> %0, %b2 : ...
42 ///      linalg.yield %0, %1 : ...
43 ///  } -> (..., ...)
44 ///  return %result#0, %result#1
45 /// ```
46 ///
47 /// gets split into
48 ///
49 /// ```mlir
50 /// %init = tensor.empty ...
51 /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
52 ///      outs(%init0, %init1, %init : ...)
53 ///    ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
54 ///      %0 = <s0> %b0, %b1 : ...
55 ///      linalg.yield %0, %..., %0 : ...
56 ///  } -> (..., ..., ...)
57 /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...)
58 ///      outs(%init0, %init1 : ...) {
59 ///    ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
60 ///      %1 = <s1> %b3, %b2 : ...
61 ///      linalg.yield %..., %1 : ...
62 ///  } -> (..., ...)
63 ///  return %op0#0, %op1#1
64 /// ```
65 ///
66 /// After canonicalization this is expected to be
67 ///
68 /// ```mlir
69 /// %init = tensor.empty ...
70 /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...)
71 ///      outs(%init : ...)
72 ///    ^bb0(%b0: ... , %b1: ... , %b2: ...):
73 ///      %0 = <s0> %b0, %b1 : ...
74 ///      linalg.yield %0 : ...
75 ///  } -> ...
76 /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...)
77 ///      outs(%init1 : ...) {
78 ///    ^bb0(%b0: ... , %b1: ... , %b2: ...):
79 ///      %1 = <s1> %b1, %b0 : ...
80 ///      linalg.yield %..., %1 : ...
81 ///  } -> ...
82 ///  return %op0, %op1
83 /// ```
84 struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
85   using OpRewritePattern<GenericOp>::OpRewritePattern;
86 
87   LogicalResult matchAndRewrite(GenericOp genericOp,
88                                 PatternRewriter &rewriter) const override;
89 
90 private:
91   /// Helper method to create a generic op for the peeled scalar operation. The
92   /// created op has an empty region.
93   GenericOp createPeeledGenericOp(GenericOp genericOp,
94                                   PatternRewriter &rewriter) const;
95 
96   /// Helper method to create a generic op for the residual scalar operation.
97   /// The created op has the same region as the original op.
98   GenericOp createResidualGenericOp(GenericOp genericOp,
99                                     GenericOp peeledGenericOp,
100                                     PatternRewriter &rewriter) const;
101 };
102 } // namespace
103 
104 /// Helper method to compute the range of a generic op.
getGenericOpLoopRange(OpBuilder & b,GenericOp op)105 static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b,
106                                                        GenericOp op) {
107   OpBuilder::InsertionGuard g(b);
108   b.setInsertionPoint(op);
109   Location loc = op.getLoc();
110   auto allShapesSizes =
111       cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
112   AffineMap map = op.getShapesToLoopsMap();
113   IRRewriter rewriter(b);
114   return affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
115                                                           allShapesSizes);
116 }
117 
118 /// Helper method to permute the list of `values` based on the `map`.
permuteValues(ArrayRef<OpFoldResult> values,AffineMap map)119 SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values,
120                                         AffineMap map) {
121   assert(map.isPermutation());
122   SmallVector<OpFoldResult> permutedValues(values.size());
123   for (const auto &position :
124        llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) {
125          return cast<AffineDimExpr>(expr).getPosition();
126        })))
127     permutedValues[position.value()] = values[position.index()];
128   return permutedValues;
129 }
130 
131 /// Get zero value for an element type.
getZero(OpBuilder & b,Location loc,Type elementType)132 static Value getZero(OpBuilder &b, Location loc, Type elementType) {
133   assert(elementType.isIntOrIndexOrFloat() &&
134          "expected scalar type while computing zero value");
135   if (isa<IntegerType>(elementType))
136     return b.create<arith::ConstantIntOp>(loc, 0, elementType);
137   if (elementType.isIndex())
138     return b.create<arith::ConstantIndexOp>(loc, 0);
139   // Assume float.
140   auto floatType = cast<FloatType>(elementType);
141   return b.create<arith::ConstantFloatOp>(
142       loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
143 }
144 
145 GenericOp
createPeeledGenericOp(GenericOp genericOp,PatternRewriter & rewriter) const146 DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
147                                          PatternRewriter &rewriter) const {
148   Block *body = genericOp.getBody();
149   Operation *peeledScalarOperation = &(*body->begin());
150   SmallVector<AffineMap> peeledGenericOpIndexingMaps =
151       genericOp.getIndexingMapsArray();
152 
153   /// Compute the loop ranges for operation. This is the shape of the result of
154   /// the generic op for the peeled operation.
155   Location loc = genericOp.getLoc();
156   SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
157   SmallVector<Value> newInitValues;
158   SmallVector<Type> newResultTypes;
159 
160   // Add as many new results as the number of results of the peeled scalar op.
161   for (auto scalarOpResult : peeledScalarOperation->getResults()) {
162     // If the result is yielded by the original op, use the operand, indexing
163     // map and result type that correspond to the yielded value.
164 
165     std::optional<unsigned> resultNumber;
166     for (auto *user : scalarOpResult.getUsers()) {
167       if (auto yieldOp = dyn_cast<YieldOp>(user)) {
168         // Find the first use of the `scalarOpResult` in the yield op.
169         for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
170           if (yieldOperand.get() == scalarOpResult) {
171             resultNumber = yieldOperand.getOperandNumber();
172             break;
173           }
174         }
175         assert(resultNumber && "unable to find use of a value in its user");
176         break;
177       }
178     }
179     if (resultNumber) {
180       newInitValues.push_back(
181           genericOp.getDpsInitOperand(*resultNumber)->get());
182       OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber));
183       newResultTypes.push_back(result.getType());
184       peeledGenericOpIndexingMaps.push_back(
185           genericOp.getIndexingMapMatchingResult(result));
186       continue;
187     }
188 
189     // Fall back path, use an `init_tensor` and identity indexing map.
190     AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size());
191     Value emptyTensor =
192         rewriter.create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType());
193     newInitValues.push_back(emptyTensor);
194     newResultTypes.push_back(emptyTensor.getType());
195     peeledGenericOpIndexingMaps.push_back(indexingMap);
196   }
197 
198   /// Create the peeled generic op with an empty body.
199   SmallVector<Value> outsOperands = genericOp.getOutputs();
200   outsOperands.append(newInitValues.begin(), newInitValues.end());
201   SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
202   resultTypes.append(newResultTypes.begin(), newResultTypes.end());
203   auto indexingMapAttr =
204       rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
205   return rewriter.create<GenericOp>(
206       loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr,
207       genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
208       [](OpBuilder, Location, ValueRange) {});
209 }
210 
211 GenericOp
createResidualGenericOp(GenericOp genericOp,GenericOp peeledGenericOp,PatternRewriter & rewriter) const212 DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
213                                            GenericOp peeledGenericOp,
214                                            PatternRewriter &rewriter) const {
215   /// Append all results from the peeledGenericOps as `ins` operand for the
216   /// residual generic op.
217   SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
218   unsigned origNumResults = genericOp.getNumResults();
219   unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
220   SmallVector<Value> extraIns;
221   for (auto resultNum :
222        llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
223     extraIns.push_back(peeledGenericOp->getResult(resultNum));
224   residualGenericOpOperands.append(extraIns);
225 
226   /// Add indexing maps for the newly added operands. Use the same map
227   /// as those used for the new results of the peeledGenericOp.
228   auto indexingMaps = llvm::to_vector(
229       llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) {
230         return genericOp.getMatchingIndexingMap(operand);
231       }));
232   for (auto resultNum :
233        llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
234     OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum));
235     indexingMaps.push_back(
236         peeledGenericOp.getIndexingMapMatchingResult(result));
237   }
238   for (OpOperand &outOperand : genericOp.getDpsInitsMutable())
239     indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));
240 
241   auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
242   return rewriter.create<GenericOp>(
243       genericOp->getLoc(), genericOp->getResultTypes(),
244       residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
245       genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
246       [](OpBuilder, Location, ValueRange) {});
247 }
248 
249 LogicalResult
matchAndRewrite(GenericOp genericOp,PatternRewriter & rewriter) const250 DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
251                                    PatternRewriter &rewriter) const {
252   /// For now only match on operations where the iterator types are all parallel
253   if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
254     return rewriter.notifyMatchFailure(genericOp,
255                                        "unhandled decomposition of operation "
256                                        "with non-parallel iterator types");
257   }
258   // TODO: this could be generalized to handle `linalg.generic` with buffer
259   // operands too but requires allocation for intermediates. Punt on this for
260   // now.
261   if (!genericOp.hasPureTensorSemantics()) {
262     return rewriter.notifyMatchFailure(
263         genericOp, "only operations with tensor semantics are handled");
264   }
265 
266   if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) {
267         return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation();
268       })) {
269     return rewriter.notifyMatchFailure(
270         genericOp, "unhandled decomposition of generic op with out operand not "
271                    "accessed using a permutation");
272   }
273 
274   /// If the op has only a single statement (apart from the yield), do nothing.
275   Block *body = genericOp.getBody();
276   if (body->getOperations().size() <= 2) {
277     return rewriter.notifyMatchFailure(genericOp,
278                                        "operation has less than 3 statements");
279   }
280 
281   /// Check that the peeled statement has a scalar element type.
282   if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
283                    [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
284     return rewriter.notifyMatchFailure(
285         &(*body->getOperations().begin()),
286         "expected return type to be only int, index or float");
287   }
288 
289   GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
290   GenericOp residualGenericOp =
291       createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
292 
293   /// Move the first statement of the original operation into the body of the
294   /// generic op for the peeled operation.
295   Block *peeledGenericOpBody = peeledGenericOp.getBody();
296   Block *residualGenericOpBody = residualGenericOp.getBody();
297   assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
298          "expected split generic ops to have empty region");
299   peeledGenericOpBody->getOperations().splice(
300       peeledGenericOpBody->begin(), body->getOperations(), body->begin());
301   residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
302                                                 body->getOperations());
303 
304   Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
305   auto *yieldOp = residualGenericOpBody->getTerminator();
306   {
307     // Yield all the result of the peeled scalar operation.
308     OpBuilder::InsertionGuard g(rewriter);
309     rewriter.setInsertionPointToEnd(peeledGenericOpBody);
310     SmallVector<Value> yieldedVals;
311     for (auto origYield : yieldOp->getOperands()) {
312       if (origYield.getDefiningOp() == peeledScalarOperation) {
313         yieldedVals.push_back(origYield);
314       } else {
315         // Do not materialize any new ops inside of the decomposed LinalgOp,
316         // as that would trigger another application of the rewrite pattern
317         // (infinite loop).
318         OpBuilder::InsertionGuard g(rewriter);
319         rewriter.setInsertionPoint(peeledGenericOp);
320         yieldedVals.push_back(
321             getZero(rewriter, genericOp.getLoc(), origYield.getType()));
322       }
323     }
324     yieldedVals.append(llvm::to_vector(
325         llvm::map_range(peeledScalarOperation->getResults(),
326                         [](OpResult opr) -> Value { return opr; })));
327     rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals);
328   }
329 
330   /// In the split operations, replace block arguments uses that refer to
331   /// original operation to the block arguments of the newly created operation.
332   unsigned origNumInputs = genericOp.getNumDpsInputs();
333   for (const auto &inputBlockArg :
334        llvm::enumerate(genericOp.getBody()->getArguments())) {
335     Value residualOpReplacementArg =
336         residualGenericOpBody->getArgument(inputBlockArg.index());
337     rewriter.replaceUsesWithIf(
338         inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) {
339           return use.getOwner()->getBlock() == residualGenericOpBody;
340         });
341 
342     Value peeledOpReplacementArg =
343         peeledGenericOpBody->getArgument(inputBlockArg.index());
344     rewriter.replaceUsesWithIf(
345         inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) {
346           return use.getOwner()->getBlock() == peeledGenericOpBody;
347         });
348   }
349 
350   /// Before fixing up the residual operation, track what values are yielded. If
351   /// any of those are from the peeled scalar operation, the uses of the
352   /// corresponding result have to be remapped to result of the generic op for
353   /// the peeled operation.
354   SmallVector<Value> replacements;
355   for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
356     OpResult opr = dyn_cast<OpResult>(yieldValue.value());
357     if (!opr || opr.getOwner() != peeledScalarOperation)
358       replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
359     else
360       replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
361   }
362 
363   /// Update all uses of the peeled scalar operation results in the residual op
364   /// to the newly added arguments.
365   {
366     SmallVector<Value> scalarReplacements;
367     unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
368     scalarReplacements.reserve(peeledScalarOpNumResults);
369     for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
370       scalarReplacements.push_back(
371           residualGenericOpBody->getArgument(num + origNumInputs));
372     bool allUsesReplaced = false;
373     rewriter.replaceOpUsesWithinBlock(peeledScalarOperation, scalarReplacements,
374                                       residualGenericOpBody, &allUsesReplaced);
375     assert(!allUsesReplaced &&
376            "peeled scalar operation is erased when it wasnt expected to be");
377   }
378 
379   // Replace the original operation
380   rewriter.replaceOp(genericOp, replacements);
381   return success();
382 }
383 
populateDecomposeLinalgOpsPattern(RewritePatternSet & patterns,bool removeDeadArgsAndResults)384 void mlir::linalg::populateDecomposeLinalgOpsPattern(
385     RewritePatternSet &patterns, bool removeDeadArgsAndResults) {
386   patterns.insert<DecomposeLinalgOp>(patterns.getContext());
387   // Add the patterns to clean up the dead operands and results.
388   if (removeDeadArgsAndResults)
389     populateEraseUnusedOperandsAndResultsPatterns(patterns);
390 }
391