xref: /llvm-project/mlir/lib/Bindings/Python/Pass.cpp (revision b3c5f6b15b1eaa2552ce62329208ece5166356fe)
1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Pass.h"
10 
11 #include "IRModule.h"
12 #include "mlir-c/Bindings/Python/Interop.h"
13 #include "mlir-c/Pass.h"
14 
15 namespace py = pybind11;
16 using namespace mlir;
17 using namespace mlir::python;
18 
19 namespace {
20 
21 /// Owning Wrapper around a PassManager.
22 class PyPassManager {
23 public:
24   PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
25   PyPassManager(PyPassManager &&other) : passManager(other.passManager) {
26     other.passManager.ptr = nullptr;
27   }
28   ~PyPassManager() {
29     if (!mlirPassManagerIsNull(passManager))
30       mlirPassManagerDestroy(passManager);
31   }
32   MlirPassManager get() { return passManager; }
33 
34   void release() { passManager.ptr = nullptr; }
35   pybind11::object getCapsule() {
36     return py::reinterpret_steal<py::object>(
37         mlirPythonPassManagerToCapsule(get()));
38   }
39 
40   static pybind11::object createFromCapsule(pybind11::object capsule) {
41     MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
42     if (mlirPassManagerIsNull(rawPm))
43       throw py::error_already_set();
44     return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
45   }
46 
47 private:
48   MlirPassManager passManager;
49 };
50 
51 } // namespace
52 
53 /// Create the `mlir.passmanager` here.
54 void mlir::python::populatePassManagerSubmodule(py::module &m) {
55   //----------------------------------------------------------------------------
56   // Mapping of the top-level PassManager
57   //----------------------------------------------------------------------------
58   py::class_<PyPassManager>(m, "PassManager", py::module_local())
59       .def(py::init<>([](DefaultingPyMlirContext context) {
60              MlirPassManager passManager =
61                  mlirPassManagerCreate(context->get());
62              return new PyPassManager(passManager);
63            }),
64            py::arg("context") = py::none(),
65            "Create a new PassManager for the current (or provided) Context.")
66       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
67                              &PyPassManager::getCapsule)
68       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
69       .def("_testing_release", &PyPassManager::release,
70            "Releases (leaks) the backing pass manager (testing)")
71       .def(
72           "enable_ir_printing",
73           [](PyPassManager &passManager) {
74             mlirPassManagerEnableIRPrinting(passManager.get());
75           },
76           "Enable mlir-print-ir-after-all.")
77       .def(
78           "enable_verifier",
79           [](PyPassManager &passManager, bool enable) {
80             mlirPassManagerEnableVerifier(passManager.get(), enable);
81           },
82           py::arg("enable"), "Enable / disable verify-each.")
83       .def_static(
84           "parse",
85           [](const std::string &pipeline, DefaultingPyMlirContext context) {
86             MlirPassManager passManager = mlirPassManagerCreate(context->get());
87             PyPrintAccumulator errorMsg;
88             MlirLogicalResult status = mlirOpPassManagerAddPipeline(
89                 mlirPassManagerGetAsOpPassManager(passManager),
90                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
91                 errorMsg.getCallback(), errorMsg.getUserData());
92             if (mlirLogicalResultIsFailure(status))
93               throw SetPyError(PyExc_ValueError, std::string(errorMsg.join()));
94             return new PyPassManager(passManager);
95           },
96           py::arg("pipeline"), py::arg("context") = py::none(),
97           "Parse a textual pass-pipeline and return a top-level PassManager "
98           "that can be applied on a Module. Throw a ValueError if the pipeline "
99           "can't be parsed")
100       .def(
101           "run",
102           [](PyPassManager &passManager, PyModule &module) {
103             MlirLogicalResult status =
104                 mlirPassManagerRun(passManager.get(), module.get());
105             if (mlirLogicalResultIsFailure(status))
106               throw SetPyError(PyExc_RuntimeError,
107                                "Failure while executing pass pipeline.");
108           },
109           py::arg("module"),
110           "Run the pass manager on the provided module, throw a RuntimeError "
111           "on failure.")
112       .def(
113           "__str__",
114           [](PyPassManager &self) {
115             MlirPassManager passManager = self.get();
116             PyPrintAccumulator printAccum;
117             mlirPrintPassPipeline(
118                 mlirPassManagerGetAsOpPassManager(passManager),
119                 printAccum.getCallback(), printAccum.getUserData());
120             return printAccum.join();
121           },
122           "Print the textual representation for this PassManager, suitable to "
123           "be passed to `parse` for round-tripping.");
124 }
125