//===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass for testing fusion of elementwise operations in // Linalg, mainly linalg options. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; static void addOperands(Operation *op, SetVector &operandSet) { if (!op) return; TypeSwitch(op) .Case([&](linalg::LinalgOp linalgOp) { SmallVector inputOperands = linalgOp.getDpsInputs(); operandSet.insert(inputOperands.begin(), inputOperands.end()); }) .Default([&](Operation *operation) { operandSet.insert(operation->operand_begin(), operation->operand_end()); }); } template static bool setFusedOpOperandLimit(OpOperand *fusedOperand) { Operation *producer = fusedOperand->get().getDefiningOp(); if (!producer) return false; Operation *consumer = fusedOperand->getOwner(); SetVector fusedOpOperands; if (producer->getNumResults() != 1) return false; addOperands(consumer, fusedOpOperands); fusedOpOperands.remove(producer->getResult(0)); addOperands(producer, fusedOpOperands); return fusedOpOperands.size() <= limit; } namespace { /// Pattern to test fusion of producer with consumer, even if producer has /// multiple uses. struct TestMultiUseProducerFusion : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::GenericOp genericOp, PatternRewriter &rewriter) const override { OpOperand *fusableOperand = nullptr; for (OpOperand &operand : genericOp->getOpOperands()) { if (linalg::areElementwiseOpsFusable(&operand)) { fusableOperand = &operand; break; } } if (!fusableOperand) { return rewriter.notifyMatchFailure(genericOp, "no fusable operand found"); } std::optional fusionResult = linalg::fuseElementwiseOps(rewriter, fusableOperand); if (!fusionResult) return rewriter.notifyMatchFailure(genericOp, "fusion failed"); for (auto [origValue, replacement] : fusionResult->replacements) { rewriter.replaceUsesWithIf(origValue, replacement, [&](OpOperand &use) { return use.getOwner() != genericOp.getOperation(); }); } rewriter.eraseOp(genericOp); return success(); } }; struct TestLinalgElementwiseFusion : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion) TestLinalgElementwiseFusion() = default; TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass) : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } StringRef getArgument() const final { return "test-linalg-elementwise-fusion-patterns"; } StringRef getDescription() const final { return "Test Linalg element wise operation fusion patterns"; } Option fuseGenericOps{ *this, "fuse-generic-ops", llvm::cl::desc("Test fusion of generic operations."), llvm::cl::init(false)}; Option fuseGenericOpsControl{ *this, "fuse-generic-ops-control", llvm::cl::desc( "Test fusion of generic operations with a control function."), llvm::cl::init(false)}; Option fuseWithReshapeByExpansion{ *this, "fuse-with-reshape-by-expansion", llvm::cl::desc( "Test fusion of generic operations with reshape by expansion"), llvm::cl::init(false)}; Option controlFuseByExpansion{ *this, "control-fusion-by-expansion", llvm::cl::desc( "Test controlling fusion of reshape with generic op by expansion"), llvm::cl::init(false)}; Option fuseWithReshapeByCollapsing{ *this, "fuse-with-reshape-by-collapsing", llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that " "collapse the iteration space of the consumer"), llvm::cl::init(false)}; Option fuseWithReshapeByCollapsingWithControlFn{ *this, "fuse-with-reshape-by-collapsing-control", llvm::cl::desc("Test controlling the linalg expand_shape -> generic " "fusion patterns that " "collapse the iteration space of the consumer"), llvm::cl::init(false)}; Option fuseMultiUseProducer{ *this, "fuse-multiuse-producer", llvm::cl::desc("Test fusion of producer ops with multiple uses"), llvm::cl::init(false)}; ListOption collapseDimensions{ *this, "collapse-dimensions-control", llvm::cl::desc("Test controlling dimension collapse pattern")}; void runOnOperation() override { MLIRContext *context = &this->getContext(); func::FuncOp funcOp = this->getOperation(); if (fuseGenericOps) { RewritePatternSet fusionPatterns(context); auto controlFn = [](OpOperand *operand) { return true; }; linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn); if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(fusionPatterns)))) return signalPassFailure(); return; } if (fuseGenericOpsControl) { RewritePatternSet fusionPatterns(context); linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, setFusedOpOperandLimit<4>); if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(fusionPatterns)))) return signalPassFailure(); return; } if (fuseWithReshapeByExpansion) { RewritePatternSet fusionPatterns(context); linalg::populateFoldReshapeOpsByExpansionPatterns( fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; }); if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(fusionPatterns)))) return signalPassFailure(); return; } if (controlFuseByExpansion) { RewritePatternSet fusionPatterns(context); linalg::ControlFusionFn controlReshapeFusionFn = [](OpOperand *fusedOperand) { Operation *producer = fusedOperand->get().getDefiningOp(); if (!producer) return false; if (auto collapseOp = dyn_cast(producer)) { if (!collapseOp.getSrc().getDefiningOp()) { return false; } } Operation *consumer = fusedOperand->getOwner(); if (auto expandOp = dyn_cast(consumer)) { if (expandOp->hasOneUse()) { OpOperand &use = *expandOp->getUses().begin(); auto linalgOp = dyn_cast(use.getOwner()); if (linalgOp && linalgOp.isDpsInit(&use)) return true; } return false; } return true; }; linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, controlReshapeFusionFn); if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(fusionPatterns)))) return signalPassFailure(); return; } if (fuseWithReshapeByCollapsing) { RewritePatternSet patterns(context); linalg::populateFoldReshapeOpsByCollapsingPatterns( patterns, [](OpOperand * /*fusedOperand */) { return true; }); if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure(); return; } if (fuseWithReshapeByCollapsingWithControlFn) { RewritePatternSet patterns(context); linalg::ControlFusionFn controlFn = [](OpOperand *fusedOperand) -> bool { Operation *producer = fusedOperand->get().getDefiningOp(); if (isa(producer)) { // Skip fusing the first operand. return fusedOperand->getOperandNumber(); } return true; }; linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn); if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure(); return; } if (fuseMultiUseProducer) { RewritePatternSet patterns(context); patterns.insert(context); if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure(); return; } if (!collapseDimensions.empty()) { SmallVector dims(collapseDimensions.begin(), collapseDimensions.end()); linalg::GetCollapsableDimensionsFn collapseFn = [&dims](linalg::LinalgOp op) { SmallVector reassociations; reassociations.emplace_back(dims); return reassociations; }; RewritePatternSet patterns(context); linalg::populateCollapseDimensions(patterns, collapseFn); if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure(); return; } } }; } // namespace namespace mlir { namespace test { void registerTestLinalgElementwiseFusion() { PassRegistration(); } } // namespace test } // namespace mlir