xref: /llvm-project/mlir/lib/Transforms/CompositePass.cpp (revision 5b66b6a32ad89562732ad6a81c84783486b6187a)
1 //===- CompositePass.cpp - Composite pass code ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // CompositePass allows to run set of passes until fixed point is reached.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/Passes.h"
14 
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17 
18 namespace mlir {
19 #define GEN_PASS_DEF_COMPOSITEFIXEDPOINTPASS
20 #include "mlir/Transforms/Passes.h.inc"
21 } // namespace mlir
22 
23 using namespace mlir;
24 
25 namespace {
26 struct CompositeFixedPointPass final
27     : public impl::CompositeFixedPointPassBase<CompositeFixedPointPass> {
28   using CompositeFixedPointPassBase::CompositeFixedPointPassBase;
29 
CompositeFixedPointPass__anon3ef978240111::CompositeFixedPointPass30   CompositeFixedPointPass(
31       std::string name_, llvm::function_ref<void(OpPassManager &)> populateFunc,
32       int maxIterations) {
33     name = std::move(name_);
34     maxIter = maxIterations;
35     populateFunc(dynamicPM);
36 
37     llvm::raw_string_ostream os(pipelineStr);
38     dynamicPM.printAsTextualPipeline(os);
39   }
40 
initializeOptions__anon3ef978240111::CompositeFixedPointPass41   LogicalResult initializeOptions(
42       StringRef options,
43       function_ref<LogicalResult(const Twine &)> errorHandler) override {
44     if (failed(CompositeFixedPointPassBase::initializeOptions(options,
45                                                               errorHandler)))
46       return failure();
47 
48     if (failed(parsePassPipeline(pipelineStr, dynamicPM)))
49       return errorHandler("Failed to parse composite pass pipeline");
50 
51     return success();
52   }
53 
initialize__anon3ef978240111::CompositeFixedPointPass54   LogicalResult initialize(MLIRContext *context) override {
55     if (maxIter <= 0)
56       return emitError(UnknownLoc::get(context))
57              << "Invalid maxIterations value: " << maxIter << "\n";
58 
59     return success();
60   }
61 
getDependentDialects__anon3ef978240111::CompositeFixedPointPass62   void getDependentDialects(DialectRegistry &registry) const override {
63     dynamicPM.getDependentDialects(registry);
64   }
65 
runOnOperation__anon3ef978240111::CompositeFixedPointPass66   void runOnOperation() override {
67     auto op = getOperation();
68     OperationFingerPrint fp(op);
69 
70     int currentIter = 0;
71     int maxIterVal = maxIter;
72     while (true) {
73       if (failed(runPipeline(dynamicPM, op)))
74         return signalPassFailure();
75 
76       if (currentIter++ >= maxIterVal) {
77         op->emitWarning("Composite pass \"" + llvm::Twine(name) +
78                         "\"+ didn't converge in " + llvm::Twine(maxIterVal) +
79                         " iterations");
80         break;
81       }
82 
83       OperationFingerPrint newFp(op);
84       if (newFp == fp)
85         break;
86 
87       fp = newFp;
88     }
89   }
90 
91 protected:
getName__anon3ef978240111::CompositeFixedPointPass92   llvm::StringRef getName() const override { return name; }
93 
94 private:
95   OpPassManager dynamicPM;
96 };
97 } // namespace
98 
createCompositeFixedPointPass(std::string name,llvm::function_ref<void (OpPassManager &)> populateFunc,int maxIterations)99 std::unique_ptr<Pass> mlir::createCompositeFixedPointPass(
100     std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
101     int maxIterations) {
102 
103   return std::make_unique<CompositeFixedPointPass>(std::move(name),
104                                                    populateFunc, maxIterations);
105 }
106