xref: /llvm-project/mlir/lib/Bindings/Python/Pass.cpp (revision 6e8b3a3e0cad614954fc387df22d59d941f081c3)
1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Pass.h"
10 
11 #include "IRModule.h"
12 #include "mlir-c/Bindings/Python/Interop.h"
13 #include "mlir-c/Pass.h"
14 
15 namespace py = pybind11;
16 using namespace py::literals;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 namespace {
21 
22 /// Owning Wrapper around a PassManager.
23 class PyPassManager {
24 public:
25   PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
26   PyPassManager(PyPassManager &&other) noexcept
27       : passManager(other.passManager) {
28     other.passManager.ptr = nullptr;
29   }
30   ~PyPassManager() {
31     if (!mlirPassManagerIsNull(passManager))
32       mlirPassManagerDestroy(passManager);
33   }
34   MlirPassManager get() { return passManager; }
35 
36   void release() { passManager.ptr = nullptr; }
37   pybind11::object getCapsule() {
38     return py::reinterpret_steal<py::object>(
39         mlirPythonPassManagerToCapsule(get()));
40   }
41 
42   static pybind11::object createFromCapsule(pybind11::object capsule) {
43     MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
44     if (mlirPassManagerIsNull(rawPm))
45       throw py::error_already_set();
46     return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
47   }
48 
49 private:
50   MlirPassManager passManager;
51 };
52 
53 } // namespace
54 
55 /// Create the `mlir.passmanager` here.
56 void mlir::python::populatePassManagerSubmodule(py::module &m) {
57   //----------------------------------------------------------------------------
58   // Mapping of the top-level PassManager
59   //----------------------------------------------------------------------------
60   py::class_<PyPassManager>(m, "PassManager", py::module_local())
61       .def(py::init<>([](const std::string &anchorOp,
62                          DefaultingPyMlirContext context) {
63              MlirPassManager passManager = mlirPassManagerCreateOnOperation(
64                  context->get(),
65                  mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
66              return new PyPassManager(passManager);
67            }),
68            "anchor_op"_a = py::str("any"), "context"_a = py::none(),
69            "Create a new PassManager for the current (or provided) Context.")
70       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
71                              &PyPassManager::getCapsule)
72       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
73       .def("_testing_release", &PyPassManager::release,
74            "Releases (leaks) the backing pass manager (testing)")
75       .def(
76           "enable_ir_printing",
77           [](PyPassManager &passManager, bool printBeforeAll,
78              bool printAfterAll, bool printModuleScope, bool printAfterChange,
79              bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
80              bool enableDebugInfo, bool printGenericOpForm,
81              std::optional<std::string> optionalTreePrintingPath) {
82             MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
83             if (largeElementsLimit)
84               mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
85                                                          *largeElementsLimit);
86             if (enableDebugInfo)
87               mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
88                                                  /*prettyForm=*/false);
89             if (printGenericOpForm)
90               mlirOpPrintingFlagsPrintGenericOpForm(flags);
91             std::string treePrintingPath = "";
92             if (optionalTreePrintingPath.has_value())
93               treePrintingPath = optionalTreePrintingPath.value();
94             mlirPassManagerEnableIRPrinting(
95                 passManager.get(), printBeforeAll, printAfterAll,
96                 printModuleScope, printAfterChange, printAfterFailure, flags,
97                 mlirStringRefCreate(treePrintingPath.data(),
98                                     treePrintingPath.size()));
99             mlirOpPrintingFlagsDestroy(flags);
100           },
101           "print_before_all"_a = false, "print_after_all"_a = true,
102           "print_module_scope"_a = false, "print_after_change"_a = false,
103           "print_after_failure"_a = false,
104           "large_elements_limit"_a = py::none(), "enable_debug_info"_a = false,
105           "print_generic_op_form"_a = false,
106           "tree_printing_dir_path"_a = py::none(),
107           "Enable IR printing, default as mlir-print-ir-after-all.")
108       .def(
109           "enable_verifier",
110           [](PyPassManager &passManager, bool enable) {
111             mlirPassManagerEnableVerifier(passManager.get(), enable);
112           },
113           "enable"_a, "Enable / disable verify-each.")
114       .def_static(
115           "parse",
116           [](const std::string &pipeline, DefaultingPyMlirContext context) {
117             MlirPassManager passManager = mlirPassManagerCreate(context->get());
118             PyPrintAccumulator errorMsg;
119             MlirLogicalResult status = mlirParsePassPipeline(
120                 mlirPassManagerGetAsOpPassManager(passManager),
121                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
122                 errorMsg.getCallback(), errorMsg.getUserData());
123             if (mlirLogicalResultIsFailure(status))
124               throw py::value_error(std::string(errorMsg.join()));
125             return new PyPassManager(passManager);
126           },
127           "pipeline"_a, "context"_a = py::none(),
128           "Parse a textual pass-pipeline and return a top-level PassManager "
129           "that can be applied on a Module. Throw a ValueError if the pipeline "
130           "can't be parsed")
131       .def(
132           "add",
133           [](PyPassManager &passManager, const std::string &pipeline) {
134             PyPrintAccumulator errorMsg;
135             MlirLogicalResult status = mlirOpPassManagerAddPipeline(
136                 mlirPassManagerGetAsOpPassManager(passManager.get()),
137                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
138                 errorMsg.getCallback(), errorMsg.getUserData());
139             if (mlirLogicalResultIsFailure(status))
140               throw py::value_error(std::string(errorMsg.join()));
141           },
142           "pipeline"_a,
143           "Add textual pipeline elements to the pass manager. Throws a "
144           "ValueError if the pipeline can't be parsed.")
145       .def(
146           "run",
147           [](PyPassManager &passManager, PyOperationBase &op,
148              bool invalidateOps) {
149             if (invalidateOps) {
150               op.getOperation().getContext()->clearOperationsInside(op);
151             }
152             // Actually run the pass manager.
153             PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
154             MlirLogicalResult status = mlirPassManagerRunOnOp(
155                 passManager.get(), op.getOperation().get());
156             if (mlirLogicalResultIsFailure(status))
157               throw MLIRError("Failure while executing pass pipeline",
158                               errors.take());
159           },
160           "operation"_a, "invalidate_ops"_a = true,
161           "Run the pass manager on the provided operation, raising an "
162           "MLIRError on failure.")
163       .def(
164           "__str__",
165           [](PyPassManager &self) {
166             MlirPassManager passManager = self.get();
167             PyPrintAccumulator printAccum;
168             mlirPrintPassPipeline(
169                 mlirPassManagerGetAsOpPassManager(passManager),
170                 printAccum.getCallback(), printAccum.getUserData());
171             return printAccum.join();
172           },
173           "Print the textual representation for this PassManager, suitable to "
174           "be passed to `parse` for round-tripping.");
175 }
176