1 //===- CompositePass.cpp - Composite pass code ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // CompositePass allows to run set of passes until fixed point is reached.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Transforms/Passes.h"
14
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17
18 namespace mlir {
19 #define GEN_PASS_DEF_COMPOSITEFIXEDPOINTPASS
20 #include "mlir/Transforms/Passes.h.inc"
21 } // namespace mlir
22
23 using namespace mlir;
24
25 namespace {
26 struct CompositeFixedPointPass final
27 : public impl::CompositeFixedPointPassBase<CompositeFixedPointPass> {
28 using CompositeFixedPointPassBase::CompositeFixedPointPassBase;
29
CompositeFixedPointPass__anon3ef978240111::CompositeFixedPointPass30 CompositeFixedPointPass(
31 std::string name_, llvm::function_ref<void(OpPassManager &)> populateFunc,
32 int maxIterations) {
33 name = std::move(name_);
34 maxIter = maxIterations;
35 populateFunc(dynamicPM);
36
37 llvm::raw_string_ostream os(pipelineStr);
38 dynamicPM.printAsTextualPipeline(os);
39 }
40
initializeOptions__anon3ef978240111::CompositeFixedPointPass41 LogicalResult initializeOptions(
42 StringRef options,
43 function_ref<LogicalResult(const Twine &)> errorHandler) override {
44 if (failed(CompositeFixedPointPassBase::initializeOptions(options,
45 errorHandler)))
46 return failure();
47
48 if (failed(parsePassPipeline(pipelineStr, dynamicPM)))
49 return errorHandler("Failed to parse composite pass pipeline");
50
51 return success();
52 }
53
initialize__anon3ef978240111::CompositeFixedPointPass54 LogicalResult initialize(MLIRContext *context) override {
55 if (maxIter <= 0)
56 return emitError(UnknownLoc::get(context))
57 << "Invalid maxIterations value: " << maxIter << "\n";
58
59 return success();
60 }
61
getDependentDialects__anon3ef978240111::CompositeFixedPointPass62 void getDependentDialects(DialectRegistry ®istry) const override {
63 dynamicPM.getDependentDialects(registry);
64 }
65
runOnOperation__anon3ef978240111::CompositeFixedPointPass66 void runOnOperation() override {
67 auto op = getOperation();
68 OperationFingerPrint fp(op);
69
70 int currentIter = 0;
71 int maxIterVal = maxIter;
72 while (true) {
73 if (failed(runPipeline(dynamicPM, op)))
74 return signalPassFailure();
75
76 if (currentIter++ >= maxIterVal) {
77 op->emitWarning("Composite pass \"" + llvm::Twine(name) +
78 "\"+ didn't converge in " + llvm::Twine(maxIterVal) +
79 " iterations");
80 break;
81 }
82
83 OperationFingerPrint newFp(op);
84 if (newFp == fp)
85 break;
86
87 fp = newFp;
88 }
89 }
90
91 protected:
getName__anon3ef978240111::CompositeFixedPointPass92 llvm::StringRef getName() const override { return name; }
93
94 private:
95 OpPassManager dynamicPM;
96 };
97 } // namespace
98
createCompositeFixedPointPass(std::string name,llvm::function_ref<void (OpPassManager &)> populateFunc,int maxIterations)99 std::unique_ptr<Pass> mlir::createCompositeFixedPointPass(
100 std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
101 int maxIterations) {
102
103 return std::make_unique<CompositeFixedPointPass>(std::move(name),
104 populateFunc, maxIterations);
105 }
106