xref: /llvm-project/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
13fef2d26SRiver Riddle //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
23fef2d26SRiver Riddle //
33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63fef2d26SRiver Riddle //
73fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
83fef2d26SRiver Riddle //
93fef2d26SRiver Riddle // This file implements a pass for testing fusion of elementwise operations in
103fef2d26SRiver Riddle // Linalg, mainly linalg options.
113fef2d26SRiver Riddle //
123fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
133fef2d26SRiver Riddle 
14eda6f907SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
1536550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
163fef2d26SRiver Riddle #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
173fef2d26SRiver Riddle #include "mlir/Pass/Pass.h"
183fef2d26SRiver Riddle #include "mlir/Pass/PassManager.h"
193fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
203fef2d26SRiver Riddle #include "llvm/ADT/TypeSwitch.h"
213fef2d26SRiver Riddle 
225e50dd04SRiver Riddle using namespace mlir;
233fef2d26SRiver Riddle 
243fef2d26SRiver Riddle static void addOperands(Operation *op, SetVector<Value> &operandSet) {
253fef2d26SRiver Riddle   if (!op)
263fef2d26SRiver Riddle     return;
273fef2d26SRiver Riddle   TypeSwitch<Operation *, void>(op)
283fef2d26SRiver Riddle       .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
290b2197b0SMatthias Springer         SmallVector<Value> inputOperands = linalgOp.getDpsInputs();
309f815cb5STobias Gysi         operandSet.insert(inputOperands.begin(), inputOperands.end());
313fef2d26SRiver Riddle       })
323fef2d26SRiver Riddle       .Default([&](Operation *operation) {
333fef2d26SRiver Riddle         operandSet.insert(operation->operand_begin(), operation->operand_end());
343fef2d26SRiver Riddle       });
353fef2d26SRiver Riddle }
363fef2d26SRiver Riddle 
373fef2d26SRiver Riddle template <int limit = 3>
38a7bfdc23SMahesh Ravishankar static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
39a7bfdc23SMahesh Ravishankar   Operation *producer = fusedOperand->get().getDefiningOp();
40a7bfdc23SMahesh Ravishankar   if (!producer)
413fef2d26SRiver Riddle     return false;
42a7bfdc23SMahesh Ravishankar 
43a7bfdc23SMahesh Ravishankar   Operation *consumer = fusedOperand->getOwner();
44a7bfdc23SMahesh Ravishankar   SetVector<Value> fusedOpOperands;
45a7bfdc23SMahesh Ravishankar   if (producer->getNumResults() != 1)
46a7bfdc23SMahesh Ravishankar     return false;
47a7bfdc23SMahesh Ravishankar   addOperands(consumer, fusedOpOperands);
48a7bfdc23SMahesh Ravishankar   fusedOpOperands.remove(producer->getResult(0));
49a7bfdc23SMahesh Ravishankar   addOperands(producer, fusedOpOperands);
503fef2d26SRiver Riddle   return fusedOpOperands.size() <= limit;
513fef2d26SRiver Riddle }
523fef2d26SRiver Riddle 
533fef2d26SRiver Riddle namespace {
5469011a2aSMahesh Ravishankar 
5569011a2aSMahesh Ravishankar /// Pattern to test fusion of producer with consumer, even if producer has
5669011a2aSMahesh Ravishankar /// multiple uses.
5769011a2aSMahesh Ravishankar struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
5869011a2aSMahesh Ravishankar   using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
5969011a2aSMahesh Ravishankar 
6069011a2aSMahesh Ravishankar   LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
6169011a2aSMahesh Ravishankar                                 PatternRewriter &rewriter) const override {
6269011a2aSMahesh Ravishankar     OpOperand *fusableOperand = nullptr;
6369011a2aSMahesh Ravishankar     for (OpOperand &operand : genericOp->getOpOperands()) {
6469011a2aSMahesh Ravishankar       if (linalg::areElementwiseOpsFusable(&operand)) {
6569011a2aSMahesh Ravishankar         fusableOperand = &operand;
6669011a2aSMahesh Ravishankar         break;
6769011a2aSMahesh Ravishankar       }
6869011a2aSMahesh Ravishankar     }
6969011a2aSMahesh Ravishankar     if (!fusableOperand) {
7069011a2aSMahesh Ravishankar       return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
7169011a2aSMahesh Ravishankar     }
7269011a2aSMahesh Ravishankar     std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
7369011a2aSMahesh Ravishankar         linalg::fuseElementwiseOps(rewriter, fusableOperand);
7469011a2aSMahesh Ravishankar     if (!fusionResult)
7501890942SMahesh Ravishankar       return rewriter.notifyMatchFailure(genericOp, "fusion failed");
7669011a2aSMahesh Ravishankar     for (auto [origValue, replacement] : fusionResult->replacements) {
772c40a0a6SMatthias Springer       rewriter.replaceUsesWithIf(origValue, replacement, [&](OpOperand &use) {
7869011a2aSMahesh Ravishankar         return use.getOwner() != genericOp.getOperation();
7969011a2aSMahesh Ravishankar       });
8069011a2aSMahesh Ravishankar     }
8169011a2aSMahesh Ravishankar     rewriter.eraseOp(genericOp);
8269011a2aSMahesh Ravishankar     return success();
8369011a2aSMahesh Ravishankar   }
8469011a2aSMahesh Ravishankar };
8569011a2aSMahesh Ravishankar 
863fef2d26SRiver Riddle struct TestLinalgElementwiseFusion
8758ceae95SRiver Riddle     : public PassWrapper<TestLinalgElementwiseFusion,
8858ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
895e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion)
905e50dd04SRiver Riddle 
912abd7f13SMahesh Ravishankar   TestLinalgElementwiseFusion() = default;
922abd7f13SMahesh Ravishankar   TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass)
932abd7f13SMahesh Ravishankar       : PassWrapper(pass) {}
943fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
954c48f016SMatthias Springer     registry.insert<affine::AffineDialect, linalg::LinalgDialect,
964c48f016SMatthias Springer                     memref::MemRefDialect, tensor::TensorDialect>();
973fef2d26SRiver Riddle   }
98b5e22e6dSMehdi Amini   StringRef getArgument() const final {
99b5e22e6dSMehdi Amini     return "test-linalg-elementwise-fusion-patterns";
100b5e22e6dSMehdi Amini   }
101b5e22e6dSMehdi Amini   StringRef getDescription() const final {
102b5e22e6dSMehdi Amini     return "Test Linalg element wise operation fusion patterns";
103b5e22e6dSMehdi Amini   }
1043fef2d26SRiver Riddle 
1052abd7f13SMahesh Ravishankar   Option<bool> fuseGenericOps{
1062abd7f13SMahesh Ravishankar       *this, "fuse-generic-ops",
1072abd7f13SMahesh Ravishankar       llvm::cl::desc("Test fusion of generic operations."),
1082abd7f13SMahesh Ravishankar       llvm::cl::init(false)};
1092abd7f13SMahesh Ravishankar 
1102d4b9986SMahesh Ravishankar   Option<bool> fuseGenericOpsControl{
1112d4b9986SMahesh Ravishankar       *this, "fuse-generic-ops-control",
1122d4b9986SMahesh Ravishankar       llvm::cl::desc(
1132d4b9986SMahesh Ravishankar           "Test fusion of generic operations with a control function."),
1142d4b9986SMahesh Ravishankar       llvm::cl::init(false)};
1152d4b9986SMahesh Ravishankar 
1160c090dccSMahesh Ravishankar   Option<bool> fuseWithReshapeByExpansion{
1170c090dccSMahesh Ravishankar       *this, "fuse-with-reshape-by-expansion",
1180c090dccSMahesh Ravishankar       llvm::cl::desc(
1190c090dccSMahesh Ravishankar           "Test fusion of generic operations with reshape by expansion"),
1200c090dccSMahesh Ravishankar       llvm::cl::init(false)};
1210c090dccSMahesh Ravishankar 
1222abd7f13SMahesh Ravishankar   Option<bool> controlFuseByExpansion{
1232abd7f13SMahesh Ravishankar       *this, "control-fusion-by-expansion",
1242abd7f13SMahesh Ravishankar       llvm::cl::desc(
1252abd7f13SMahesh Ravishankar           "Test controlling fusion of reshape with generic op by expansion"),
1262abd7f13SMahesh Ravishankar       llvm::cl::init(false)};
1272abd7f13SMahesh Ravishankar 
1282c58cde0SMahesh Ravishankar   Option<bool> fuseWithReshapeByCollapsing{
1292c58cde0SMahesh Ravishankar       *this, "fuse-with-reshape-by-collapsing",
1302c58cde0SMahesh Ravishankar       llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
1312c58cde0SMahesh Ravishankar                      "collapse the iteration space of the consumer"),
1322c58cde0SMahesh Ravishankar       llvm::cl::init(false)};
1332c58cde0SMahesh Ravishankar 
1342c58cde0SMahesh Ravishankar   Option<bool> fuseWithReshapeByCollapsingWithControlFn{
1352c58cde0SMahesh Ravishankar       *this, "fuse-with-reshape-by-collapsing-control",
1362c58cde0SMahesh Ravishankar       llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
1372c58cde0SMahesh Ravishankar                      "fusion patterns that "
1382c58cde0SMahesh Ravishankar                      "collapse the iteration space of the consumer"),
1392c58cde0SMahesh Ravishankar       llvm::cl::init(false)};
14069011a2aSMahesh Ravishankar 
14169011a2aSMahesh Ravishankar   Option<bool> fuseMultiUseProducer{
14269011a2aSMahesh Ravishankar       *this, "fuse-multiuse-producer",
14369011a2aSMahesh Ravishankar       llvm::cl::desc("Test fusion of producer ops with multiple uses"),
14469011a2aSMahesh Ravishankar       llvm::cl::init(false)};
14569011a2aSMahesh Ravishankar 
14683c65fbcSThomas Raoux   ListOption<int64_t> collapseDimensions{
14783c65fbcSThomas Raoux       *this, "collapse-dimensions-control",
14883c65fbcSThomas Raoux       llvm::cl::desc("Test controlling dimension collapse pattern")};
1492c58cde0SMahesh Ravishankar 
15041574554SRiver Riddle   void runOnOperation() override {
1513fef2d26SRiver Riddle     MLIRContext *context = &this->getContext();
15258ceae95SRiver Riddle     func::FuncOp funcOp = this->getOperation();
1537568f710SMahesh Ravishankar 
1542abd7f13SMahesh Ravishankar     if (fuseGenericOps) {
1552abd7f13SMahesh Ravishankar       RewritePatternSet fusionPatterns(context);
1562d4b9986SMahesh Ravishankar       auto controlFn = [](OpOperand *operand) { return true; };
1572d4b9986SMahesh Ravishankar       linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
158*09dfc571SJacques Pienaar       if (failed(applyPatternsGreedily(funcOp.getBody(),
15969011a2aSMahesh Ravishankar                                        std::move(fusionPatterns))))
16069011a2aSMahesh Ravishankar         return signalPassFailure();
1612d4b9986SMahesh Ravishankar       return;
1622d4b9986SMahesh Ravishankar     }
1632d4b9986SMahesh Ravishankar 
1642d4b9986SMahesh Ravishankar     if (fuseGenericOpsControl) {
1652d4b9986SMahesh Ravishankar       RewritePatternSet fusionPatterns(context);
1662291705dSMahesh Ravishankar       linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
1672291705dSMahesh Ravishankar                                                    setFusedOpOperandLimit<4>);
1683fef2d26SRiver Riddle 
169*09dfc571SJacques Pienaar       if (failed(applyPatternsGreedily(funcOp.getBody(),
17069011a2aSMahesh Ravishankar                                        std::move(fusionPatterns))))
17169011a2aSMahesh Ravishankar         return signalPassFailure();
1722abd7f13SMahesh Ravishankar       return;
173b546f434SMaheshRavishankar     }
174b546f434SMaheshRavishankar 
1750c090dccSMahesh Ravishankar     if (fuseWithReshapeByExpansion) {
1760c090dccSMahesh Ravishankar       RewritePatternSet fusionPatterns(context);
1770c090dccSMahesh Ravishankar       linalg::populateFoldReshapeOpsByExpansionPatterns(
178a7bfdc23SMahesh Ravishankar           fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; });
179*09dfc571SJacques Pienaar       if (failed(applyPatternsGreedily(funcOp.getBody(),
1800c090dccSMahesh Ravishankar                                        std::move(fusionPatterns))))
1810c090dccSMahesh Ravishankar         return signalPassFailure();
1820c090dccSMahesh Ravishankar       return;
1830c090dccSMahesh Ravishankar     }
1840c090dccSMahesh Ravishankar 
1852abd7f13SMahesh Ravishankar     if (controlFuseByExpansion) {
186b546f434SMaheshRavishankar       RewritePatternSet fusionPatterns(context);
187b546f434SMaheshRavishankar 
1882291705dSMahesh Ravishankar       linalg::ControlFusionFn controlReshapeFusionFn =
189a7bfdc23SMahesh Ravishankar           [](OpOperand *fusedOperand) {
190a7bfdc23SMahesh Ravishankar             Operation *producer = fusedOperand->get().getDefiningOp();
191a7bfdc23SMahesh Ravishankar             if (!producer)
192a7bfdc23SMahesh Ravishankar               return false;
193a7bfdc23SMahesh Ravishankar 
194a7bfdc23SMahesh Ravishankar             if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(producer)) {
19504235d07SJacques Pienaar               if (!collapseOp.getSrc().getDefiningOp<linalg::LinalgOp>()) {
196b546f434SMaheshRavishankar                 return false;
197b546f434SMaheshRavishankar               }
198b546f434SMaheshRavishankar             }
199a7bfdc23SMahesh Ravishankar 
200a7bfdc23SMahesh Ravishankar             Operation *consumer = fusedOperand->getOwner();
201a7bfdc23SMahesh Ravishankar             if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(consumer)) {
202b546f434SMaheshRavishankar               if (expandOp->hasOneUse()) {
203b546f434SMaheshRavishankar                 OpOperand &use = *expandOp->getUses().begin();
204b546f434SMaheshRavishankar                 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
205b4db15a9SAlexander Belyaev                 if (linalgOp && linalgOp.isDpsInit(&use))
206b546f434SMaheshRavishankar                   return true;
207b546f434SMaheshRavishankar               }
2080c090dccSMahesh Ravishankar               return false;
209b546f434SMaheshRavishankar             }
2100c090dccSMahesh Ravishankar             return true;
211b546f434SMaheshRavishankar           };
212b546f434SMaheshRavishankar 
213b546f434SMaheshRavishankar       linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
214b546f434SMaheshRavishankar                                                         controlReshapeFusionFn);
215*09dfc571SJacques Pienaar       if (failed(applyPatternsGreedily(funcOp.getBody(),
21669011a2aSMahesh Ravishankar                                        std::move(fusionPatterns))))
21769011a2aSMahesh Ravishankar         return signalPassFailure();
2182abd7f13SMahesh Ravishankar       return;
219b5e22e6dSMehdi Amini     }
2203fef2d26SRiver Riddle 
2212c58cde0SMahesh Ravishankar     if (fuseWithReshapeByCollapsing) {
2222c58cde0SMahesh Ravishankar       RewritePatternSet patterns(context);
2232291705dSMahesh Ravishankar       linalg::populateFoldReshapeOpsByCollapsingPatterns(
224a7bfdc23SMahesh Ravishankar           patterns, [](OpOperand * /*fusedOperand */) { return true; });
225*09dfc571SJacques Pienaar       if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
22669011a2aSMahesh Ravishankar         return signalPassFailure();
22769011a2aSMahesh Ravishankar       return;
2282c58cde0SMahesh Ravishankar     }
2292c58cde0SMahesh Ravishankar 
2302c58cde0SMahesh Ravishankar     if (fuseWithReshapeByCollapsingWithControlFn) {
2312c58cde0SMahesh Ravishankar       RewritePatternSet patterns(context);
232a7bfdc23SMahesh Ravishankar       linalg::ControlFusionFn controlFn = [](OpOperand *fusedOperand) -> bool {
233a7bfdc23SMahesh Ravishankar         Operation *producer = fusedOperand->get().getDefiningOp();
234a7bfdc23SMahesh Ravishankar         if (isa<tensor::ExpandShapeOp>(producer)) {
2352c58cde0SMahesh Ravishankar           // Skip fusing the first operand.
236a7bfdc23SMahesh Ravishankar           return fusedOperand->getOperandNumber();
2372c58cde0SMahesh Ravishankar         }
2382c58cde0SMahesh Ravishankar         return true;
2392c58cde0SMahesh Ravishankar       };
2402c58cde0SMahesh Ravishankar       linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
241*09dfc571SJacques Pienaar       if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
24269011a2aSMahesh Ravishankar         return signalPassFailure();
24369011a2aSMahesh Ravishankar       return;
24469011a2aSMahesh Ravishankar     }
24569011a2aSMahesh Ravishankar 
24669011a2aSMahesh Ravishankar     if (fuseMultiUseProducer) {
24769011a2aSMahesh Ravishankar       RewritePatternSet patterns(context);
24869011a2aSMahesh Ravishankar       patterns.insert<TestMultiUseProducerFusion>(context);
249*09dfc571SJacques Pienaar       if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
25069011a2aSMahesh Ravishankar         return signalPassFailure();
25169011a2aSMahesh Ravishankar       return;
2522c58cde0SMahesh Ravishankar     }
25383c65fbcSThomas Raoux 
25483c65fbcSThomas Raoux     if (!collapseDimensions.empty()) {
25583c65fbcSThomas Raoux       SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
25683c65fbcSThomas Raoux                                    collapseDimensions.end());
25783c65fbcSThomas Raoux       linalg::GetCollapsableDimensionsFn collapseFn =
2585c3ed392SAviad Cohen           [&dims](linalg::LinalgOp op) {
25983c65fbcSThomas Raoux             SmallVector<ReassociationIndices> reassociations;
26083c65fbcSThomas Raoux             reassociations.emplace_back(dims);
26183c65fbcSThomas Raoux             return reassociations;
26283c65fbcSThomas Raoux           };
26383c65fbcSThomas Raoux       RewritePatternSet patterns(context);
26483c65fbcSThomas Raoux       linalg::populateCollapseDimensions(patterns, collapseFn);
265*09dfc571SJacques Pienaar       if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
26669011a2aSMahesh Ravishankar         return signalPassFailure();
26769011a2aSMahesh Ravishankar       return;
26883c65fbcSThomas Raoux     }
2692abd7f13SMahesh Ravishankar   }
2703fef2d26SRiver Riddle };
2712abd7f13SMahesh Ravishankar 
2723fef2d26SRiver Riddle } // namespace
2733fef2d26SRiver Riddle 
2745e50dd04SRiver Riddle namespace mlir {
2753fef2d26SRiver Riddle namespace test {
2763fef2d26SRiver Riddle void registerTestLinalgElementwiseFusion() {
277b5e22e6dSMehdi Amini   PassRegistration<TestLinalgElementwiseFusion>();
2783fef2d26SRiver Riddle }
2793fef2d26SRiver Riddle } // namespace test
2803fef2d26SRiver Riddle } // namespace mlir
281