xref: /llvm-project/mlir/lib/Transforms/CompositePass.cpp (revision 5b66b6a32ad89562732ad6a81c84783486b6187a)
1*5b66b6a3SIvan Butygin //===- CompositePass.cpp - Composite pass code ----------------------------===//
2*5b66b6a3SIvan Butygin //
3*5b66b6a3SIvan Butygin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*5b66b6a3SIvan Butygin // See https://llvm.org/LICENSE.txt for license information.
5*5b66b6a3SIvan Butygin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*5b66b6a3SIvan Butygin //
7*5b66b6a3SIvan Butygin //===----------------------------------------------------------------------===//
8*5b66b6a3SIvan Butygin //
9*5b66b6a3SIvan Butygin // CompositePass allows to run set of passes until fixed point is reached.
10*5b66b6a3SIvan Butygin //
11*5b66b6a3SIvan Butygin //===----------------------------------------------------------------------===//
12*5b66b6a3SIvan Butygin 
13*5b66b6a3SIvan Butygin #include "mlir/Transforms/Passes.h"
14*5b66b6a3SIvan Butygin 
15*5b66b6a3SIvan Butygin #include "mlir/Pass/Pass.h"
16*5b66b6a3SIvan Butygin #include "mlir/Pass/PassManager.h"
17*5b66b6a3SIvan Butygin 
18*5b66b6a3SIvan Butygin namespace mlir {
19*5b66b6a3SIvan Butygin #define GEN_PASS_DEF_COMPOSITEFIXEDPOINTPASS
20*5b66b6a3SIvan Butygin #include "mlir/Transforms/Passes.h.inc"
21*5b66b6a3SIvan Butygin } // namespace mlir
22*5b66b6a3SIvan Butygin 
23*5b66b6a3SIvan Butygin using namespace mlir;
24*5b66b6a3SIvan Butygin 
25*5b66b6a3SIvan Butygin namespace {
26*5b66b6a3SIvan Butygin struct CompositeFixedPointPass final
27*5b66b6a3SIvan Butygin     : public impl::CompositeFixedPointPassBase<CompositeFixedPointPass> {
28*5b66b6a3SIvan Butygin   using CompositeFixedPointPassBase::CompositeFixedPointPassBase;
29*5b66b6a3SIvan Butygin 
CompositeFixedPointPass__anon3ef978240111::CompositeFixedPointPass30*5b66b6a3SIvan Butygin   CompositeFixedPointPass(
31*5b66b6a3SIvan Butygin       std::string name_, llvm::function_ref<void(OpPassManager &)> populateFunc,
32*5b66b6a3SIvan Butygin       int maxIterations) {
33*5b66b6a3SIvan Butygin     name = std::move(name_);
34*5b66b6a3SIvan Butygin     maxIter = maxIterations;
35*5b66b6a3SIvan Butygin     populateFunc(dynamicPM);
36*5b66b6a3SIvan Butygin 
37*5b66b6a3SIvan Butygin     llvm::raw_string_ostream os(pipelineStr);
38*5b66b6a3SIvan Butygin     dynamicPM.printAsTextualPipeline(os);
39*5b66b6a3SIvan Butygin   }
40*5b66b6a3SIvan Butygin 
initializeOptions__anon3ef978240111::CompositeFixedPointPass41*5b66b6a3SIvan Butygin   LogicalResult initializeOptions(
42*5b66b6a3SIvan Butygin       StringRef options,
43*5b66b6a3SIvan Butygin       function_ref<LogicalResult(const Twine &)> errorHandler) override {
44*5b66b6a3SIvan Butygin     if (failed(CompositeFixedPointPassBase::initializeOptions(options,
45*5b66b6a3SIvan Butygin                                                               errorHandler)))
46*5b66b6a3SIvan Butygin       return failure();
47*5b66b6a3SIvan Butygin 
48*5b66b6a3SIvan Butygin     if (failed(parsePassPipeline(pipelineStr, dynamicPM)))
49*5b66b6a3SIvan Butygin       return errorHandler("Failed to parse composite pass pipeline");
50*5b66b6a3SIvan Butygin 
51*5b66b6a3SIvan Butygin     return success();
52*5b66b6a3SIvan Butygin   }
53*5b66b6a3SIvan Butygin 
initialize__anon3ef978240111::CompositeFixedPointPass54*5b66b6a3SIvan Butygin   LogicalResult initialize(MLIRContext *context) override {
55*5b66b6a3SIvan Butygin     if (maxIter <= 0)
56*5b66b6a3SIvan Butygin       return emitError(UnknownLoc::get(context))
57*5b66b6a3SIvan Butygin              << "Invalid maxIterations value: " << maxIter << "\n";
58*5b66b6a3SIvan Butygin 
59*5b66b6a3SIvan Butygin     return success();
60*5b66b6a3SIvan Butygin   }
61*5b66b6a3SIvan Butygin 
getDependentDialects__anon3ef978240111::CompositeFixedPointPass62*5b66b6a3SIvan Butygin   void getDependentDialects(DialectRegistry &registry) const override {
63*5b66b6a3SIvan Butygin     dynamicPM.getDependentDialects(registry);
64*5b66b6a3SIvan Butygin   }
65*5b66b6a3SIvan Butygin 
runOnOperation__anon3ef978240111::CompositeFixedPointPass66*5b66b6a3SIvan Butygin   void runOnOperation() override {
67*5b66b6a3SIvan Butygin     auto op = getOperation();
68*5b66b6a3SIvan Butygin     OperationFingerPrint fp(op);
69*5b66b6a3SIvan Butygin 
70*5b66b6a3SIvan Butygin     int currentIter = 0;
71*5b66b6a3SIvan Butygin     int maxIterVal = maxIter;
72*5b66b6a3SIvan Butygin     while (true) {
73*5b66b6a3SIvan Butygin       if (failed(runPipeline(dynamicPM, op)))
74*5b66b6a3SIvan Butygin         return signalPassFailure();
75*5b66b6a3SIvan Butygin 
76*5b66b6a3SIvan Butygin       if (currentIter++ >= maxIterVal) {
77*5b66b6a3SIvan Butygin         op->emitWarning("Composite pass \"" + llvm::Twine(name) +
78*5b66b6a3SIvan Butygin                         "\"+ didn't converge in " + llvm::Twine(maxIterVal) +
79*5b66b6a3SIvan Butygin                         " iterations");
80*5b66b6a3SIvan Butygin         break;
81*5b66b6a3SIvan Butygin       }
82*5b66b6a3SIvan Butygin 
83*5b66b6a3SIvan Butygin       OperationFingerPrint newFp(op);
84*5b66b6a3SIvan Butygin       if (newFp == fp)
85*5b66b6a3SIvan Butygin         break;
86*5b66b6a3SIvan Butygin 
87*5b66b6a3SIvan Butygin       fp = newFp;
88*5b66b6a3SIvan Butygin     }
89*5b66b6a3SIvan Butygin   }
90*5b66b6a3SIvan Butygin 
91*5b66b6a3SIvan Butygin protected:
getName__anon3ef978240111::CompositeFixedPointPass92*5b66b6a3SIvan Butygin   llvm::StringRef getName() const override { return name; }
93*5b66b6a3SIvan Butygin 
94*5b66b6a3SIvan Butygin private:
95*5b66b6a3SIvan Butygin   OpPassManager dynamicPM;
96*5b66b6a3SIvan Butygin };
97*5b66b6a3SIvan Butygin } // namespace
98*5b66b6a3SIvan Butygin 
createCompositeFixedPointPass(std::string name,llvm::function_ref<void (OpPassManager &)> populateFunc,int maxIterations)99*5b66b6a3SIvan Butygin std::unique_ptr<Pass> mlir::createCompositeFixedPointPass(
100*5b66b6a3SIvan Butygin     std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
101*5b66b6a3SIvan Butygin     int maxIterations) {
102*5b66b6a3SIvan Butygin 
103*5b66b6a3SIvan Butygin   return std::make_unique<CompositeFixedPointPass>(std::move(name),
104*5b66b6a3SIvan Butygin                                                    populateFunc, maxIterations);
105*5b66b6a3SIvan Butygin }
106