xref: /llvm-project/mlir/lib/Bindings/Python/Pass.cpp (revision c8b837ad8ce4f36a3b2e47f1f1367dc0b41fca7b)
1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Pass.h"
10 
11 #include "IRModule.h"
12 #include "mlir-c/Bindings/Python/Interop.h"
13 #include "mlir-c/Pass.h"
14 
15 namespace py = pybind11;
16 using namespace py::literals;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 namespace {
21 
22 /// Owning Wrapper around a PassManager.
23 class PyPassManager {
24 public:
25   PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
26   PyPassManager(PyPassManager &&other) noexcept
27       : passManager(other.passManager) {
28     other.passManager.ptr = nullptr;
29   }
30   ~PyPassManager() {
31     if (!mlirPassManagerIsNull(passManager))
32       mlirPassManagerDestroy(passManager);
33   }
34   MlirPassManager get() { return passManager; }
35 
36   void release() { passManager.ptr = nullptr; }
37   pybind11::object getCapsule() {
38     return py::reinterpret_steal<py::object>(
39         mlirPythonPassManagerToCapsule(get()));
40   }
41 
42   static pybind11::object createFromCapsule(pybind11::object capsule) {
43     MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
44     if (mlirPassManagerIsNull(rawPm))
45       throw py::error_already_set();
46     return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
47   }
48 
49 private:
50   MlirPassManager passManager;
51 };
52 
53 } // namespace
54 
55 /// Create the `mlir.passmanager` here.
56 void mlir::python::populatePassManagerSubmodule(py::module &m) {
57   //----------------------------------------------------------------------------
58   // Mapping of the top-level PassManager
59   //----------------------------------------------------------------------------
60   py::class_<PyPassManager>(m, "PassManager", py::module_local())
61       .def(py::init<>([](const std::string &anchorOp,
62                          DefaultingPyMlirContext context) {
63              MlirPassManager passManager = mlirPassManagerCreateOnOperation(
64                  context->get(),
65                  mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
66              return new PyPassManager(passManager);
67            }),
68            "anchor_op"_a = py::str("any"), "context"_a = py::none(),
69            "Create a new PassManager for the current (or provided) Context.")
70       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
71                              &PyPassManager::getCapsule)
72       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
73       .def("_testing_release", &PyPassManager::release,
74            "Releases (leaks) the backing pass manager (testing)")
75       .def(
76           "enable_ir_printing",
77           [](PyPassManager &passManager, bool printBeforeAll,
78              bool printAfterAll, bool printModuleScope, bool printAfterChange,
79              bool printAfterFailure,
80              std::optional<std::string> optionalTreePrintingPath) {
81             std::string treePrintingPath = "";
82             if (optionalTreePrintingPath.has_value())
83               treePrintingPath = optionalTreePrintingPath.value();
84             mlirPassManagerEnableIRPrinting(
85                 passManager.get(), printBeforeAll, printAfterAll,
86                 printModuleScope, printAfterChange, printAfterFailure,
87                 mlirStringRefCreate(treePrintingPath.data(),
88                                     treePrintingPath.size()));
89           },
90           "print_before_all"_a = false, "print_after_all"_a = true,
91           "print_module_scope"_a = false, "print_after_change"_a = false,
92           "print_after_failure"_a = false,
93           "tree_printing_dir_path"_a = py::none(),
94           "Enable IR printing, default as mlir-print-ir-after-all.")
95       .def(
96           "enable_verifier",
97           [](PyPassManager &passManager, bool enable) {
98             mlirPassManagerEnableVerifier(passManager.get(), enable);
99           },
100           "enable"_a, "Enable / disable verify-each.")
101       .def_static(
102           "parse",
103           [](const std::string &pipeline, DefaultingPyMlirContext context) {
104             MlirPassManager passManager = mlirPassManagerCreate(context->get());
105             PyPrintAccumulator errorMsg;
106             MlirLogicalResult status = mlirParsePassPipeline(
107                 mlirPassManagerGetAsOpPassManager(passManager),
108                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
109                 errorMsg.getCallback(), errorMsg.getUserData());
110             if (mlirLogicalResultIsFailure(status))
111               throw py::value_error(std::string(errorMsg.join()));
112             return new PyPassManager(passManager);
113           },
114           "pipeline"_a, "context"_a = py::none(),
115           "Parse a textual pass-pipeline and return a top-level PassManager "
116           "that can be applied on a Module. Throw a ValueError if the pipeline "
117           "can't be parsed")
118       .def(
119           "add",
120           [](PyPassManager &passManager, const std::string &pipeline) {
121             PyPrintAccumulator errorMsg;
122             MlirLogicalResult status = mlirOpPassManagerAddPipeline(
123                 mlirPassManagerGetAsOpPassManager(passManager.get()),
124                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
125                 errorMsg.getCallback(), errorMsg.getUserData());
126             if (mlirLogicalResultIsFailure(status))
127               throw py::value_error(std::string(errorMsg.join()));
128           },
129           "pipeline"_a,
130           "Add textual pipeline elements to the pass manager. Throws a "
131           "ValueError if the pipeline can't be parsed.")
132       .def(
133           "run",
134           [](PyPassManager &passManager, PyOperationBase &op,
135              bool invalidateOps) {
136             if (invalidateOps) {
137               op.getOperation().getContext()->clearOperationsInside(op);
138             }
139             // Actually run the pass manager.
140             PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
141             MlirLogicalResult status = mlirPassManagerRunOnOp(
142                 passManager.get(), op.getOperation().get());
143             if (mlirLogicalResultIsFailure(status))
144               throw MLIRError("Failure while executing pass pipeline",
145                               errors.take());
146           },
147           "operation"_a, "invalidate_ops"_a = true,
148           "Run the pass manager on the provided operation, raising an "
149           "MLIRError on failure.")
150       .def(
151           "__str__",
152           [](PyPassManager &self) {
153             MlirPassManager passManager = self.get();
154             PyPrintAccumulator printAccum;
155             mlirPrintPassPipeline(
156                 mlirPassManagerGetAsOpPassManager(passManager),
157                 printAccum.getCallback(), printAccum.getUserData());
158             return printAccum.join();
159           },
160           "Print the textual representation for this PassManager, suitable to "
161           "be passed to `parse` for round-tripping.");
162 }
163