1 //===- Pass.cpp - Pass Management -----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Pass.h" 10 11 #include "IRModule.h" 12 #include "mlir-c/Bindings/Python/Interop.h" 13 #include "mlir-c/Pass.h" 14 15 namespace py = pybind11; 16 using namespace py::literals; 17 using namespace mlir; 18 using namespace mlir::python; 19 20 namespace { 21 22 /// Owning Wrapper around a PassManager. 23 class PyPassManager { 24 public: 25 PyPassManager(MlirPassManager passManager) : passManager(passManager) {} 26 PyPassManager(PyPassManager &&other) : passManager(other.passManager) { 27 other.passManager.ptr = nullptr; 28 } 29 ~PyPassManager() { 30 if (!mlirPassManagerIsNull(passManager)) 31 mlirPassManagerDestroy(passManager); 32 } 33 MlirPassManager get() { return passManager; } 34 35 void release() { passManager.ptr = nullptr; } 36 pybind11::object getCapsule() { 37 return py::reinterpret_steal<py::object>( 38 mlirPythonPassManagerToCapsule(get())); 39 } 40 41 static pybind11::object createFromCapsule(pybind11::object capsule) { 42 MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); 43 if (mlirPassManagerIsNull(rawPm)) 44 throw py::error_already_set(); 45 return py::cast(PyPassManager(rawPm), py::return_value_policy::move); 46 } 47 48 private: 49 MlirPassManager passManager; 50 }; 51 52 } // namespace 53 54 /// Create the `mlir.passmanager` here. 55 void mlir::python::populatePassManagerSubmodule(py::module &m) { 56 //---------------------------------------------------------------------------- 57 // Mapping of the top-level PassManager 58 //---------------------------------------------------------------------------- 59 py::class_<PyPassManager>(m, "PassManager", py::module_local()) 60 .def(py::init<>([](const std::string &anchorOp, 61 DefaultingPyMlirContext context) { 62 MlirPassManager passManager = mlirPassManagerCreateOnOperation( 63 context->get(), 64 mlirStringRefCreate(anchorOp.data(), anchorOp.size())); 65 return new PyPassManager(passManager); 66 }), 67 "anchor_op"_a = py::str("any"), "context"_a = py::none(), 68 "Create a new PassManager for the current (or provided) Context.") 69 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 70 &PyPassManager::getCapsule) 71 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) 72 .def("_testing_release", &PyPassManager::release, 73 "Releases (leaks) the backing pass manager (testing)") 74 .def( 75 "enable_ir_printing", 76 [](PyPassManager &passManager) { 77 mlirPassManagerEnableIRPrinting(passManager.get()); 78 }, 79 "Enable mlir-print-ir-after-all.") 80 .def( 81 "enable_verifier", 82 [](PyPassManager &passManager, bool enable) { 83 mlirPassManagerEnableVerifier(passManager.get(), enable); 84 }, 85 "enable"_a, "Enable / disable verify-each.") 86 .def_static( 87 "parse", 88 [](const std::string &pipeline, DefaultingPyMlirContext context) { 89 MlirPassManager passManager = mlirPassManagerCreate(context->get()); 90 PyPrintAccumulator errorMsg; 91 MlirLogicalResult status = mlirParsePassPipeline( 92 mlirPassManagerGetAsOpPassManager(passManager), 93 mlirStringRefCreate(pipeline.data(), pipeline.size()), 94 errorMsg.getCallback(), errorMsg.getUserData()); 95 if (mlirLogicalResultIsFailure(status)) 96 throw py::value_error(std::string(errorMsg.join())); 97 return new PyPassManager(passManager); 98 }, 99 "pipeline"_a, "context"_a = py::none(), 100 "Parse a textual pass-pipeline and return a top-level PassManager " 101 "that can be applied on a Module. Throw a ValueError if the pipeline " 102 "can't be parsed") 103 .def( 104 "add", 105 [](PyPassManager &passManager, const std::string &pipeline) { 106 PyPrintAccumulator errorMsg; 107 MlirLogicalResult status = mlirOpPassManagerAddPipeline( 108 mlirPassManagerGetAsOpPassManager(passManager.get()), 109 mlirStringRefCreate(pipeline.data(), pipeline.size()), 110 errorMsg.getCallback(), errorMsg.getUserData()); 111 if (mlirLogicalResultIsFailure(status)) 112 throw py::value_error(std::string(errorMsg.join())); 113 }, 114 "pipeline"_a, 115 "Add textual pipeline elements to the pass manager. Throws a " 116 "ValueError if the pipeline can't be parsed.") 117 .def( 118 "run", 119 [](PyPassManager &passManager, PyOperationBase &op, 120 bool invalidateOps) { 121 if (invalidateOps) { 122 op.getOperation().getContext()->clearOperationsInside(op); 123 } 124 // Actually run the pass manager. 125 PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); 126 MlirLogicalResult status = mlirPassManagerRunOnOp( 127 passManager.get(), op.getOperation().get()); 128 if (mlirLogicalResultIsFailure(status)) 129 throw MLIRError("Failure while executing pass pipeline", 130 errors.take()); 131 }, 132 "operation"_a, "invalidate_ops"_a = true, 133 "Run the pass manager on the provided operation, raising an " 134 "MLIRError on failure.") 135 .def( 136 "__str__", 137 [](PyPassManager &self) { 138 MlirPassManager passManager = self.get(); 139 PyPrintAccumulator printAccum; 140 mlirPrintPassPipeline( 141 mlirPassManagerGetAsOpPassManager(passManager), 142 printAccum.getCallback(), printAccum.getUserData()); 143 return printAccum.join(); 144 }, 145 "Print the textual representation for this PassManager, suitable to " 146 "be passed to `parse` for round-tripping."); 147 } 148