xref: /llvm-project/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for testing fusion of elementwise operations in
10 // Linalg, mainly linalg options.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 using namespace mlir;
23 
24 static void addOperands(Operation *op, SetVector<Value> &operandSet) {
25   if (!op)
26     return;
27   TypeSwitch<Operation *, void>(op)
28       .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
29         SmallVector<Value> inputOperands = linalgOp.getDpsInputs();
30         operandSet.insert(inputOperands.begin(), inputOperands.end());
31       })
32       .Default([&](Operation *operation) {
33         operandSet.insert(operation->operand_begin(), operation->operand_end());
34       });
35 }
36 
37 template <int limit = 3>
38 static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
39   Operation *producer = fusedOperand->get().getDefiningOp();
40   if (!producer)
41     return false;
42 
43   Operation *consumer = fusedOperand->getOwner();
44   SetVector<Value> fusedOpOperands;
45   if (producer->getNumResults() != 1)
46     return false;
47   addOperands(consumer, fusedOpOperands);
48   fusedOpOperands.remove(producer->getResult(0));
49   addOperands(producer, fusedOpOperands);
50   return fusedOpOperands.size() <= limit;
51 }
52 
53 namespace {
54 
55 /// Pattern to test fusion of producer with consumer, even if producer has
56 /// multiple uses.
57 struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
58   using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
59 
60   LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
61                                 PatternRewriter &rewriter) const override {
62     OpOperand *fusableOperand = nullptr;
63     for (OpOperand &operand : genericOp->getOpOperands()) {
64       if (linalg::areElementwiseOpsFusable(&operand)) {
65         fusableOperand = &operand;
66         break;
67       }
68     }
69     if (!fusableOperand) {
70       return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
71     }
72     std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
73         linalg::fuseElementwiseOps(rewriter, fusableOperand);
74     if (!fusionResult)
75       return rewriter.notifyMatchFailure(genericOp, "fusion failed");
76     for (auto [origValue, replacement] : fusionResult->replacements) {
77       rewriter.replaceUsesWithIf(origValue, replacement, [&](OpOperand &use) {
78         return use.getOwner() != genericOp.getOperation();
79       });
80     }
81     rewriter.eraseOp(genericOp);
82     return success();
83   }
84 };
85 
86 struct TestLinalgElementwiseFusion
87     : public PassWrapper<TestLinalgElementwiseFusion,
88                          OperationPass<func::FuncOp>> {
89   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion)
90 
91   TestLinalgElementwiseFusion() = default;
92   TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass)
93       : PassWrapper(pass) {}
94   void getDependentDialects(DialectRegistry &registry) const override {
95     registry.insert<affine::AffineDialect, linalg::LinalgDialect,
96                     memref::MemRefDialect, tensor::TensorDialect>();
97   }
98   StringRef getArgument() const final {
99     return "test-linalg-elementwise-fusion-patterns";
100   }
101   StringRef getDescription() const final {
102     return "Test Linalg element wise operation fusion patterns";
103   }
104 
105   Option<bool> fuseGenericOps{
106       *this, "fuse-generic-ops",
107       llvm::cl::desc("Test fusion of generic operations."),
108       llvm::cl::init(false)};
109 
110   Option<bool> fuseGenericOpsControl{
111       *this, "fuse-generic-ops-control",
112       llvm::cl::desc(
113           "Test fusion of generic operations with a control function."),
114       llvm::cl::init(false)};
115 
116   Option<bool> fuseWithReshapeByExpansion{
117       *this, "fuse-with-reshape-by-expansion",
118       llvm::cl::desc(
119           "Test fusion of generic operations with reshape by expansion"),
120       llvm::cl::init(false)};
121 
122   Option<bool> controlFuseByExpansion{
123       *this, "control-fusion-by-expansion",
124       llvm::cl::desc(
125           "Test controlling fusion of reshape with generic op by expansion"),
126       llvm::cl::init(false)};
127 
128   Option<bool> fuseWithReshapeByCollapsing{
129       *this, "fuse-with-reshape-by-collapsing",
130       llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
131                      "collapse the iteration space of the consumer"),
132       llvm::cl::init(false)};
133 
134   Option<bool> fuseWithReshapeByCollapsingWithControlFn{
135       *this, "fuse-with-reshape-by-collapsing-control",
136       llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
137                      "fusion patterns that "
138                      "collapse the iteration space of the consumer"),
139       llvm::cl::init(false)};
140 
141   Option<bool> fuseMultiUseProducer{
142       *this, "fuse-multiuse-producer",
143       llvm::cl::desc("Test fusion of producer ops with multiple uses"),
144       llvm::cl::init(false)};
145 
146   ListOption<int64_t> collapseDimensions{
147       *this, "collapse-dimensions-control",
148       llvm::cl::desc("Test controlling dimension collapse pattern")};
149 
150   void runOnOperation() override {
151     MLIRContext *context = &this->getContext();
152     func::FuncOp funcOp = this->getOperation();
153 
154     if (fuseGenericOps) {
155       RewritePatternSet fusionPatterns(context);
156       auto controlFn = [](OpOperand *operand) { return true; };
157       linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
158       if (failed(applyPatternsGreedily(funcOp.getBody(),
159                                        std::move(fusionPatterns))))
160         return signalPassFailure();
161       return;
162     }
163 
164     if (fuseGenericOpsControl) {
165       RewritePatternSet fusionPatterns(context);
166       linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
167                                                    setFusedOpOperandLimit<4>);
168 
169       if (failed(applyPatternsGreedily(funcOp.getBody(),
170                                        std::move(fusionPatterns))))
171         return signalPassFailure();
172       return;
173     }
174 
175     if (fuseWithReshapeByExpansion) {
176       RewritePatternSet fusionPatterns(context);
177       linalg::populateFoldReshapeOpsByExpansionPatterns(
178           fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; });
179       if (failed(applyPatternsGreedily(funcOp.getBody(),
180                                        std::move(fusionPatterns))))
181         return signalPassFailure();
182       return;
183     }
184 
185     if (controlFuseByExpansion) {
186       RewritePatternSet fusionPatterns(context);
187 
188       linalg::ControlFusionFn controlReshapeFusionFn =
189           [](OpOperand *fusedOperand) {
190             Operation *producer = fusedOperand->get().getDefiningOp();
191             if (!producer)
192               return false;
193 
194             if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(producer)) {
195               if (!collapseOp.getSrc().getDefiningOp<linalg::LinalgOp>()) {
196                 return false;
197               }
198             }
199 
200             Operation *consumer = fusedOperand->getOwner();
201             if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(consumer)) {
202               if (expandOp->hasOneUse()) {
203                 OpOperand &use = *expandOp->getUses().begin();
204                 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
205                 if (linalgOp && linalgOp.isDpsInit(&use))
206                   return true;
207               }
208               return false;
209             }
210             return true;
211           };
212 
213       linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
214                                                         controlReshapeFusionFn);
215       if (failed(applyPatternsGreedily(funcOp.getBody(),
216                                        std::move(fusionPatterns))))
217         return signalPassFailure();
218       return;
219     }
220 
221     if (fuseWithReshapeByCollapsing) {
222       RewritePatternSet patterns(context);
223       linalg::populateFoldReshapeOpsByCollapsingPatterns(
224           patterns, [](OpOperand * /*fusedOperand */) { return true; });
225       if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
226         return signalPassFailure();
227       return;
228     }
229 
230     if (fuseWithReshapeByCollapsingWithControlFn) {
231       RewritePatternSet patterns(context);
232       linalg::ControlFusionFn controlFn = [](OpOperand *fusedOperand) -> bool {
233         Operation *producer = fusedOperand->get().getDefiningOp();
234         if (isa<tensor::ExpandShapeOp>(producer)) {
235           // Skip fusing the first operand.
236           return fusedOperand->getOperandNumber();
237         }
238         return true;
239       };
240       linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
241       if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
242         return signalPassFailure();
243       return;
244     }
245 
246     if (fuseMultiUseProducer) {
247       RewritePatternSet patterns(context);
248       patterns.insert<TestMultiUseProducerFusion>(context);
249       if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
250         return signalPassFailure();
251       return;
252     }
253 
254     if (!collapseDimensions.empty()) {
255       SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
256                                    collapseDimensions.end());
257       linalg::GetCollapsableDimensionsFn collapseFn =
258           [&dims](linalg::LinalgOp op) {
259             SmallVector<ReassociationIndices> reassociations;
260             reassociations.emplace_back(dims);
261             return reassociations;
262           };
263       RewritePatternSet patterns(context);
264       linalg::populateCollapseDimensions(patterns, collapseFn);
265       if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
266         return signalPassFailure();
267       return;
268     }
269   }
270 };
271 
272 } // namespace
273 
274 namespace mlir {
275 namespace test {
276 void registerTestLinalgElementwiseFusion() {
277   PassRegistration<TestLinalgElementwiseFusion>();
278 }
279 } // namespace test
280 } // namespace mlir
281