xref: /llvm-project/mlir/lib/Bindings/Python/Pass.cpp (revision 41bd35b58bb482fd466aa4b13aa44a810ad6470f)
1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Pass.h"
10 
11 #include <nanobind/nanobind.h>
12 #include <nanobind/stl/optional.h>
13 #include <nanobind/stl/string.h>
14 
15 #include "IRModule.h"
16 #include "mlir-c/Bindings/Python/Interop.h"
17 #include "mlir-c/Pass.h"
18 
19 namespace nb = nanobind;
20 using namespace nb::literals;
21 using namespace mlir;
22 using namespace mlir::python;
23 
24 namespace {
25 
26 /// Owning Wrapper around a PassManager.
27 class PyPassManager {
28 public:
29   PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
30   PyPassManager(PyPassManager &&other) noexcept
31       : passManager(other.passManager) {
32     other.passManager.ptr = nullptr;
33   }
34   ~PyPassManager() {
35     if (!mlirPassManagerIsNull(passManager))
36       mlirPassManagerDestroy(passManager);
37   }
38   MlirPassManager get() { return passManager; }
39 
40   void release() { passManager.ptr = nullptr; }
41   nb::object getCapsule() {
42     return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
43   }
44 
45   static nb::object createFromCapsule(nb::object capsule) {
46     MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
47     if (mlirPassManagerIsNull(rawPm))
48       throw nb::python_error();
49     return nb::cast(PyPassManager(rawPm), nb::rv_policy::move);
50   }
51 
52 private:
53   MlirPassManager passManager;
54 };
55 
56 } // namespace
57 
58 /// Create the `mlir.passmanager` here.
59 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
60   //----------------------------------------------------------------------------
61   // Mapping of the top-level PassManager
62   //----------------------------------------------------------------------------
63   nb::class_<PyPassManager>(m, "PassManager")
64       .def(
65           "__init__",
66           [](PyPassManager &self, const std::string &anchorOp,
67              DefaultingPyMlirContext context) {
68             MlirPassManager passManager = mlirPassManagerCreateOnOperation(
69                 context->get(),
70                 mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
71             new (&self) PyPassManager(passManager);
72           },
73           "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(),
74           "Create a new PassManager for the current (or provided) Context.")
75       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule)
76       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
77       .def("_testing_release", &PyPassManager::release,
78            "Releases (leaks) the backing pass manager (testing)")
79       .def(
80           "enable_ir_printing",
81           [](PyPassManager &passManager, bool printBeforeAll,
82              bool printAfterAll, bool printModuleScope, bool printAfterChange,
83              bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
84              bool enableDebugInfo, bool printGenericOpForm,
85              std::optional<std::string> optionalTreePrintingPath) {
86             MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
87             if (largeElementsLimit)
88               mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
89                                                          *largeElementsLimit);
90             if (enableDebugInfo)
91               mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
92                                                  /*prettyForm=*/false);
93             if (printGenericOpForm)
94               mlirOpPrintingFlagsPrintGenericOpForm(flags);
95             std::string treePrintingPath = "";
96             if (optionalTreePrintingPath.has_value())
97               treePrintingPath = optionalTreePrintingPath.value();
98             mlirPassManagerEnableIRPrinting(
99                 passManager.get(), printBeforeAll, printAfterAll,
100                 printModuleScope, printAfterChange, printAfterFailure, flags,
101                 mlirStringRefCreate(treePrintingPath.data(),
102                                     treePrintingPath.size()));
103             mlirOpPrintingFlagsDestroy(flags);
104           },
105           "print_before_all"_a = false, "print_after_all"_a = true,
106           "print_module_scope"_a = false, "print_after_change"_a = false,
107           "print_after_failure"_a = false,
108           "large_elements_limit"_a.none() = nb::none(),
109           "enable_debug_info"_a = false, "print_generic_op_form"_a = false,
110           "tree_printing_dir_path"_a.none() = nb::none(),
111           "Enable IR printing, default as mlir-print-ir-after-all.")
112       .def(
113           "enable_verifier",
114           [](PyPassManager &passManager, bool enable) {
115             mlirPassManagerEnableVerifier(passManager.get(), enable);
116           },
117           "enable"_a, "Enable / disable verify-each.")
118       .def_static(
119           "parse",
120           [](const std::string &pipeline, DefaultingPyMlirContext context) {
121             MlirPassManager passManager = mlirPassManagerCreate(context->get());
122             PyPrintAccumulator errorMsg;
123             MlirLogicalResult status = mlirParsePassPipeline(
124                 mlirPassManagerGetAsOpPassManager(passManager),
125                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
126                 errorMsg.getCallback(), errorMsg.getUserData());
127             if (mlirLogicalResultIsFailure(status))
128               throw nb::value_error(errorMsg.join().c_str());
129             return new PyPassManager(passManager);
130           },
131           "pipeline"_a, "context"_a.none() = nb::none(),
132           "Parse a textual pass-pipeline and return a top-level PassManager "
133           "that can be applied on a Module. Throw a ValueError if the pipeline "
134           "can't be parsed")
135       .def(
136           "add",
137           [](PyPassManager &passManager, const std::string &pipeline) {
138             PyPrintAccumulator errorMsg;
139             MlirLogicalResult status = mlirOpPassManagerAddPipeline(
140                 mlirPassManagerGetAsOpPassManager(passManager.get()),
141                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
142                 errorMsg.getCallback(), errorMsg.getUserData());
143             if (mlirLogicalResultIsFailure(status))
144               throw nb::value_error(errorMsg.join().c_str());
145           },
146           "pipeline"_a,
147           "Add textual pipeline elements to the pass manager. Throws a "
148           "ValueError if the pipeline can't be parsed.")
149       .def(
150           "run",
151           [](PyPassManager &passManager, PyOperationBase &op,
152              bool invalidateOps) {
153             if (invalidateOps) {
154               op.getOperation().getContext()->clearOperationsInside(op);
155             }
156             // Actually run the pass manager.
157             PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
158             MlirLogicalResult status = mlirPassManagerRunOnOp(
159                 passManager.get(), op.getOperation().get());
160             if (mlirLogicalResultIsFailure(status))
161               throw MLIRError("Failure while executing pass pipeline",
162                               errors.take());
163           },
164           "operation"_a, "invalidate_ops"_a = true,
165           "Run the pass manager on the provided operation, raising an "
166           "MLIRError on failure.")
167       .def(
168           "__str__",
169           [](PyPassManager &self) {
170             MlirPassManager passManager = self.get();
171             PyPrintAccumulator printAccum;
172             mlirPrintPassPipeline(
173                 mlirPassManagerGetAsOpPassManager(passManager),
174                 printAccum.getCallback(), printAccum.getUserData());
175             return printAccum.join();
176           },
177           "Print the textual representation for this PassManager, suitable to "
178           "be passed to `parse` for round-tripping.");
179 }
180