xref: /llvm-project/mlir/lib/Bindings/Python/Pass.cpp (revision fa19ef7a10869bf0f8325681be111f7d97b2544e)
1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Pass.h"
10 
11 #include "IRModule.h"
12 #include "mlir-c/Bindings/Python/Interop.h"
13 #include "mlir-c/Pass.h"
14 
15 namespace py = pybind11;
16 using namespace py::literals;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 namespace {
21 
22 /// Owning Wrapper around a PassManager.
23 class PyPassManager {
24 public:
25   PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
26   PyPassManager(PyPassManager &&other) : passManager(other.passManager) {
27     other.passManager.ptr = nullptr;
28   }
29   ~PyPassManager() {
30     if (!mlirPassManagerIsNull(passManager))
31       mlirPassManagerDestroy(passManager);
32   }
33   MlirPassManager get() { return passManager; }
34 
35   void release() { passManager.ptr = nullptr; }
36   pybind11::object getCapsule() {
37     return py::reinterpret_steal<py::object>(
38         mlirPythonPassManagerToCapsule(get()));
39   }
40 
41   static pybind11::object createFromCapsule(pybind11::object capsule) {
42     MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
43     if (mlirPassManagerIsNull(rawPm))
44       throw py::error_already_set();
45     return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
46   }
47 
48 private:
49   MlirPassManager passManager;
50 };
51 
52 } // namespace
53 
54 /// Create the `mlir.passmanager` here.
55 void mlir::python::populatePassManagerSubmodule(py::module &m) {
56   //----------------------------------------------------------------------------
57   // Mapping of the top-level PassManager
58   //----------------------------------------------------------------------------
59   py::class_<PyPassManager>(m, "PassManager", py::module_local())
60       .def(py::init<>([](const std::string &anchorOp,
61                          DefaultingPyMlirContext context) {
62              MlirPassManager passManager = mlirPassManagerCreateOnOperation(
63                  context->get(),
64                  mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
65              return new PyPassManager(passManager);
66            }),
67            "anchor_op"_a = py::str("any"), "context"_a = py::none(),
68            "Create a new PassManager for the current (or provided) Context.")
69       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
70                              &PyPassManager::getCapsule)
71       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
72       .def("_testing_release", &PyPassManager::release,
73            "Releases (leaks) the backing pass manager (testing)")
74       .def(
75           "enable_ir_printing",
76           [](PyPassManager &passManager) {
77             mlirPassManagerEnableIRPrinting(passManager.get());
78           },
79           "Enable mlir-print-ir-after-all.")
80       .def(
81           "enable_verifier",
82           [](PyPassManager &passManager, bool enable) {
83             mlirPassManagerEnableVerifier(passManager.get(), enable);
84           },
85           "enable"_a, "Enable / disable verify-each.")
86       .def_static(
87           "parse",
88           [](const std::string &pipeline, DefaultingPyMlirContext context) {
89             MlirPassManager passManager = mlirPassManagerCreate(context->get());
90             PyPrintAccumulator errorMsg;
91             MlirLogicalResult status = mlirParsePassPipeline(
92                 mlirPassManagerGetAsOpPassManager(passManager),
93                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
94                 errorMsg.getCallback(), errorMsg.getUserData());
95             if (mlirLogicalResultIsFailure(status))
96               throw py::value_error(std::string(errorMsg.join()));
97             return new PyPassManager(passManager);
98           },
99           "pipeline"_a, "context"_a = py::none(),
100           "Parse a textual pass-pipeline and return a top-level PassManager "
101           "that can be applied on a Module. Throw a ValueError if the pipeline "
102           "can't be parsed")
103       .def(
104           "add",
105           [](PyPassManager &passManager, const std::string &pipeline) {
106             PyPrintAccumulator errorMsg;
107             MlirLogicalResult status = mlirOpPassManagerAddPipeline(
108                 mlirPassManagerGetAsOpPassManager(passManager.get()),
109                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
110                 errorMsg.getCallback(), errorMsg.getUserData());
111             if (mlirLogicalResultIsFailure(status))
112               throw py::value_error(std::string(errorMsg.join()));
113           },
114           "pipeline"_a,
115           "Add textual pipeline elements to the pass manager. Throws a "
116           "ValueError if the pipeline can't be parsed.")
117       .def(
118           "run",
119           [](PyPassManager &passManager, PyOperationBase &op,
120              bool invalidateOps) {
121             if (invalidateOps) {
122               op.getOperation().getContext()->clearOperationsInside(op);
123             }
124             // Actually run the pass manager.
125             PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
126             MlirLogicalResult status = mlirPassManagerRunOnOp(
127                 passManager.get(), op.getOperation().get());
128             if (mlirLogicalResultIsFailure(status))
129               throw MLIRError("Failure while executing pass pipeline",
130                               errors.take());
131           },
132           "operation"_a, "invalidate_ops"_a = true,
133           "Run the pass manager on the provided operation, raising an "
134           "MLIRError on failure.")
135       .def(
136           "__str__",
137           [](PyPassManager &self) {
138             MlirPassManager passManager = self.get();
139             PyPrintAccumulator printAccum;
140             mlirPrintPassPipeline(
141                 mlirPassManagerGetAsOpPassManager(passManager),
142                 printAccum.getCallback(), printAccum.getUserData());
143             return printAccum.join();
144           },
145           "Print the textual representation for this PassManager, suitable to "
146           "be passed to `parse` for round-tripping.");
147 }
148