1 //===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 #include <optional>
14
15 using namespace mlir;
16 using namespace mlir::linalg;
17
18 namespace {
19
20 /// Pattern to decompose a GenericOp that has more than two statements
21 /// into one GenericOp with the first statement (i.e. peeled operation), and
22 /// a second GenericOp with the remaining statements (i.e. residual operations).
23
24 /// - The result of the first GenericOp has the same shape as the iteration
25 /// space of the GenericOp. The body of the op yields as many values as the
26 /// original op plus all the results of the peeled operation.
27 /// - The second GenericOp has as many operands as the original operation plus
28 /// all the results of the first Generic Op. It has the same number of yields as
29 /// the original op.
30 /// - If the result of the peeled operation was yielded by the original
31 /// GenericOp the uses of the corresponding results will be replaced with the
32 /// result of the first GenericOp created.
33 ///
34 /// Example
35 ///
36 /// ```mlir
37 /// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
38 /// outs(%init0, %init1 : ...) {
39 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...):
40 /// %0 = <s0> %b0, %b1 : ...
41 /// %1 = <s1> %0, %b2 : ...
42 /// linalg.yield %0, %1 : ...
43 /// } -> (..., ...)
44 /// return %result#0, %result#1
45 /// ```
46 ///
47 /// gets split into
48 ///
49 /// ```mlir
50 /// %init = tensor.empty ...
51 /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
52 /// outs(%init0, %init1, %init : ...)
53 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
54 /// %0 = <s0> %b0, %b1 : ...
55 /// linalg.yield %0, %..., %0 : ...
56 /// } -> (..., ..., ...)
57 /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...)
58 /// outs(%init0, %init1 : ...) {
59 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
60 /// %1 = <s1> %b3, %b2 : ...
61 /// linalg.yield %..., %1 : ...
62 /// } -> (..., ...)
63 /// return %op0#0, %op1#1
64 /// ```
65 ///
66 /// After canonicalization this is expected to be
67 ///
68 /// ```mlir
69 /// %init = tensor.empty ...
70 /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...)
71 /// outs(%init : ...)
72 /// ^bb0(%b0: ... , %b1: ... , %b2: ...):
73 /// %0 = <s0> %b0, %b1 : ...
74 /// linalg.yield %0 : ...
75 /// } -> ...
76 /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...)
77 /// outs(%init1 : ...) {
78 /// ^bb0(%b0: ... , %b1: ... , %b2: ...):
79 /// %1 = <s1> %b1, %b0 : ...
80 /// linalg.yield %..., %1 : ...
81 /// } -> ...
82 /// return %op0, %op1
83 /// ```
84 struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
85 using OpRewritePattern<GenericOp>::OpRewritePattern;
86
87 LogicalResult matchAndRewrite(GenericOp genericOp,
88 PatternRewriter &rewriter) const override;
89
90 private:
91 /// Helper method to create a generic op for the peeled scalar operation. The
92 /// created op has an empty region.
93 GenericOp createPeeledGenericOp(GenericOp genericOp,
94 PatternRewriter &rewriter) const;
95
96 /// Helper method to create a generic op for the residual scalar operation.
97 /// The created op has the same region as the original op.
98 GenericOp createResidualGenericOp(GenericOp genericOp,
99 GenericOp peeledGenericOp,
100 PatternRewriter &rewriter) const;
101 };
102 } // namespace
103
104 /// Helper method to compute the range of a generic op.
getGenericOpLoopRange(OpBuilder & b,GenericOp op)105 static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b,
106 GenericOp op) {
107 OpBuilder::InsertionGuard g(b);
108 b.setInsertionPoint(op);
109 Location loc = op.getLoc();
110 auto allShapesSizes =
111 cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
112 AffineMap map = op.getShapesToLoopsMap();
113 IRRewriter rewriter(b);
114 return affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
115 allShapesSizes);
116 }
117
118 /// Helper method to permute the list of `values` based on the `map`.
permuteValues(ArrayRef<OpFoldResult> values,AffineMap map)119 SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values,
120 AffineMap map) {
121 assert(map.isPermutation());
122 SmallVector<OpFoldResult> permutedValues(values.size());
123 for (const auto &position :
124 llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) {
125 return cast<AffineDimExpr>(expr).getPosition();
126 })))
127 permutedValues[position.value()] = values[position.index()];
128 return permutedValues;
129 }
130
131 /// Get zero value for an element type.
getZero(OpBuilder & b,Location loc,Type elementType)132 static Value getZero(OpBuilder &b, Location loc, Type elementType) {
133 assert(elementType.isIntOrIndexOrFloat() &&
134 "expected scalar type while computing zero value");
135 if (isa<IntegerType>(elementType))
136 return b.create<arith::ConstantIntOp>(loc, 0, elementType);
137 if (elementType.isIndex())
138 return b.create<arith::ConstantIndexOp>(loc, 0);
139 // Assume float.
140 auto floatType = cast<FloatType>(elementType);
141 return b.create<arith::ConstantFloatOp>(
142 loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
143 }
144
145 GenericOp
createPeeledGenericOp(GenericOp genericOp,PatternRewriter & rewriter) const146 DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
147 PatternRewriter &rewriter) const {
148 Block *body = genericOp.getBody();
149 Operation *peeledScalarOperation = &(*body->begin());
150 SmallVector<AffineMap> peeledGenericOpIndexingMaps =
151 genericOp.getIndexingMapsArray();
152
153 /// Compute the loop ranges for operation. This is the shape of the result of
154 /// the generic op for the peeled operation.
155 Location loc = genericOp.getLoc();
156 SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
157 SmallVector<Value> newInitValues;
158 SmallVector<Type> newResultTypes;
159
160 // Add as many new results as the number of results of the peeled scalar op.
161 for (auto scalarOpResult : peeledScalarOperation->getResults()) {
162 // If the result is yielded by the original op, use the operand, indexing
163 // map and result type that correspond to the yielded value.
164
165 std::optional<unsigned> resultNumber;
166 for (auto *user : scalarOpResult.getUsers()) {
167 if (auto yieldOp = dyn_cast<YieldOp>(user)) {
168 // Find the first use of the `scalarOpResult` in the yield op.
169 for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
170 if (yieldOperand.get() == scalarOpResult) {
171 resultNumber = yieldOperand.getOperandNumber();
172 break;
173 }
174 }
175 assert(resultNumber && "unable to find use of a value in its user");
176 break;
177 }
178 }
179 if (resultNumber) {
180 newInitValues.push_back(
181 genericOp.getDpsInitOperand(*resultNumber)->get());
182 OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber));
183 newResultTypes.push_back(result.getType());
184 peeledGenericOpIndexingMaps.push_back(
185 genericOp.getIndexingMapMatchingResult(result));
186 continue;
187 }
188
189 // Fall back path, use an `init_tensor` and identity indexing map.
190 AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size());
191 Value emptyTensor =
192 rewriter.create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType());
193 newInitValues.push_back(emptyTensor);
194 newResultTypes.push_back(emptyTensor.getType());
195 peeledGenericOpIndexingMaps.push_back(indexingMap);
196 }
197
198 /// Create the peeled generic op with an empty body.
199 SmallVector<Value> outsOperands = genericOp.getOutputs();
200 outsOperands.append(newInitValues.begin(), newInitValues.end());
201 SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
202 resultTypes.append(newResultTypes.begin(), newResultTypes.end());
203 auto indexingMapAttr =
204 rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
205 return rewriter.create<GenericOp>(
206 loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr,
207 genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
208 [](OpBuilder, Location, ValueRange) {});
209 }
210
211 GenericOp
createResidualGenericOp(GenericOp genericOp,GenericOp peeledGenericOp,PatternRewriter & rewriter) const212 DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
213 GenericOp peeledGenericOp,
214 PatternRewriter &rewriter) const {
215 /// Append all results from the peeledGenericOps as `ins` operand for the
216 /// residual generic op.
217 SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
218 unsigned origNumResults = genericOp.getNumResults();
219 unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
220 SmallVector<Value> extraIns;
221 for (auto resultNum :
222 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
223 extraIns.push_back(peeledGenericOp->getResult(resultNum));
224 residualGenericOpOperands.append(extraIns);
225
226 /// Add indexing maps for the newly added operands. Use the same map
227 /// as those used for the new results of the peeledGenericOp.
228 auto indexingMaps = llvm::to_vector(
229 llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) {
230 return genericOp.getMatchingIndexingMap(operand);
231 }));
232 for (auto resultNum :
233 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
234 OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum));
235 indexingMaps.push_back(
236 peeledGenericOp.getIndexingMapMatchingResult(result));
237 }
238 for (OpOperand &outOperand : genericOp.getDpsInitsMutable())
239 indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));
240
241 auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
242 return rewriter.create<GenericOp>(
243 genericOp->getLoc(), genericOp->getResultTypes(),
244 residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
245 genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
246 [](OpBuilder, Location, ValueRange) {});
247 }
248
249 LogicalResult
matchAndRewrite(GenericOp genericOp,PatternRewriter & rewriter) const250 DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
251 PatternRewriter &rewriter) const {
252 /// For now only match on operations where the iterator types are all parallel
253 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
254 return rewriter.notifyMatchFailure(genericOp,
255 "unhandled decomposition of operation "
256 "with non-parallel iterator types");
257 }
258 // TODO: this could be generalized to handle `linalg.generic` with buffer
259 // operands too but requires allocation for intermediates. Punt on this for
260 // now.
261 if (!genericOp.hasPureTensorSemantics()) {
262 return rewriter.notifyMatchFailure(
263 genericOp, "only operations with tensor semantics are handled");
264 }
265
266 if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) {
267 return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation();
268 })) {
269 return rewriter.notifyMatchFailure(
270 genericOp, "unhandled decomposition of generic op with out operand not "
271 "accessed using a permutation");
272 }
273
274 /// If the op has only a single statement (apart from the yield), do nothing.
275 Block *body = genericOp.getBody();
276 if (body->getOperations().size() <= 2) {
277 return rewriter.notifyMatchFailure(genericOp,
278 "operation has less than 3 statements");
279 }
280
281 /// Check that the peeled statement has a scalar element type.
282 if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
283 [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
284 return rewriter.notifyMatchFailure(
285 &(*body->getOperations().begin()),
286 "expected return type to be only int, index or float");
287 }
288
289 GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
290 GenericOp residualGenericOp =
291 createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
292
293 /// Move the first statement of the original operation into the body of the
294 /// generic op for the peeled operation.
295 Block *peeledGenericOpBody = peeledGenericOp.getBody();
296 Block *residualGenericOpBody = residualGenericOp.getBody();
297 assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
298 "expected split generic ops to have empty region");
299 peeledGenericOpBody->getOperations().splice(
300 peeledGenericOpBody->begin(), body->getOperations(), body->begin());
301 residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
302 body->getOperations());
303
304 Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
305 auto *yieldOp = residualGenericOpBody->getTerminator();
306 {
307 // Yield all the result of the peeled scalar operation.
308 OpBuilder::InsertionGuard g(rewriter);
309 rewriter.setInsertionPointToEnd(peeledGenericOpBody);
310 SmallVector<Value> yieldedVals;
311 for (auto origYield : yieldOp->getOperands()) {
312 if (origYield.getDefiningOp() == peeledScalarOperation) {
313 yieldedVals.push_back(origYield);
314 } else {
315 // Do not materialize any new ops inside of the decomposed LinalgOp,
316 // as that would trigger another application of the rewrite pattern
317 // (infinite loop).
318 OpBuilder::InsertionGuard g(rewriter);
319 rewriter.setInsertionPoint(peeledGenericOp);
320 yieldedVals.push_back(
321 getZero(rewriter, genericOp.getLoc(), origYield.getType()));
322 }
323 }
324 yieldedVals.append(llvm::to_vector(
325 llvm::map_range(peeledScalarOperation->getResults(),
326 [](OpResult opr) -> Value { return opr; })));
327 rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals);
328 }
329
330 /// In the split operations, replace block arguments uses that refer to
331 /// original operation to the block arguments of the newly created operation.
332 unsigned origNumInputs = genericOp.getNumDpsInputs();
333 for (const auto &inputBlockArg :
334 llvm::enumerate(genericOp.getBody()->getArguments())) {
335 Value residualOpReplacementArg =
336 residualGenericOpBody->getArgument(inputBlockArg.index());
337 rewriter.replaceUsesWithIf(
338 inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) {
339 return use.getOwner()->getBlock() == residualGenericOpBody;
340 });
341
342 Value peeledOpReplacementArg =
343 peeledGenericOpBody->getArgument(inputBlockArg.index());
344 rewriter.replaceUsesWithIf(
345 inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) {
346 return use.getOwner()->getBlock() == peeledGenericOpBody;
347 });
348 }
349
350 /// Before fixing up the residual operation, track what values are yielded. If
351 /// any of those are from the peeled scalar operation, the uses of the
352 /// corresponding result have to be remapped to result of the generic op for
353 /// the peeled operation.
354 SmallVector<Value> replacements;
355 for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
356 OpResult opr = dyn_cast<OpResult>(yieldValue.value());
357 if (!opr || opr.getOwner() != peeledScalarOperation)
358 replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
359 else
360 replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
361 }
362
363 /// Update all uses of the peeled scalar operation results in the residual op
364 /// to the newly added arguments.
365 {
366 SmallVector<Value> scalarReplacements;
367 unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
368 scalarReplacements.reserve(peeledScalarOpNumResults);
369 for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
370 scalarReplacements.push_back(
371 residualGenericOpBody->getArgument(num + origNumInputs));
372 bool allUsesReplaced = false;
373 rewriter.replaceOpUsesWithinBlock(peeledScalarOperation, scalarReplacements,
374 residualGenericOpBody, &allUsesReplaced);
375 assert(!allUsesReplaced &&
376 "peeled scalar operation is erased when it wasnt expected to be");
377 }
378
379 // Replace the original operation
380 rewriter.replaceOp(genericOp, replacements);
381 return success();
382 }
383
populateDecomposeLinalgOpsPattern(RewritePatternSet & patterns,bool removeDeadArgsAndResults)384 void mlir::linalg::populateDecomposeLinalgOpsPattern(
385 RewritePatternSet &patterns, bool removeDeadArgsAndResults) {
386 patterns.insert<DecomposeLinalgOp>(patterns.getContext());
387 // Add the patterns to clean up the dead operands and results.
388 if (removeDeadArgsAndResults)
389 populateEraseUnusedOperandsAndResultsPatterns(patterns);
390 }
391