13fef2d26SRiver Riddle //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===// 23fef2d26SRiver Riddle // 33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 63fef2d26SRiver Riddle // 73fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 83fef2d26SRiver Riddle // 93fef2d26SRiver Riddle // This file implements a pass for testing fusion of elementwise operations in 103fef2d26SRiver Riddle // Linalg, mainly linalg options. 113fef2d26SRiver Riddle // 123fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 133fef2d26SRiver Riddle 14eda6f907SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h" 1536550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 163fef2d26SRiver Riddle #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 173fef2d26SRiver Riddle #include "mlir/Pass/Pass.h" 183fef2d26SRiver Riddle #include "mlir/Pass/PassManager.h" 193fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 203fef2d26SRiver Riddle #include "llvm/ADT/TypeSwitch.h" 213fef2d26SRiver Riddle 225e50dd04SRiver Riddle using namespace mlir; 233fef2d26SRiver Riddle 243fef2d26SRiver Riddle static void addOperands(Operation *op, SetVector<Value> &operandSet) { 253fef2d26SRiver Riddle if (!op) 263fef2d26SRiver Riddle return; 273fef2d26SRiver Riddle TypeSwitch<Operation *, void>(op) 283fef2d26SRiver Riddle .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) { 290b2197b0SMatthias Springer SmallVector<Value> inputOperands = linalgOp.getDpsInputs(); 309f815cb5STobias Gysi operandSet.insert(inputOperands.begin(), inputOperands.end()); 313fef2d26SRiver Riddle }) 323fef2d26SRiver Riddle .Default([&](Operation *operation) { 333fef2d26SRiver Riddle operandSet.insert(operation->operand_begin(), operation->operand_end()); 343fef2d26SRiver Riddle }); 353fef2d26SRiver Riddle } 363fef2d26SRiver Riddle 373fef2d26SRiver Riddle template <int limit = 3> 38a7bfdc23SMahesh Ravishankar static bool setFusedOpOperandLimit(OpOperand *fusedOperand) { 39a7bfdc23SMahesh Ravishankar Operation *producer = fusedOperand->get().getDefiningOp(); 40a7bfdc23SMahesh Ravishankar if (!producer) 413fef2d26SRiver Riddle return false; 42a7bfdc23SMahesh Ravishankar 43a7bfdc23SMahesh Ravishankar Operation *consumer = fusedOperand->getOwner(); 44a7bfdc23SMahesh Ravishankar SetVector<Value> fusedOpOperands; 45a7bfdc23SMahesh Ravishankar if (producer->getNumResults() != 1) 46a7bfdc23SMahesh Ravishankar return false; 47a7bfdc23SMahesh Ravishankar addOperands(consumer, fusedOpOperands); 48a7bfdc23SMahesh Ravishankar fusedOpOperands.remove(producer->getResult(0)); 49a7bfdc23SMahesh Ravishankar addOperands(producer, fusedOpOperands); 503fef2d26SRiver Riddle return fusedOpOperands.size() <= limit; 513fef2d26SRiver Riddle } 523fef2d26SRiver Riddle 533fef2d26SRiver Riddle namespace { 5469011a2aSMahesh Ravishankar 5569011a2aSMahesh Ravishankar /// Pattern to test fusion of producer with consumer, even if producer has 5669011a2aSMahesh Ravishankar /// multiple uses. 5769011a2aSMahesh Ravishankar struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> { 5869011a2aSMahesh Ravishankar using OpRewritePattern<linalg::GenericOp>::OpRewritePattern; 5969011a2aSMahesh Ravishankar 6069011a2aSMahesh Ravishankar LogicalResult matchAndRewrite(linalg::GenericOp genericOp, 6169011a2aSMahesh Ravishankar PatternRewriter &rewriter) const override { 6269011a2aSMahesh Ravishankar OpOperand *fusableOperand = nullptr; 6369011a2aSMahesh Ravishankar for (OpOperand &operand : genericOp->getOpOperands()) { 6469011a2aSMahesh Ravishankar if (linalg::areElementwiseOpsFusable(&operand)) { 6569011a2aSMahesh Ravishankar fusableOperand = &operand; 6669011a2aSMahesh Ravishankar break; 6769011a2aSMahesh Ravishankar } 6869011a2aSMahesh Ravishankar } 6969011a2aSMahesh Ravishankar if (!fusableOperand) { 7069011a2aSMahesh Ravishankar return rewriter.notifyMatchFailure(genericOp, "no fusable operand found"); 7169011a2aSMahesh Ravishankar } 7269011a2aSMahesh Ravishankar std::optional<linalg::ElementwiseOpFusionResult> fusionResult = 7369011a2aSMahesh Ravishankar linalg::fuseElementwiseOps(rewriter, fusableOperand); 7469011a2aSMahesh Ravishankar if (!fusionResult) 7501890942SMahesh Ravishankar return rewriter.notifyMatchFailure(genericOp, "fusion failed"); 7669011a2aSMahesh Ravishankar for (auto [origValue, replacement] : fusionResult->replacements) { 772c40a0a6SMatthias Springer rewriter.replaceUsesWithIf(origValue, replacement, [&](OpOperand &use) { 7869011a2aSMahesh Ravishankar return use.getOwner() != genericOp.getOperation(); 7969011a2aSMahesh Ravishankar }); 8069011a2aSMahesh Ravishankar } 8169011a2aSMahesh Ravishankar rewriter.eraseOp(genericOp); 8269011a2aSMahesh Ravishankar return success(); 8369011a2aSMahesh Ravishankar } 8469011a2aSMahesh Ravishankar }; 8569011a2aSMahesh Ravishankar 863fef2d26SRiver Riddle struct TestLinalgElementwiseFusion 8758ceae95SRiver Riddle : public PassWrapper<TestLinalgElementwiseFusion, 8858ceae95SRiver Riddle OperationPass<func::FuncOp>> { 895e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion) 905e50dd04SRiver Riddle 912abd7f13SMahesh Ravishankar TestLinalgElementwiseFusion() = default; 922abd7f13SMahesh Ravishankar TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass) 932abd7f13SMahesh Ravishankar : PassWrapper(pass) {} 943fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override { 954c48f016SMatthias Springer registry.insert<affine::AffineDialect, linalg::LinalgDialect, 964c48f016SMatthias Springer memref::MemRefDialect, tensor::TensorDialect>(); 973fef2d26SRiver Riddle } 98b5e22e6dSMehdi Amini StringRef getArgument() const final { 99b5e22e6dSMehdi Amini return "test-linalg-elementwise-fusion-patterns"; 100b5e22e6dSMehdi Amini } 101b5e22e6dSMehdi Amini StringRef getDescription() const final { 102b5e22e6dSMehdi Amini return "Test Linalg element wise operation fusion patterns"; 103b5e22e6dSMehdi Amini } 1043fef2d26SRiver Riddle 1052abd7f13SMahesh Ravishankar Option<bool> fuseGenericOps{ 1062abd7f13SMahesh Ravishankar *this, "fuse-generic-ops", 1072abd7f13SMahesh Ravishankar llvm::cl::desc("Test fusion of generic operations."), 1082abd7f13SMahesh Ravishankar llvm::cl::init(false)}; 1092abd7f13SMahesh Ravishankar 1102d4b9986SMahesh Ravishankar Option<bool> fuseGenericOpsControl{ 1112d4b9986SMahesh Ravishankar *this, "fuse-generic-ops-control", 1122d4b9986SMahesh Ravishankar llvm::cl::desc( 1132d4b9986SMahesh Ravishankar "Test fusion of generic operations with a control function."), 1142d4b9986SMahesh Ravishankar llvm::cl::init(false)}; 1152d4b9986SMahesh Ravishankar 1160c090dccSMahesh Ravishankar Option<bool> fuseWithReshapeByExpansion{ 1170c090dccSMahesh Ravishankar *this, "fuse-with-reshape-by-expansion", 1180c090dccSMahesh Ravishankar llvm::cl::desc( 1190c090dccSMahesh Ravishankar "Test fusion of generic operations with reshape by expansion"), 1200c090dccSMahesh Ravishankar llvm::cl::init(false)}; 1210c090dccSMahesh Ravishankar 1222abd7f13SMahesh Ravishankar Option<bool> controlFuseByExpansion{ 1232abd7f13SMahesh Ravishankar *this, "control-fusion-by-expansion", 1242abd7f13SMahesh Ravishankar llvm::cl::desc( 1252abd7f13SMahesh Ravishankar "Test controlling fusion of reshape with generic op by expansion"), 1262abd7f13SMahesh Ravishankar llvm::cl::init(false)}; 1272abd7f13SMahesh Ravishankar 1282c58cde0SMahesh Ravishankar Option<bool> fuseWithReshapeByCollapsing{ 1292c58cde0SMahesh Ravishankar *this, "fuse-with-reshape-by-collapsing", 1302c58cde0SMahesh Ravishankar llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that " 1312c58cde0SMahesh Ravishankar "collapse the iteration space of the consumer"), 1322c58cde0SMahesh Ravishankar llvm::cl::init(false)}; 1332c58cde0SMahesh Ravishankar 1342c58cde0SMahesh Ravishankar Option<bool> fuseWithReshapeByCollapsingWithControlFn{ 1352c58cde0SMahesh Ravishankar *this, "fuse-with-reshape-by-collapsing-control", 1362c58cde0SMahesh Ravishankar llvm::cl::desc("Test controlling the linalg expand_shape -> generic " 1372c58cde0SMahesh Ravishankar "fusion patterns that " 1382c58cde0SMahesh Ravishankar "collapse the iteration space of the consumer"), 1392c58cde0SMahesh Ravishankar llvm::cl::init(false)}; 14069011a2aSMahesh Ravishankar 14169011a2aSMahesh Ravishankar Option<bool> fuseMultiUseProducer{ 14269011a2aSMahesh Ravishankar *this, "fuse-multiuse-producer", 14369011a2aSMahesh Ravishankar llvm::cl::desc("Test fusion of producer ops with multiple uses"), 14469011a2aSMahesh Ravishankar llvm::cl::init(false)}; 14569011a2aSMahesh Ravishankar 14683c65fbcSThomas Raoux ListOption<int64_t> collapseDimensions{ 14783c65fbcSThomas Raoux *this, "collapse-dimensions-control", 14883c65fbcSThomas Raoux llvm::cl::desc("Test controlling dimension collapse pattern")}; 1492c58cde0SMahesh Ravishankar 15041574554SRiver Riddle void runOnOperation() override { 1513fef2d26SRiver Riddle MLIRContext *context = &this->getContext(); 15258ceae95SRiver Riddle func::FuncOp funcOp = this->getOperation(); 1537568f710SMahesh Ravishankar 1542abd7f13SMahesh Ravishankar if (fuseGenericOps) { 1552abd7f13SMahesh Ravishankar RewritePatternSet fusionPatterns(context); 1562d4b9986SMahesh Ravishankar auto controlFn = [](OpOperand *operand) { return true; }; 1572d4b9986SMahesh Ravishankar linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn); 158*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(funcOp.getBody(), 15969011a2aSMahesh Ravishankar std::move(fusionPatterns)))) 16069011a2aSMahesh Ravishankar return signalPassFailure(); 1612d4b9986SMahesh Ravishankar return; 1622d4b9986SMahesh Ravishankar } 1632d4b9986SMahesh Ravishankar 1642d4b9986SMahesh Ravishankar if (fuseGenericOpsControl) { 1652d4b9986SMahesh Ravishankar RewritePatternSet fusionPatterns(context); 1662291705dSMahesh Ravishankar linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, 1672291705dSMahesh Ravishankar setFusedOpOperandLimit<4>); 1683fef2d26SRiver Riddle 169*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(funcOp.getBody(), 17069011a2aSMahesh Ravishankar std::move(fusionPatterns)))) 17169011a2aSMahesh Ravishankar return signalPassFailure(); 1722abd7f13SMahesh Ravishankar return; 173b546f434SMaheshRavishankar } 174b546f434SMaheshRavishankar 1750c090dccSMahesh Ravishankar if (fuseWithReshapeByExpansion) { 1760c090dccSMahesh Ravishankar RewritePatternSet fusionPatterns(context); 1770c090dccSMahesh Ravishankar linalg::populateFoldReshapeOpsByExpansionPatterns( 178a7bfdc23SMahesh Ravishankar fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; }); 179*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(funcOp.getBody(), 1800c090dccSMahesh Ravishankar std::move(fusionPatterns)))) 1810c090dccSMahesh Ravishankar return signalPassFailure(); 1820c090dccSMahesh Ravishankar return; 1830c090dccSMahesh Ravishankar } 1840c090dccSMahesh Ravishankar 1852abd7f13SMahesh Ravishankar if (controlFuseByExpansion) { 186b546f434SMaheshRavishankar RewritePatternSet fusionPatterns(context); 187b546f434SMaheshRavishankar 1882291705dSMahesh Ravishankar linalg::ControlFusionFn controlReshapeFusionFn = 189a7bfdc23SMahesh Ravishankar [](OpOperand *fusedOperand) { 190a7bfdc23SMahesh Ravishankar Operation *producer = fusedOperand->get().getDefiningOp(); 191a7bfdc23SMahesh Ravishankar if (!producer) 192a7bfdc23SMahesh Ravishankar return false; 193a7bfdc23SMahesh Ravishankar 194a7bfdc23SMahesh Ravishankar if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(producer)) { 19504235d07SJacques Pienaar if (!collapseOp.getSrc().getDefiningOp<linalg::LinalgOp>()) { 196b546f434SMaheshRavishankar return false; 197b546f434SMaheshRavishankar } 198b546f434SMaheshRavishankar } 199a7bfdc23SMahesh Ravishankar 200a7bfdc23SMahesh Ravishankar Operation *consumer = fusedOperand->getOwner(); 201a7bfdc23SMahesh Ravishankar if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(consumer)) { 202b546f434SMaheshRavishankar if (expandOp->hasOneUse()) { 203b546f434SMaheshRavishankar OpOperand &use = *expandOp->getUses().begin(); 204b546f434SMaheshRavishankar auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner()); 205b4db15a9SAlexander Belyaev if (linalgOp && linalgOp.isDpsInit(&use)) 206b546f434SMaheshRavishankar return true; 207b546f434SMaheshRavishankar } 2080c090dccSMahesh Ravishankar return false; 209b546f434SMaheshRavishankar } 2100c090dccSMahesh Ravishankar return true; 211b546f434SMaheshRavishankar }; 212b546f434SMaheshRavishankar 213b546f434SMaheshRavishankar linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, 214b546f434SMaheshRavishankar controlReshapeFusionFn); 215*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(funcOp.getBody(), 21669011a2aSMahesh Ravishankar std::move(fusionPatterns)))) 21769011a2aSMahesh Ravishankar return signalPassFailure(); 2182abd7f13SMahesh Ravishankar return; 219b5e22e6dSMehdi Amini } 2203fef2d26SRiver Riddle 2212c58cde0SMahesh Ravishankar if (fuseWithReshapeByCollapsing) { 2222c58cde0SMahesh Ravishankar RewritePatternSet patterns(context); 2232291705dSMahesh Ravishankar linalg::populateFoldReshapeOpsByCollapsingPatterns( 224a7bfdc23SMahesh Ravishankar patterns, [](OpOperand * /*fusedOperand */) { return true; }); 225*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) 22669011a2aSMahesh Ravishankar return signalPassFailure(); 22769011a2aSMahesh Ravishankar return; 2282c58cde0SMahesh Ravishankar } 2292c58cde0SMahesh Ravishankar 2302c58cde0SMahesh Ravishankar if (fuseWithReshapeByCollapsingWithControlFn) { 2312c58cde0SMahesh Ravishankar RewritePatternSet patterns(context); 232a7bfdc23SMahesh Ravishankar linalg::ControlFusionFn controlFn = [](OpOperand *fusedOperand) -> bool { 233a7bfdc23SMahesh Ravishankar Operation *producer = fusedOperand->get().getDefiningOp(); 234a7bfdc23SMahesh Ravishankar if (isa<tensor::ExpandShapeOp>(producer)) { 2352c58cde0SMahesh Ravishankar // Skip fusing the first operand. 236a7bfdc23SMahesh Ravishankar return fusedOperand->getOperandNumber(); 2372c58cde0SMahesh Ravishankar } 2382c58cde0SMahesh Ravishankar return true; 2392c58cde0SMahesh Ravishankar }; 2402c58cde0SMahesh Ravishankar linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn); 241*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) 24269011a2aSMahesh Ravishankar return signalPassFailure(); 24369011a2aSMahesh Ravishankar return; 24469011a2aSMahesh Ravishankar } 24569011a2aSMahesh Ravishankar 24669011a2aSMahesh Ravishankar if (fuseMultiUseProducer) { 24769011a2aSMahesh Ravishankar RewritePatternSet patterns(context); 24869011a2aSMahesh Ravishankar patterns.insert<TestMultiUseProducerFusion>(context); 249*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) 25069011a2aSMahesh Ravishankar return signalPassFailure(); 25169011a2aSMahesh Ravishankar return; 2522c58cde0SMahesh Ravishankar } 25383c65fbcSThomas Raoux 25483c65fbcSThomas Raoux if (!collapseDimensions.empty()) { 25583c65fbcSThomas Raoux SmallVector<int64_t, 2> dims(collapseDimensions.begin(), 25683c65fbcSThomas Raoux collapseDimensions.end()); 25783c65fbcSThomas Raoux linalg::GetCollapsableDimensionsFn collapseFn = 2585c3ed392SAviad Cohen [&dims](linalg::LinalgOp op) { 25983c65fbcSThomas Raoux SmallVector<ReassociationIndices> reassociations; 26083c65fbcSThomas Raoux reassociations.emplace_back(dims); 26183c65fbcSThomas Raoux return reassociations; 26283c65fbcSThomas Raoux }; 26383c65fbcSThomas Raoux RewritePatternSet patterns(context); 26483c65fbcSThomas Raoux linalg::populateCollapseDimensions(patterns, collapseFn); 265*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) 26669011a2aSMahesh Ravishankar return signalPassFailure(); 26769011a2aSMahesh Ravishankar return; 26883c65fbcSThomas Raoux } 2692abd7f13SMahesh Ravishankar } 2703fef2d26SRiver Riddle }; 2712abd7f13SMahesh Ravishankar 2723fef2d26SRiver Riddle } // namespace 2733fef2d26SRiver Riddle 2745e50dd04SRiver Riddle namespace mlir { 2753fef2d26SRiver Riddle namespace test { 2763fef2d26SRiver Riddle void registerTestLinalgElementwiseFusion() { 277b5e22e6dSMehdi Amini PassRegistration<TestLinalgElementwiseFusion>(); 2783fef2d26SRiver Riddle } 2793fef2d26SRiver Riddle } // namespace test 2803fef2d26SRiver Riddle } // namespace mlir 281