1 //===- Pass.cpp - Pass Management -----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Pass.h" 10 11 #include "IRModule.h" 12 #include "mlir-c/Bindings/Python/Interop.h" 13 #include "mlir-c/Pass.h" 14 15 namespace py = pybind11; 16 using namespace mlir; 17 using namespace mlir::python; 18 19 namespace { 20 21 /// Owning Wrapper around a PassManager. 22 class PyPassManager { 23 public: 24 PyPassManager(MlirPassManager passManager) : passManager(passManager) {} 25 PyPassManager(PyPassManager &&other) : passManager(other.passManager) { 26 other.passManager.ptr = nullptr; 27 } 28 ~PyPassManager() { 29 if (!mlirPassManagerIsNull(passManager)) 30 mlirPassManagerDestroy(passManager); 31 } 32 MlirPassManager get() { return passManager; } 33 34 void release() { passManager.ptr = nullptr; } 35 pybind11::object getCapsule() { 36 return py::reinterpret_steal<py::object>( 37 mlirPythonPassManagerToCapsule(get())); 38 } 39 40 static pybind11::object createFromCapsule(pybind11::object capsule) { 41 MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); 42 if (mlirPassManagerIsNull(rawPm)) 43 throw py::error_already_set(); 44 return py::cast(PyPassManager(rawPm), py::return_value_policy::move); 45 } 46 47 private: 48 MlirPassManager passManager; 49 }; 50 51 } // namespace 52 53 /// Create the `mlir.passmanager` here. 54 void mlir::python::populatePassManagerSubmodule(py::module &m) { 55 //---------------------------------------------------------------------------- 56 // Mapping of the top-level PassManager 57 //---------------------------------------------------------------------------- 58 py::class_<PyPassManager>(m, "PassManager", py::module_local()) 59 .def(py::init<>([](const std::string &anchorOp, 60 DefaultingPyMlirContext context) { 61 MlirPassManager passManager = mlirPassManagerCreateOnOperation( 62 context->get(), 63 mlirStringRefCreate(anchorOp.data(), anchorOp.size())); 64 return new PyPassManager(passManager); 65 }), 66 py::arg("anchor_op") = py::str("any"), 67 py::arg("context") = py::none(), 68 "Create a new PassManager for the current (or provided) Context.") 69 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 70 &PyPassManager::getCapsule) 71 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) 72 .def("_testing_release", &PyPassManager::release, 73 "Releases (leaks) the backing pass manager (testing)") 74 .def( 75 "enable_ir_printing", 76 [](PyPassManager &passManager) { 77 mlirPassManagerEnableIRPrinting(passManager.get()); 78 }, 79 "Enable mlir-print-ir-after-all.") 80 .def( 81 "enable_verifier", 82 [](PyPassManager &passManager, bool enable) { 83 mlirPassManagerEnableVerifier(passManager.get(), enable); 84 }, 85 py::arg("enable"), "Enable / disable verify-each.") 86 .def_static( 87 "parse", 88 [](const std::string &pipeline, DefaultingPyMlirContext context) { 89 MlirPassManager passManager = mlirPassManagerCreate(context->get()); 90 PyPrintAccumulator errorMsg; 91 MlirLogicalResult status = mlirParsePassPipeline( 92 mlirPassManagerGetAsOpPassManager(passManager), 93 mlirStringRefCreate(pipeline.data(), pipeline.size()), 94 errorMsg.getCallback(), errorMsg.getUserData()); 95 if (mlirLogicalResultIsFailure(status)) 96 throw SetPyError(PyExc_ValueError, std::string(errorMsg.join())); 97 return new PyPassManager(passManager); 98 }, 99 py::arg("pipeline"), py::arg("context") = py::none(), 100 "Parse a textual pass-pipeline and return a top-level PassManager " 101 "that can be applied on a Module. Throw a ValueError if the pipeline " 102 "can't be parsed") 103 .def( 104 "run", 105 [](PyPassManager &passManager, PyModule &module) { 106 MlirLogicalResult status = 107 mlirPassManagerRun(passManager.get(), module.get()); 108 if (mlirLogicalResultIsFailure(status)) 109 throw SetPyError(PyExc_RuntimeError, 110 "Failure while executing pass pipeline."); 111 }, 112 py::arg("module"), 113 "Run the pass manager on the provided module, throw a RuntimeError " 114 "on failure.") 115 .def( 116 "__str__", 117 [](PyPassManager &self) { 118 MlirPassManager passManager = self.get(); 119 PyPrintAccumulator printAccum; 120 mlirPrintPassPipeline( 121 mlirPassManagerGetAsOpPassManager(passManager), 122 printAccum.getCallback(), printAccum.getUserData()); 123 return printAccum.join(); 124 }, 125 "Print the textual representation for this PassManager, suitable to " 126 "be passed to `parse` for round-tripping."); 127 } 128