xref: /llvm-project/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg fusion patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "mlir/Transforms/Passes.h"
21 
22 using namespace mlir;
23 using namespace mlir::linalg;
24 
25 static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
26   OpBuilder b(f);
27 
28   // Save original Linalg ops, we only want to make a pass over those.
29   SmallVector<LinalgOp, 8> linalgOps;
30   f.walk([&](LinalgOp op) {
31     // TODO: support multi-results.
32     if (op->getNumResults() <= 1)
33       linalgOps.push_back(op);
34   });
35 
36   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
37   bool changed = false;
38   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
39     for (OpOperand &opOperand : linalgOp->getOpOperands()) {
40       if (isa<MemRefType>(opOperand.get().getType()))
41         continue;
42       if (isa<RankedTensorType>(opOperand.get().getType())) {
43         // Tile and Fuse tensor input.
44         if (opOperand.getOperandNumber() >= linalgOp.getNumDpsInputs())
45           continue;
46         auto info = fuseProducerOfTensor(b, opOperand);
47         if (failed(info))
48           continue;
49         auto *originalOp = info->originalProducer.getOperation();
50         auto *originalOpInLinalgOpsVector =
51             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
52         *originalOpInLinalgOpsVector = info->fusedProducer;
53         // Don't mark for erasure in the tensor case, let DCE handle this.
54         changed = true;
55       }
56     }
57   }
58 
59   return changed ? success() : failure();
60 }
61 
62 namespace {
63 struct TestLinalgGreedyFusion
64     : public PassWrapper<TestLinalgGreedyFusion, OperationPass<func::FuncOp>> {
65   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgGreedyFusion)
66 
67   void getDependentDialects(DialectRegistry &registry) const override {
68     registry.insert<affine::AffineDialect, linalg::LinalgDialect,
69                     memref::MemRefDialect, scf::SCFDialect>();
70   }
71   StringRef getArgument() const final { return "test-linalg-greedy-fusion"; }
72   StringRef getDescription() const final {
73     return "Test Linalg fusion by applying a greedy test transformation.";
74   }
75   void runOnOperation() override {
76     MLIRContext *context = &getContext();
77     RewritePatternSet patterns =
78         linalg::getLinalgTilingCanonicalizationPatterns(context);
79     patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
80     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
81     FrozenRewritePatternSet frozenPatterns(std::move(patterns));
82     OpPassManager pm(func::FuncOp::getOperationName());
83     pm.addPass(createLoopInvariantCodeMotionPass());
84     pm.addPass(createCanonicalizerPass());
85     pm.addPass(createCSEPass());
86     do {
87       (void)applyPatternsGreedily(getOperation(), frozenPatterns);
88       if (failed(runPipeline(pm, getOperation())))
89         this->signalPassFailure();
90     } while (succeeded(fuseLinalgOpsGreedily(getOperation())));
91   }
92 };
93 } // namespace
94 
95 namespace mlir {
96 namespace test {
97 void registerTestLinalgGreedyFusion() {
98   PassRegistration<TestLinalgGreedyFusion>();
99 }
100 
101 } // namespace test
102 } // namespace mlir
103