xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp (revision 5fcf907b34355980f77d7665a175b05fea7a6b7b)
1 //===- EraseUnusedOperandsAndResults.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10 
11 #include "mlir/Dialect/Linalg/IR/Linalg.h"
12 
13 using namespace mlir;
14 using namespace mlir::linalg;
15 
16 /// Return `true` if the `result` of an operation `genericOp` is dead.
isResultValueDead(linalg::GenericOp genericOp,OpResult result)17 static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
18   if (!result.use_empty())
19     return false;
20   // If out operand not used in payload, we can drop it.
21   OpOperand *outputOpOperand =
22       genericOp.getDpsInitOperand(result.getResultNumber());
23   if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
24     return true;
25 
26   // The out operand that is part of a payload can be dropped if
27   // these conditions are met:
28   // - Result from out operand is dead.
29   // - User of arg is yield.
30   // - outArg data is not being used by other outArgs.
31 
32   // Check block arg and cycle from out operand has a single use.
33   BlockArgument outputArg =
34       genericOp.getRegionOutputArgs()[result.getResultNumber()];
35   if (!outputArg.hasOneUse())
36     return false;
37   Operation *argUserOp = *outputArg.user_begin();
38 
39   // Check argUser has no other use.
40   if (!argUserOp->use_empty())
41     return false;
42 
43   // Check that argUser is a yield.
44   auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
45   if (!yieldOp)
46     return false;
47 
48   // Check outArg data is not being used by other outArgs.
49   if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
50     return false;
51 
52   return true;
53 }
54 
55 namespace {
56 
57 struct DeduplicateAndRemoveDeadOperandsAndResults
58     : public OpRewritePattern<GenericOp> {
DeduplicateAndRemoveDeadOperandsAndResults__anona48d586a0111::DeduplicateAndRemoveDeadOperandsAndResults59   DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,
60                                              bool removeOutputs)
61       : OpRewritePattern<GenericOp>(ctx), removeOutputs(removeOutputs) {}
62 
matchAndRewrite__anona48d586a0111::DeduplicateAndRemoveDeadOperandsAndResults63   LogicalResult matchAndRewrite(GenericOp genericOp,
64                                 PatternRewriter &rewriter) const override {
65     // Create a map from argument position in the original op to the argument
66     // position in the new op. If the argument is dropped it wont have an entry.
67     SmallVector<OpOperand *> droppedOpOperands;
68 
69     // Information needed to build the new op.
70     SmallVector<Value> newInputOperands, newOutputOperands;
71     SmallVector<AffineMap> newIndexingMaps;
72 
73     // Gather information about duplicate input operands.
74     llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
75         deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
76                                  newIndexingMaps);
77 
78     // Gather information about the dropped outputs.
79     llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
80         deduplicateOutputOperands(genericOp, droppedOpOperands,
81                                   newOutputOperands, newIndexingMaps);
82 
83     // Check if there is any change to operands.
84     if (newInputOperands.size() + newOutputOperands.size() ==
85         genericOp->getNumOperands())
86       return failure();
87 
88     // Create the new op with the body being empty.
89     Location loc = genericOp.getLoc();
90     SmallVector<Type> newResultTypes;
91     for (Value v : newOutputOperands)
92       if (isa<TensorType>(v.getType()))
93         newResultTypes.push_back(v.getType());
94     auto newOp = rewriter.create<GenericOp>(
95         loc, newResultTypes, newInputOperands, newOutputOperands,
96         rewriter.getAffineMapArrayAttr(newIndexingMaps),
97         genericOp.getIteratorTypes(), genericOp.getDocAttr(),
98         genericOp.getLibraryCallAttr(),
99         [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
100           return;
101         });
102     // Copy over unknown attributes. They might be load bearing for some flow.
103     ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
104     for (NamedAttribute kv : genericOp->getAttrs())
105       if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
106         newOp->setAttr(kv.getName(), kv.getValue());
107 
108     // Fix up the payload of the canonicalized operation.
109     populateOpPayload(genericOp, newOp, origInsToNewInsPos,
110                       origOutsToNewOutsPos, rewriter);
111 
112     // Replace all live uses of the op.
113     SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
114     for (const auto &result : llvm::enumerate(genericOp.getResults())) {
115       auto it = origOutsToNewOutsPos.find(result.index());
116       if (it == origOutsToNewOutsPos.end())
117         continue;
118       replacementsVals[result.index()] = newOp.getResult(it->second);
119     }
120     rewriter.replaceOp(genericOp, replacementsVals);
121     return success();
122   }
123 
124 private:
125   /// If unset, outputs are not modified by this pattern.
126   bool removeOutputs;
127 
128   // Deduplicate input operands, and return the
129   // - Mapping from operand position in the original op, to operand position in
130   // the canonicalized op.
131   // - The preserved input operands list (by reference).
132   llvm::SmallDenseMap<unsigned, unsigned>
deduplicateInputOperands__anona48d586a0111::DeduplicateAndRemoveDeadOperandsAndResults133   deduplicateInputOperands(GenericOp genericOp,
134                            SmallVector<OpOperand *> &droppedOpOperands,
135                            SmallVector<Value> &newInputOperands,
136                            SmallVector<AffineMap> &newIndexingMaps) const {
137     llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
138     llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
139     for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
140       OpOperand *inputOpOperand = en.value();
141       // Check if operand is dead and if dropping the indexing map makes the
142       // loops to shape computation invalid.
143       if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
144         // Add the current operands to the list of potentially droppable
145         // operands. If it cannot be dropped, this needs to be popped back.
146         droppedOpOperands.push_back(inputOpOperand);
147         if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
148           continue;
149         droppedOpOperands.pop_back();
150       }
151 
152       // Check if this operand is a duplicate.
153       AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
154       auto it = dedupedInputs.find(
155           std::make_pair(inputOpOperand->get(), indexingMap));
156       if (it != dedupedInputs.end()) {
157         origToNewPos[en.index()] = it->second;
158         droppedOpOperands.push_back(inputOpOperand);
159         continue;
160       }
161 
162       // This is a preserved argument.
163       origToNewPos[en.index()] = newInputOperands.size();
164       dedupedInputs[{inputOpOperand->get(), indexingMap}] =
165           newInputOperands.size();
166       newInputOperands.push_back(inputOpOperand->get());
167       newIndexingMaps.push_back(indexingMap);
168     }
169     return origToNewPos;
170   }
171 
172   // Deduplicate output operands, and return the
173   // - Mapping from operand position in the original op, to operand position in
174   // the canonicalized op.
175   // - The preserved output operands list (by reference).
176   llvm::SmallDenseMap<unsigned, unsigned>
deduplicateOutputOperands__anona48d586a0111::DeduplicateAndRemoveDeadOperandsAndResults177   deduplicateOutputOperands(GenericOp genericOp,
178                             SmallVector<OpOperand *> &droppedOpOperands,
179                             SmallVector<Value> &newOutputOperands,
180                             SmallVector<AffineMap> &newIndexingMaps) const {
181     llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
182     llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
183         dedupedOutpts;
184     // If the op doesn't have tensor semantics or outputs should not be removed,
185     // keep all the outputs as preserved.
186     if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
187       for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
188         origToNewPos[en.index()] = newOutputOperands.size();
189         newOutputOperands.push_back(en.value().get());
190         newIndexingMaps.push_back(
191             genericOp.getMatchingIndexingMap(&en.value()));
192       }
193       return origToNewPos;
194     }
195     // Output argument can be dropped if the result has
196     // - no users, and
197     // - it is not used in the payload, and
198     // - the corresponding indexing maps are not needed for loop bound
199     //   computation.
200     auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
201     for (const auto &outputOpOperand :
202          llvm::enumerate(genericOp.getDpsInitsMutable())) {
203       OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
204       AffineMap indexingMap =
205           genericOp.getMatchingIndexingMap(&outputOpOperand.value());
206       auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
207                                  yieldOp->getOperand(outputOpOperand.index()));
208       if (isResultValueDead(genericOp, result)) {
209         // Check if the opoperand can be dropped without affecting loop
210         // bound computation. Add the operand to the list of dropped op
211         // operand for checking. If it cannot be dropped, need to pop the
212         // value back.
213         droppedOpOperands.push_back(&outputOpOperand.value());
214         if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
215           continue;
216         }
217         droppedOpOperands.pop_back();
218       }
219 
220       if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
221         // The out operand can also be dropped if it is computed redundantly
222         // by another result, the conditions for that are
223         // - The same operand is used as the out operand
224         // - The same indexing map is used
225         // - The same yield value is used.
226         auto it = dedupedOutpts.find(key);
227         if (it != dedupedOutpts.end()) {
228           origToNewPos[outputOpOperand.index()] = it->second;
229           droppedOpOperands.push_back(&outputOpOperand.value());
230           continue;
231         }
232       }
233 
234       origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
235       dedupedOutpts[key] = newOutputOperands.size();
236       newOutputOperands.push_back(outputOpOperand.value().get());
237       newIndexingMaps.push_back(
238           genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
239     }
240     return origToNewPos;
241   }
242 
243   // Populate the body of the canonicalized operation.
populateOpPayload__anona48d586a0111::DeduplicateAndRemoveDeadOperandsAndResults244   void populateOpPayload(
245       GenericOp genericOp, GenericOp newOp,
246       const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
247       const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
248       PatternRewriter &rewriter) const {
249     // Merge the body of the original op with the new op.
250     Block *newOpBlock = &newOp.getRegion().front();
251     assert(newOpBlock->empty() && "expected new op to have an empty payload");
252     Block *origOpBlock = &genericOp.getRegion().front();
253     SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
254 
255     // Replace all arguments in the original op, with arguments from the
256     // canonicalized op.
257     auto updateReplacements =
258         [&](SmallVector<OpOperand *> &origOperands,
259             SmallVector<OpOperand *> &newOperands,
260             const llvm::SmallDenseMap<unsigned, unsigned> &map) {
261           for (const auto &origOperand : llvm::enumerate(origOperands)) {
262             auto it = map.find(origOperand.index());
263             if (it == map.end())
264               continue;
265             OpOperand *newOperand = newOperands[it->second];
266             replacements[origOperand.value()->getOperandNumber()] =
267                 newOpBlock->getArgument(newOperand->getOperandNumber());
268           }
269         };
270 
271     SmallVector<OpOperand *> origInputOperands =
272         genericOp.getDpsInputOperands();
273     SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
274     updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
275 
276     SmallVector<OpOperand *> origOutputOperands =
277         llvm::to_vector(llvm::map_range(genericOp.getDpsInitsMutable(),
278                                         [](OpOperand &o) { return &o; }));
279     SmallVector<OpOperand *> newOutputOperands =
280         llvm::to_vector(llvm::map_range(newOp.getDpsInitsMutable(),
281                                         [](OpOperand &o) { return &o; }));
282     updateReplacements(origOutputOperands, newOutputOperands,
283                        origOutsToNewOutsPos);
284 
285     // Drop the unused yield args.
286     if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
287       OpBuilder::InsertionGuard g(rewriter);
288       YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
289       rewriter.setInsertionPoint(origYieldOp);
290 
291       SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
292       for (const auto &yieldOpOperands :
293            llvm::enumerate(origYieldOp.getValues())) {
294         auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
295         if (it == origOutsToNewOutsPos.end())
296           continue;
297         newYieldVals[it->second] = yieldOpOperands.value();
298       }
299       rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
300     }
301 
302     rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
303   }
304 };
305 
306 /// Remove unused cycles.
307 /// We can remove unused cycle within a payload of generic region
308 /// if these conditions are met:
309 /// - Result from out operand is dead.
310 /// - Block arg from out operand has a single use in the %cycle
311 /// instruction.
312 /// - Cycle has a single use and it is in yield.
313 struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
314   using OpRewritePattern<GenericOp>::OpRewritePattern;
315 
matchAndRewrite__anona48d586a0111::RemoveUnusedCycleInGenericOp316   LogicalResult matchAndRewrite(GenericOp genericOp,
317                                 PatternRewriter &rewriter) const override {
318 
319     // If the op doesnt have tensor semantics, preserve the outputs as is.
320     if (!genericOp.hasPureTensorSemantics())
321       return failure();
322 
323     bool hasRemovedCycles = false;
324     // Iterate over output operands and remove any unused cycles.
325     for (const auto &outputOpOperand :
326          llvm::enumerate(genericOp.getDpsInits())) {
327 
328       // Check that result from out operand is dead.
329       Value result = genericOp.getResult(outputOpOperand.index());
330       if (!result.use_empty())
331         continue;
332 
333       // Check that outputArg has one use in cycle.
334       BlockArgument outputArg =
335           genericOp.getRegionOutputArgs()[outputOpOperand.index()];
336       if (!outputArg.hasOneUse())
337         continue;
338 
339       // Check cycle has at most one use.
340       Operation *cycleOp = *outputArg.user_begin();
341       if (!cycleOp->hasOneUse())
342         continue;
343 
344       // Check that the cycleUser is a yield.
345       Operation *cycleUserOp = *cycleOp->user_begin();
346       if (!isa<linalg::YieldOp>(cycleUserOp))
347         continue;
348 
349       // Check that argIndex matches yieldIndex, else data is being used.
350       if (cycleUserOp->getOperand(outputOpOperand.index()) !=
351           cycleOp->getResult(0))
352         continue;
353 
354       // Directly replace the cycle with the blockArg such that
355       // Deduplicate pattern can eliminate it along with unused yield.
356       rewriter.replaceOp(cycleOp, outputArg);
357       rewriter.modifyOpInPlace(genericOp, [] {});
358       hasRemovedCycles = true;
359     }
360 
361     if (hasRemovedCycles) {
362       return success();
363     }
364 
365     return failure();
366   }
367 };
368 
369 /// Fold uses of duplicate inputs in the body of a linalg.generic. E.g.:
370 /// ```
371 /// linalg.generic ins(%a, %b, %a, %b) outs(%a)
372 /// ^bb0(%in0, %in1, %in2, %in3, %out1)
373 /// ```
374 /// Assuming that all %a and %b have the same index map:
375 /// * All uses of %in0 and %in2 are replaced with %out1
376 /// * All uses of %in1 are replaced with %in3
377 /// This pattern can enable additional canonicalizations: In the above example,
378 /// %in0, %in1 and %in3 have no uses anymore and their corresponding operands
379 /// can be folded away. This pattern does not modify uses of output block args.
380 struct FoldDuplicateInputBbArgs : public OpRewritePattern<GenericOp> {
381   using OpRewritePattern<GenericOp>::OpRewritePattern;
382 
matchAndRewrite__anona48d586a0111::FoldDuplicateInputBbArgs383   LogicalResult matchAndRewrite(GenericOp genericOp,
384                                 PatternRewriter &rewriter) const override {
385     // Find replacement bbArgs for all input bbArg.
386     DenseMap<int, int> replacements;
387     for (int i = 0; i < genericOp.getNumDpsInputs(); ++i) {
388       // Skip bbArgs that have no uses.
389       if (genericOp.getBody()->getArgument(i).getUses().empty())
390         continue;
391       // Find replacement bbArg. This can be an input or an output bbArg.
392       for (int j = genericOp->getNumOperands() - 1; j > i; --j) {
393         if (genericOp->getOperand(i) == genericOp->getOperand(j) &&
394             genericOp.getIndexingMapsArray()[i] ==
395                 genericOp.getIndexingMapsArray()[j]) {
396           replacements[i] = j;
397           break;
398         }
399       }
400     }
401 
402     // Stop here if no replacements were found.
403     if (replacements.empty())
404       return failure();
405 
406     // Rewrite the op.
407     rewriter.modifyOpInPlace(genericOp, [&]() {
408       for (auto [before, after] : replacements) {
409         BlockArgument bbArg = genericOp.getBody()->getArgument(before);
410         BlockArgument replacement = genericOp.getBody()->getArgument(after);
411         rewriter.replaceAllUsesWith(bbArg, replacement);
412       }
413     });
414 
415     return success();
416   }
417 };
418 
419 } // namespace
420 
populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet & patterns)421 void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns(
422     RewritePatternSet &patterns) {
423   patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
424       patterns.getContext(), /*removeOutputs=*/true);
425   patterns.insert<RemoveUnusedCycleInGenericOp>(patterns.getContext());
426 }
427 
populateEraseUnnecessaryInputsPatterns(RewritePatternSet & patterns)428 void mlir::linalg::populateEraseUnnecessaryInputsPatterns(
429     RewritePatternSet &patterns) {
430   patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
431       patterns.getContext(), /*removeOutputs=*/false);
432   patterns.insert<FoldDuplicateInputBbArgs>(patterns.getContext());
433 }
434