xref: /llvm-project/mlir/lib/Bindings/Python/Pass.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Pass.h"
10 
11 #include "IRModule.h"
12 #include "mlir-c/Pass.h"
13 #include "mlir/Bindings/Python/Nanobind.h"
14 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
15 
16 namespace nb = nanobind;
17 using namespace nb::literals;
18 using namespace mlir;
19 using namespace mlir::python;
20 
21 namespace {
22 
23 /// Owning Wrapper around a PassManager.
24 class PyPassManager {
25 public:
26   PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
27   PyPassManager(PyPassManager &&other) noexcept
28       : passManager(other.passManager) {
29     other.passManager.ptr = nullptr;
30   }
31   ~PyPassManager() {
32     if (!mlirPassManagerIsNull(passManager))
33       mlirPassManagerDestroy(passManager);
34   }
35   MlirPassManager get() { return passManager; }
36 
37   void release() { passManager.ptr = nullptr; }
38   nb::object getCapsule() {
39     return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
40   }
41 
42   static nb::object createFromCapsule(nb::object capsule) {
43     MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
44     if (mlirPassManagerIsNull(rawPm))
45       throw nb::python_error();
46     return nb::cast(PyPassManager(rawPm), nb::rv_policy::move);
47   }
48 
49 private:
50   MlirPassManager passManager;
51 };
52 
53 } // namespace
54 
55 /// Create the `mlir.passmanager` here.
56 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
57   //----------------------------------------------------------------------------
58   // Mapping of the top-level PassManager
59   //----------------------------------------------------------------------------
60   nb::class_<PyPassManager>(m, "PassManager")
61       .def(
62           "__init__",
63           [](PyPassManager &self, const std::string &anchorOp,
64              DefaultingPyMlirContext context) {
65             MlirPassManager passManager = mlirPassManagerCreateOnOperation(
66                 context->get(),
67                 mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
68             new (&self) PyPassManager(passManager);
69           },
70           "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(),
71           "Create a new PassManager for the current (or provided) Context.")
72       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule)
73       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
74       .def("_testing_release", &PyPassManager::release,
75            "Releases (leaks) the backing pass manager (testing)")
76       .def(
77           "enable_ir_printing",
78           [](PyPassManager &passManager, bool printBeforeAll,
79              bool printAfterAll, bool printModuleScope, bool printAfterChange,
80              bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
81              bool enableDebugInfo, bool printGenericOpForm,
82              std::optional<std::string> optionalTreePrintingPath) {
83             MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
84             if (largeElementsLimit)
85               mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
86                                                          *largeElementsLimit);
87             if (enableDebugInfo)
88               mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
89                                                  /*prettyForm=*/false);
90             if (printGenericOpForm)
91               mlirOpPrintingFlagsPrintGenericOpForm(flags);
92             std::string treePrintingPath = "";
93             if (optionalTreePrintingPath.has_value())
94               treePrintingPath = optionalTreePrintingPath.value();
95             mlirPassManagerEnableIRPrinting(
96                 passManager.get(), printBeforeAll, printAfterAll,
97                 printModuleScope, printAfterChange, printAfterFailure, flags,
98                 mlirStringRefCreate(treePrintingPath.data(),
99                                     treePrintingPath.size()));
100             mlirOpPrintingFlagsDestroy(flags);
101           },
102           "print_before_all"_a = false, "print_after_all"_a = true,
103           "print_module_scope"_a = false, "print_after_change"_a = false,
104           "print_after_failure"_a = false,
105           "large_elements_limit"_a.none() = nb::none(),
106           "enable_debug_info"_a = false, "print_generic_op_form"_a = false,
107           "tree_printing_dir_path"_a.none() = nb::none(),
108           "Enable IR printing, default as mlir-print-ir-after-all.")
109       .def(
110           "enable_verifier",
111           [](PyPassManager &passManager, bool enable) {
112             mlirPassManagerEnableVerifier(passManager.get(), enable);
113           },
114           "enable"_a, "Enable / disable verify-each.")
115       .def_static(
116           "parse",
117           [](const std::string &pipeline, DefaultingPyMlirContext context) {
118             MlirPassManager passManager = mlirPassManagerCreate(context->get());
119             PyPrintAccumulator errorMsg;
120             MlirLogicalResult status = mlirParsePassPipeline(
121                 mlirPassManagerGetAsOpPassManager(passManager),
122                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
123                 errorMsg.getCallback(), errorMsg.getUserData());
124             if (mlirLogicalResultIsFailure(status))
125               throw nb::value_error(errorMsg.join().c_str());
126             return new PyPassManager(passManager);
127           },
128           "pipeline"_a, "context"_a.none() = nb::none(),
129           "Parse a textual pass-pipeline and return a top-level PassManager "
130           "that can be applied on a Module. Throw a ValueError if the pipeline "
131           "can't be parsed")
132       .def(
133           "add",
134           [](PyPassManager &passManager, const std::string &pipeline) {
135             PyPrintAccumulator errorMsg;
136             MlirLogicalResult status = mlirOpPassManagerAddPipeline(
137                 mlirPassManagerGetAsOpPassManager(passManager.get()),
138                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
139                 errorMsg.getCallback(), errorMsg.getUserData());
140             if (mlirLogicalResultIsFailure(status))
141               throw nb::value_error(errorMsg.join().c_str());
142           },
143           "pipeline"_a,
144           "Add textual pipeline elements to the pass manager. Throws a "
145           "ValueError if the pipeline can't be parsed.")
146       .def(
147           "run",
148           [](PyPassManager &passManager, PyOperationBase &op,
149              bool invalidateOps) {
150             if (invalidateOps) {
151               op.getOperation().getContext()->clearOperationsInside(op);
152             }
153             // Actually run the pass manager.
154             PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
155             MlirLogicalResult status = mlirPassManagerRunOnOp(
156                 passManager.get(), op.getOperation().get());
157             if (mlirLogicalResultIsFailure(status))
158               throw MLIRError("Failure while executing pass pipeline",
159                               errors.take());
160           },
161           "operation"_a, "invalidate_ops"_a = true,
162           "Run the pass manager on the provided operation, raising an "
163           "MLIRError on failure.")
164       .def(
165           "__str__",
166           [](PyPassManager &self) {
167             MlirPassManager passManager = self.get();
168             PyPrintAccumulator printAccum;
169             mlirPrintPassPipeline(
170                 mlirPassManagerGetAsOpPassManager(passManager),
171                 printAccum.getCallback(), printAccum.getUserData());
172             return printAccum.join();
173           },
174           "Print the textual representation for this PassManager, suitable to "
175           "be passed to `parse` for round-tripping.");
176 }
177