xref: /llvm-project/mlir/lib/Bindings/Python/Pass.cpp (revision f8eceb45d0bbca092164efffc92f2e9d66b304a5)
1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Pass.h"
10 
11 #include "IRModule.h"
12 #include "mlir-c/Bindings/Python/Interop.h"
13 #include "mlir-c/Pass.h"
14 
15 namespace py = pybind11;
16 using namespace py::literals;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 namespace {
21 
22 /// Owning Wrapper around a PassManager.
23 class PyPassManager {
24 public:
25   PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
26   PyPassManager(PyPassManager &&other) noexcept
27       : passManager(other.passManager) {
28     other.passManager.ptr = nullptr;
29   }
30   ~PyPassManager() {
31     if (!mlirPassManagerIsNull(passManager))
32       mlirPassManagerDestroy(passManager);
33   }
34   MlirPassManager get() { return passManager; }
35 
36   void release() { passManager.ptr = nullptr; }
37   pybind11::object getCapsule() {
38     return py::reinterpret_steal<py::object>(
39         mlirPythonPassManagerToCapsule(get()));
40   }
41 
42   static pybind11::object createFromCapsule(pybind11::object capsule) {
43     MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
44     if (mlirPassManagerIsNull(rawPm))
45       throw py::error_already_set();
46     return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
47   }
48 
49 private:
50   MlirPassManager passManager;
51 };
52 
53 } // namespace
54 
55 /// Create the `mlir.passmanager` here.
56 void mlir::python::populatePassManagerSubmodule(py::module &m) {
57   //----------------------------------------------------------------------------
58   // Mapping of the top-level PassManager
59   //----------------------------------------------------------------------------
60   py::class_<PyPassManager>(m, "PassManager", py::module_local())
61       .def(py::init<>([](const std::string &anchorOp,
62                          DefaultingPyMlirContext context) {
63              MlirPassManager passManager = mlirPassManagerCreateOnOperation(
64                  context->get(),
65                  mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
66              return new PyPassManager(passManager);
67            }),
68            "anchor_op"_a = py::str("any"), "context"_a = py::none(),
69            "Create a new PassManager for the current (or provided) Context.")
70       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
71                              &PyPassManager::getCapsule)
72       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
73       .def("_testing_release", &PyPassManager::release,
74            "Releases (leaks) the backing pass manager (testing)")
75       .def(
76           "enable_ir_printing",
77           [](PyPassManager &passManager, bool printBeforeAll,
78              bool printAfterAll, bool printModuleScope, bool printAfterChange,
79              bool printAfterFailure) {
80             mlirPassManagerEnableIRPrinting(
81                 passManager.get(), printBeforeAll, printAfterAll,
82                 printModuleScope, printAfterChange, printAfterFailure);
83           },
84           "print_before_all"_a = false, "print_after_all"_a = true,
85           "print_module_scope"_a = false, "print_after_change"_a = false,
86           "print_after_failure"_a = false,
87           "Enable IR printing, default as mlir-print-ir-after-all.")
88       .def(
89           "enable_verifier",
90           [](PyPassManager &passManager, bool enable) {
91             mlirPassManagerEnableVerifier(passManager.get(), enable);
92           },
93           "enable"_a, "Enable / disable verify-each.")
94       .def_static(
95           "parse",
96           [](const std::string &pipeline, DefaultingPyMlirContext context) {
97             MlirPassManager passManager = mlirPassManagerCreate(context->get());
98             PyPrintAccumulator errorMsg;
99             MlirLogicalResult status = mlirParsePassPipeline(
100                 mlirPassManagerGetAsOpPassManager(passManager),
101                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
102                 errorMsg.getCallback(), errorMsg.getUserData());
103             if (mlirLogicalResultIsFailure(status))
104               throw py::value_error(std::string(errorMsg.join()));
105             return new PyPassManager(passManager);
106           },
107           "pipeline"_a, "context"_a = py::none(),
108           "Parse a textual pass-pipeline and return a top-level PassManager "
109           "that can be applied on a Module. Throw a ValueError if the pipeline "
110           "can't be parsed")
111       .def(
112           "add",
113           [](PyPassManager &passManager, const std::string &pipeline) {
114             PyPrintAccumulator errorMsg;
115             MlirLogicalResult status = mlirOpPassManagerAddPipeline(
116                 mlirPassManagerGetAsOpPassManager(passManager.get()),
117                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
118                 errorMsg.getCallback(), errorMsg.getUserData());
119             if (mlirLogicalResultIsFailure(status))
120               throw py::value_error(std::string(errorMsg.join()));
121           },
122           "pipeline"_a,
123           "Add textual pipeline elements to the pass manager. Throws a "
124           "ValueError if the pipeline can't be parsed.")
125       .def(
126           "run",
127           [](PyPassManager &passManager, PyOperationBase &op,
128              bool invalidateOps) {
129             if (invalidateOps) {
130               op.getOperation().getContext()->clearOperationsInside(op);
131             }
132             // Actually run the pass manager.
133             PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
134             MlirLogicalResult status = mlirPassManagerRunOnOp(
135                 passManager.get(), op.getOperation().get());
136             if (mlirLogicalResultIsFailure(status))
137               throw MLIRError("Failure while executing pass pipeline",
138                               errors.take());
139           },
140           "operation"_a, "invalidate_ops"_a = true,
141           "Run the pass manager on the provided operation, raising an "
142           "MLIRError on failure.")
143       .def(
144           "__str__",
145           [](PyPassManager &self) {
146             MlirPassManager passManager = self.get();
147             PyPrintAccumulator printAccum;
148             mlirPrintPassPipeline(
149                 mlirPassManagerGetAsOpPassManager(passManager),
150                 printAccum.getCallback(), printAccum.getUserData());
151             return printAccum.join();
152           },
153           "Print the textual representation for this PassManager, suitable to "
154           "be passed to `parse` for round-tripping.");
155 }
156