1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg fusion patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/SCF/Transforms/Patterns.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Pass/PassManager.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "mlir/Transforms/Passes.h" 21 22 using namespace mlir; 23 using namespace mlir::linalg; 24 25 static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) { 26 OpBuilder b(f); 27 28 // Save original Linalg ops, we only want to make a pass over those. 29 SmallVector<LinalgOp, 8> linalgOps; 30 f.walk([&](LinalgOp op) { 31 // TODO: support multi-results. 32 if (op->getNumResults() <= 1) 33 linalgOps.push_back(op); 34 }); 35 36 // Tile and Fuse for tensors inputs (TODO: all tensor operands). 37 bool changed = false; 38 for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { 39 for (OpOperand &opOperand : linalgOp->getOpOperands()) { 40 if (isa<MemRefType>(opOperand.get().getType())) 41 continue; 42 if (isa<RankedTensorType>(opOperand.get().getType())) { 43 // Tile and Fuse tensor input. 44 if (opOperand.getOperandNumber() >= linalgOp.getNumDpsInputs()) 45 continue; 46 auto info = fuseProducerOfTensor(b, opOperand); 47 if (failed(info)) 48 continue; 49 auto *originalOp = info->originalProducer.getOperation(); 50 auto *originalOpInLinalgOpsVector = 51 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 52 *originalOpInLinalgOpsVector = info->fusedProducer; 53 // Don't mark for erasure in the tensor case, let DCE handle this. 54 changed = true; 55 } 56 } 57 } 58 59 return changed ? success() : failure(); 60 } 61 62 namespace { 63 struct TestLinalgGreedyFusion 64 : public PassWrapper<TestLinalgGreedyFusion, OperationPass<func::FuncOp>> { 65 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgGreedyFusion) 66 67 void getDependentDialects(DialectRegistry ®istry) const override { 68 registry.insert<affine::AffineDialect, linalg::LinalgDialect, 69 memref::MemRefDialect, scf::SCFDialect>(); 70 } 71 StringRef getArgument() const final { return "test-linalg-greedy-fusion"; } 72 StringRef getDescription() const final { 73 return "Test Linalg fusion by applying a greedy test transformation."; 74 } 75 void runOnOperation() override { 76 MLIRContext *context = &getContext(); 77 RewritePatternSet patterns = 78 linalg::getLinalgTilingCanonicalizationPatterns(context); 79 patterns.add<ExtractSliceOfPadTensorSwapPattern>(context); 80 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 81 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 82 OpPassManager pm(func::FuncOp::getOperationName()); 83 pm.addPass(createLoopInvariantCodeMotionPass()); 84 pm.addPass(createCanonicalizerPass()); 85 pm.addPass(createCSEPass()); 86 do { 87 (void)applyPatternsGreedily(getOperation(), frozenPatterns); 88 if (failed(runPipeline(pm, getOperation()))) 89 this->signalPassFailure(); 90 } while (succeeded(fuseLinalgOpsGreedily(getOperation()))); 91 } 92 }; 93 } // namespace 94 95 namespace mlir { 96 namespace test { 97 void registerTestLinalgGreedyFusion() { 98 PassRegistration<TestLinalgGreedyFusion>(); 99 } 100 101 } // namespace test 102 } // namespace mlir 103