1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass for testing fusion of elementwise operations in 10 // Linalg, mainly linalg options. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Pass/PassManager.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 22 using namespace mlir; 23 24 static void addOperands(Operation *op, SetVector<Value> &operandSet) { 25 if (!op) 26 return; 27 TypeSwitch<Operation *, void>(op) 28 .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) { 29 SmallVector<Value> inputOperands = linalgOp.getDpsInputs(); 30 operandSet.insert(inputOperands.begin(), inputOperands.end()); 31 }) 32 .Default([&](Operation *operation) { 33 operandSet.insert(operation->operand_begin(), operation->operand_end()); 34 }); 35 } 36 37 template <int limit = 3> 38 static bool setFusedOpOperandLimit(OpOperand *fusedOperand) { 39 Operation *producer = fusedOperand->get().getDefiningOp(); 40 if (!producer) 41 return false; 42 43 Operation *consumer = fusedOperand->getOwner(); 44 SetVector<Value> fusedOpOperands; 45 if (producer->getNumResults() != 1) 46 return false; 47 addOperands(consumer, fusedOpOperands); 48 fusedOpOperands.remove(producer->getResult(0)); 49 addOperands(producer, fusedOpOperands); 50 return fusedOpOperands.size() <= limit; 51 } 52 53 namespace { 54 55 /// Pattern to test fusion of producer with consumer, even if producer has 56 /// multiple uses. 57 struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> { 58 using OpRewritePattern<linalg::GenericOp>::OpRewritePattern; 59 60 LogicalResult matchAndRewrite(linalg::GenericOp genericOp, 61 PatternRewriter &rewriter) const override { 62 OpOperand *fusableOperand = nullptr; 63 for (OpOperand &operand : genericOp->getOpOperands()) { 64 if (linalg::areElementwiseOpsFusable(&operand)) { 65 fusableOperand = &operand; 66 break; 67 } 68 } 69 if (!fusableOperand) { 70 return rewriter.notifyMatchFailure(genericOp, "no fusable operand found"); 71 } 72 std::optional<linalg::ElementwiseOpFusionResult> fusionResult = 73 linalg::fuseElementwiseOps(rewriter, fusableOperand); 74 if (!fusionResult) 75 return rewriter.notifyMatchFailure(genericOp, "fusion failed"); 76 for (auto [origValue, replacement] : fusionResult->replacements) { 77 rewriter.replaceUsesWithIf(origValue, replacement, [&](OpOperand &use) { 78 return use.getOwner() != genericOp.getOperation(); 79 }); 80 } 81 rewriter.eraseOp(genericOp); 82 return success(); 83 } 84 }; 85 86 struct TestLinalgElementwiseFusion 87 : public PassWrapper<TestLinalgElementwiseFusion, 88 OperationPass<func::FuncOp>> { 89 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion) 90 91 TestLinalgElementwiseFusion() = default; 92 TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass) 93 : PassWrapper(pass) {} 94 void getDependentDialects(DialectRegistry ®istry) const override { 95 registry.insert<affine::AffineDialect, linalg::LinalgDialect, 96 memref::MemRefDialect, tensor::TensorDialect>(); 97 } 98 StringRef getArgument() const final { 99 return "test-linalg-elementwise-fusion-patterns"; 100 } 101 StringRef getDescription() const final { 102 return "Test Linalg element wise operation fusion patterns"; 103 } 104 105 Option<bool> fuseGenericOps{ 106 *this, "fuse-generic-ops", 107 llvm::cl::desc("Test fusion of generic operations."), 108 llvm::cl::init(false)}; 109 110 Option<bool> fuseGenericOpsControl{ 111 *this, "fuse-generic-ops-control", 112 llvm::cl::desc( 113 "Test fusion of generic operations with a control function."), 114 llvm::cl::init(false)}; 115 116 Option<bool> fuseWithReshapeByExpansion{ 117 *this, "fuse-with-reshape-by-expansion", 118 llvm::cl::desc( 119 "Test fusion of generic operations with reshape by expansion"), 120 llvm::cl::init(false)}; 121 122 Option<bool> controlFuseByExpansion{ 123 *this, "control-fusion-by-expansion", 124 llvm::cl::desc( 125 "Test controlling fusion of reshape with generic op by expansion"), 126 llvm::cl::init(false)}; 127 128 Option<bool> fuseWithReshapeByCollapsing{ 129 *this, "fuse-with-reshape-by-collapsing", 130 llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that " 131 "collapse the iteration space of the consumer"), 132 llvm::cl::init(false)}; 133 134 Option<bool> fuseWithReshapeByCollapsingWithControlFn{ 135 *this, "fuse-with-reshape-by-collapsing-control", 136 llvm::cl::desc("Test controlling the linalg expand_shape -> generic " 137 "fusion patterns that " 138 "collapse the iteration space of the consumer"), 139 llvm::cl::init(false)}; 140 141 Option<bool> fuseMultiUseProducer{ 142 *this, "fuse-multiuse-producer", 143 llvm::cl::desc("Test fusion of producer ops with multiple uses"), 144 llvm::cl::init(false)}; 145 146 ListOption<int64_t> collapseDimensions{ 147 *this, "collapse-dimensions-control", 148 llvm::cl::desc("Test controlling dimension collapse pattern")}; 149 150 void runOnOperation() override { 151 MLIRContext *context = &this->getContext(); 152 func::FuncOp funcOp = this->getOperation(); 153 154 if (fuseGenericOps) { 155 RewritePatternSet fusionPatterns(context); 156 auto controlFn = [](OpOperand *operand) { return true; }; 157 linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn); 158 if (failed(applyPatternsGreedily(funcOp.getBody(), 159 std::move(fusionPatterns)))) 160 return signalPassFailure(); 161 return; 162 } 163 164 if (fuseGenericOpsControl) { 165 RewritePatternSet fusionPatterns(context); 166 linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, 167 setFusedOpOperandLimit<4>); 168 169 if (failed(applyPatternsGreedily(funcOp.getBody(), 170 std::move(fusionPatterns)))) 171 return signalPassFailure(); 172 return; 173 } 174 175 if (fuseWithReshapeByExpansion) { 176 RewritePatternSet fusionPatterns(context); 177 linalg::populateFoldReshapeOpsByExpansionPatterns( 178 fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; }); 179 if (failed(applyPatternsGreedily(funcOp.getBody(), 180 std::move(fusionPatterns)))) 181 return signalPassFailure(); 182 return; 183 } 184 185 if (controlFuseByExpansion) { 186 RewritePatternSet fusionPatterns(context); 187 188 linalg::ControlFusionFn controlReshapeFusionFn = 189 [](OpOperand *fusedOperand) { 190 Operation *producer = fusedOperand->get().getDefiningOp(); 191 if (!producer) 192 return false; 193 194 if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(producer)) { 195 if (!collapseOp.getSrc().getDefiningOp<linalg::LinalgOp>()) { 196 return false; 197 } 198 } 199 200 Operation *consumer = fusedOperand->getOwner(); 201 if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(consumer)) { 202 if (expandOp->hasOneUse()) { 203 OpOperand &use = *expandOp->getUses().begin(); 204 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner()); 205 if (linalgOp && linalgOp.isDpsInit(&use)) 206 return true; 207 } 208 return false; 209 } 210 return true; 211 }; 212 213 linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, 214 controlReshapeFusionFn); 215 if (failed(applyPatternsGreedily(funcOp.getBody(), 216 std::move(fusionPatterns)))) 217 return signalPassFailure(); 218 return; 219 } 220 221 if (fuseWithReshapeByCollapsing) { 222 RewritePatternSet patterns(context); 223 linalg::populateFoldReshapeOpsByCollapsingPatterns( 224 patterns, [](OpOperand * /*fusedOperand */) { return true; }); 225 if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) 226 return signalPassFailure(); 227 return; 228 } 229 230 if (fuseWithReshapeByCollapsingWithControlFn) { 231 RewritePatternSet patterns(context); 232 linalg::ControlFusionFn controlFn = [](OpOperand *fusedOperand) -> bool { 233 Operation *producer = fusedOperand->get().getDefiningOp(); 234 if (isa<tensor::ExpandShapeOp>(producer)) { 235 // Skip fusing the first operand. 236 return fusedOperand->getOperandNumber(); 237 } 238 return true; 239 }; 240 linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn); 241 if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) 242 return signalPassFailure(); 243 return; 244 } 245 246 if (fuseMultiUseProducer) { 247 RewritePatternSet patterns(context); 248 patterns.insert<TestMultiUseProducerFusion>(context); 249 if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) 250 return signalPassFailure(); 251 return; 252 } 253 254 if (!collapseDimensions.empty()) { 255 SmallVector<int64_t, 2> dims(collapseDimensions.begin(), 256 collapseDimensions.end()); 257 linalg::GetCollapsableDimensionsFn collapseFn = 258 [&dims](linalg::LinalgOp op) { 259 SmallVector<ReassociationIndices> reassociations; 260 reassociations.emplace_back(dims); 261 return reassociations; 262 }; 263 RewritePatternSet patterns(context); 264 linalg::populateCollapseDimensions(patterns, collapseFn); 265 if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) 266 return signalPassFailure(); 267 return; 268 } 269 } 270 }; 271 272 } // namespace 273 274 namespace mlir { 275 namespace test { 276 void registerTestLinalgElementwiseFusion() { 277 PassRegistration<TestLinalgElementwiseFusion>(); 278 } 279 } // namespace test 280 } // namespace mlir 281