xref: /llvm-project/mlir/lib/Bindings/Python/Pass.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1dc43f785SMehdi Amini //===- Pass.cpp - Pass Management -----------------------------------------===//
2dc43f785SMehdi Amini //
3dc43f785SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4dc43f785SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
5dc43f785SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6dc43f785SMehdi Amini //
7dc43f785SMehdi Amini //===----------------------------------------------------------------------===//
8dc43f785SMehdi Amini 
9dc43f785SMehdi Amini #include "Pass.h"
10dc43f785SMehdi Amini 
11436c6c9cSStella Laurenzo #include "IRModule.h"
12dc43f785SMehdi Amini #include "mlir-c/Pass.h"
13*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h"
14*5cd42747SPeter Hawkins #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
15dc43f785SMehdi Amini 
16b56d1ec6SPeter Hawkins namespace nb = nanobind;
17b56d1ec6SPeter Hawkins using namespace nb::literals;
18dc43f785SMehdi Amini using namespace mlir;
19dc43f785SMehdi Amini using namespace mlir::python;
20dc43f785SMehdi Amini 
21dc43f785SMehdi Amini namespace {
22dc43f785SMehdi Amini 
23dc43f785SMehdi Amini /// Owning Wrapper around a PassManager.
24dc43f785SMehdi Amini class PyPassManager {
25dc43f785SMehdi Amini public:
26dc43f785SMehdi Amini   PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
27ea2e83afSAdrian Kuegel   PyPassManager(PyPassManager &&other) noexcept
28ea2e83afSAdrian Kuegel       : passManager(other.passManager) {
295fef6ce0SStella Laurenzo     other.passManager.ptr = nullptr;
305fef6ce0SStella Laurenzo   }
315fef6ce0SStella Laurenzo   ~PyPassManager() {
325fef6ce0SStella Laurenzo     if (!mlirPassManagerIsNull(passManager))
335fef6ce0SStella Laurenzo       mlirPassManagerDestroy(passManager);
345fef6ce0SStella Laurenzo   }
35dc43f785SMehdi Amini   MlirPassManager get() { return passManager; }
36dc43f785SMehdi Amini 
375fef6ce0SStella Laurenzo   void release() { passManager.ptr = nullptr; }
38b56d1ec6SPeter Hawkins   nb::object getCapsule() {
39b56d1ec6SPeter Hawkins     return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
405fef6ce0SStella Laurenzo   }
415fef6ce0SStella Laurenzo 
42b56d1ec6SPeter Hawkins   static nb::object createFromCapsule(nb::object capsule) {
435fef6ce0SStella Laurenzo     MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
445fef6ce0SStella Laurenzo     if (mlirPassManagerIsNull(rawPm))
45b56d1ec6SPeter Hawkins       throw nb::python_error();
46b56d1ec6SPeter Hawkins     return nb::cast(PyPassManager(rawPm), nb::rv_policy::move);
475fef6ce0SStella Laurenzo   }
485fef6ce0SStella Laurenzo 
49dc43f785SMehdi Amini private:
50dc43f785SMehdi Amini   MlirPassManager passManager;
51dc43f785SMehdi Amini };
52dc43f785SMehdi Amini 
53be0a7e9fSMehdi Amini } // namespace
54dc43f785SMehdi Amini 
55dc43f785SMehdi Amini /// Create the `mlir.passmanager` here.
56b56d1ec6SPeter Hawkins void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
57dc43f785SMehdi Amini   //----------------------------------------------------------------------------
58dc43f785SMehdi Amini   // Mapping of the top-level PassManager
59dc43f785SMehdi Amini   //----------------------------------------------------------------------------
60b56d1ec6SPeter Hawkins   nb::class_<PyPassManager>(m, "PassManager")
61b56d1ec6SPeter Hawkins       .def(
62b56d1ec6SPeter Hawkins           "__init__",
63b56d1ec6SPeter Hawkins           [](PyPassManager &self, const std::string &anchorOp,
64d97e8cd4Srkayaith              DefaultingPyMlirContext context) {
65d97e8cd4Srkayaith             MlirPassManager passManager = mlirPassManagerCreateOnOperation(
66d97e8cd4Srkayaith                 context->get(),
67d97e8cd4Srkayaith                 mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
68b56d1ec6SPeter Hawkins             new (&self) PyPassManager(passManager);
69b56d1ec6SPeter Hawkins           },
70b56d1ec6SPeter Hawkins           "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(),
71dc43f785SMehdi Amini           "Create a new PassManager for the current (or provided) Context.")
72b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule)
735fef6ce0SStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
745fef6ce0SStella Laurenzo       .def("_testing_release", &PyPassManager::release,
755fef6ce0SStella Laurenzo            "Releases (leaks) the backing pass manager (testing)")
76caa159f0SNicolas Vasilache       .def(
77caa159f0SNicolas Vasilache           "enable_ir_printing",
78f8eceb45SBimo           [](PyPassManager &passManager, bool printBeforeAll,
79f8eceb45SBimo              bool printAfterAll, bool printModuleScope, bool printAfterChange,
802e51e150SYuanqiang Liu              bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
812e51e150SYuanqiang Liu              bool enableDebugInfo, bool printGenericOpForm,
82c8b837adSMehdi Amini              std::optional<std::string> optionalTreePrintingPath) {
832e51e150SYuanqiang Liu             MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
842e51e150SYuanqiang Liu             if (largeElementsLimit)
852e51e150SYuanqiang Liu               mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
862e51e150SYuanqiang Liu                                                          *largeElementsLimit);
872e51e150SYuanqiang Liu             if (enableDebugInfo)
882e51e150SYuanqiang Liu               mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
892e51e150SYuanqiang Liu                                                  /*prettyForm=*/false);
902e51e150SYuanqiang Liu             if (printGenericOpForm)
912e51e150SYuanqiang Liu               mlirOpPrintingFlagsPrintGenericOpForm(flags);
92c8b837adSMehdi Amini             std::string treePrintingPath = "";
93c8b837adSMehdi Amini             if (optionalTreePrintingPath.has_value())
94c8b837adSMehdi Amini               treePrintingPath = optionalTreePrintingPath.value();
95f8eceb45SBimo             mlirPassManagerEnableIRPrinting(
96f8eceb45SBimo                 passManager.get(), printBeforeAll, printAfterAll,
972e51e150SYuanqiang Liu                 printModuleScope, printAfterChange, printAfterFailure, flags,
98c8b837adSMehdi Amini                 mlirStringRefCreate(treePrintingPath.data(),
99c8b837adSMehdi Amini                                     treePrintingPath.size()));
1002e51e150SYuanqiang Liu             mlirOpPrintingFlagsDestroy(flags);
101caa159f0SNicolas Vasilache           },
102f8eceb45SBimo           "print_before_all"_a = false, "print_after_all"_a = true,
103f8eceb45SBimo           "print_module_scope"_a = false, "print_after_change"_a = false,
104f8eceb45SBimo           "print_after_failure"_a = false,
105b56d1ec6SPeter Hawkins           "large_elements_limit"_a.none() = nb::none(),
106b56d1ec6SPeter Hawkins           "enable_debug_info"_a = false, "print_generic_op_form"_a = false,
107b56d1ec6SPeter Hawkins           "tree_printing_dir_path"_a.none() = nb::none(),
108f8eceb45SBimo           "Enable IR printing, default as mlir-print-ir-after-all.")
109caa159f0SNicolas Vasilache       .def(
110caa159f0SNicolas Vasilache           "enable_verifier",
111caa159f0SNicolas Vasilache           [](PyPassManager &passManager, bool enable) {
112caa159f0SNicolas Vasilache             mlirPassManagerEnableVerifier(passManager.get(), enable);
113caa159f0SNicolas Vasilache           },
114bdc3e6cbSMaksim Levental           "enable"_a, "Enable / disable verify-each.")
115dc43f785SMehdi Amini       .def_static(
116dc43f785SMehdi Amini           "parse",
117b3c5f6b1Srkayaith           [](const std::string &pipeline, DefaultingPyMlirContext context) {
118dc43f785SMehdi Amini             MlirPassManager passManager = mlirPassManagerCreate(context->get());
119b3c5f6b1Srkayaith             PyPrintAccumulator errorMsg;
12066645a03Srkayaith             MlirLogicalResult status = mlirParsePassPipeline(
121dc43f785SMehdi Amini                 mlirPassManagerGetAsOpPassManager(passManager),
122b3c5f6b1Srkayaith                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
123b3c5f6b1Srkayaith                 errorMsg.getCallback(), errorMsg.getUserData());
124dc43f785SMehdi Amini             if (mlirLogicalResultIsFailure(status))
125b56d1ec6SPeter Hawkins               throw nb::value_error(errorMsg.join().c_str());
126dc43f785SMehdi Amini             return new PyPassManager(passManager);
127dc43f785SMehdi Amini           },
128b56d1ec6SPeter Hawkins           "pipeline"_a, "context"_a.none() = nb::none(),
129dc43f785SMehdi Amini           "Parse a textual pass-pipeline and return a top-level PassManager "
130dc43f785SMehdi Amini           "that can be applied on a Module. Throw a ValueError if the pipeline "
131dc43f785SMehdi Amini           "can't be parsed")
132dc43f785SMehdi Amini       .def(
133dd1b1d44Srkayaith           "add",
134dd1b1d44Srkayaith           [](PyPassManager &passManager, const std::string &pipeline) {
135dd1b1d44Srkayaith             PyPrintAccumulator errorMsg;
136dd1b1d44Srkayaith             MlirLogicalResult status = mlirOpPassManagerAddPipeline(
137dd1b1d44Srkayaith                 mlirPassManagerGetAsOpPassManager(passManager.get()),
138dd1b1d44Srkayaith                 mlirStringRefCreate(pipeline.data(), pipeline.size()),
139dd1b1d44Srkayaith                 errorMsg.getCallback(), errorMsg.getUserData());
140dd1b1d44Srkayaith             if (mlirLogicalResultIsFailure(status))
141b56d1ec6SPeter Hawkins               throw nb::value_error(errorMsg.join().c_str());
142dd1b1d44Srkayaith           },
143bdc3e6cbSMaksim Levental           "pipeline"_a,
144dd1b1d44Srkayaith           "Add textual pipeline elements to the pass manager. Throws a "
145dd1b1d44Srkayaith           "ValueError if the pipeline can't be parsed.")
146dd1b1d44Srkayaith       .def(
1476cb1c0caSMehdi Amini           "run",
148bdc3e6cbSMaksim Levental           [](PyPassManager &passManager, PyOperationBase &op,
149bdc3e6cbSMaksim Levental              bool invalidateOps) {
150bdc3e6cbSMaksim Levental             if (invalidateOps) {
151fa19ef7aSIngo Müller               op.getOperation().getContext()->clearOperationsInside(op);
152bdc3e6cbSMaksim Levental             }
153bdc3e6cbSMaksim Levental             // Actually run the pass manager.
1543ea4c501SRahul Kayaith             PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
1556f5590caSrkayaith             MlirLogicalResult status = mlirPassManagerRunOnOp(
156c00f81ccSrkayaith                 passManager.get(), op.getOperation().get());
1576cb1c0caSMehdi Amini             if (mlirLogicalResultIsFailure(status))
1583ea4c501SRahul Kayaith               throw MLIRError("Failure while executing pass pipeline",
1593ea4c501SRahul Kayaith                               errors.take());
1606cb1c0caSMehdi Amini           },
161bdc3e6cbSMaksim Levental           "operation"_a, "invalidate_ops"_a = true,
1623ea4c501SRahul Kayaith           "Run the pass manager on the provided operation, raising an "
1633ea4c501SRahul Kayaith           "MLIRError on failure.")
1646cb1c0caSMehdi Amini       .def(
165dc43f785SMehdi Amini           "__str__",
166dc43f785SMehdi Amini           [](PyPassManager &self) {
167dc43f785SMehdi Amini             MlirPassManager passManager = self.get();
168dc43f785SMehdi Amini             PyPrintAccumulator printAccum;
169dc43f785SMehdi Amini             mlirPrintPassPipeline(
170dc43f785SMehdi Amini                 mlirPassManagerGetAsOpPassManager(passManager),
171dc43f785SMehdi Amini                 printAccum.getCallback(), printAccum.getUserData());
172dc43f785SMehdi Amini             return printAccum.join();
173dc43f785SMehdi Amini           },
174dc43f785SMehdi Amini           "Print the textual representation for this PassManager, suitable to "
175dc43f785SMehdi Amini           "be passed to `parse` for round-tripping.");
176dc43f785SMehdi Amini }
177