1 //===- Pass.cpp - Pass Management -----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Pass.h" 10 11 #include "IRModule.h" 12 #include "mlir-c/Bindings/Python/Interop.h" 13 #include "mlir-c/Pass.h" 14 15 namespace py = pybind11; 16 using namespace py::literals; 17 using namespace mlir; 18 using namespace mlir::python; 19 20 namespace { 21 22 /// Owning Wrapper around a PassManager. 23 class PyPassManager { 24 public: 25 PyPassManager(MlirPassManager passManager) : passManager(passManager) {} 26 PyPassManager(PyPassManager &&other) noexcept 27 : passManager(other.passManager) { 28 other.passManager.ptr = nullptr; 29 } 30 ~PyPassManager() { 31 if (!mlirPassManagerIsNull(passManager)) 32 mlirPassManagerDestroy(passManager); 33 } 34 MlirPassManager get() { return passManager; } 35 36 void release() { passManager.ptr = nullptr; } 37 pybind11::object getCapsule() { 38 return py::reinterpret_steal<py::object>( 39 mlirPythonPassManagerToCapsule(get())); 40 } 41 42 static pybind11::object createFromCapsule(pybind11::object capsule) { 43 MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); 44 if (mlirPassManagerIsNull(rawPm)) 45 throw py::error_already_set(); 46 return py::cast(PyPassManager(rawPm), py::return_value_policy::move); 47 } 48 49 private: 50 MlirPassManager passManager; 51 }; 52 53 } // namespace 54 55 /// Create the `mlir.passmanager` here. 56 void mlir::python::populatePassManagerSubmodule(py::module &m) { 57 //---------------------------------------------------------------------------- 58 // Mapping of the top-level PassManager 59 //---------------------------------------------------------------------------- 60 py::class_<PyPassManager>(m, "PassManager", py::module_local()) 61 .def(py::init<>([](const std::string &anchorOp, 62 DefaultingPyMlirContext context) { 63 MlirPassManager passManager = mlirPassManagerCreateOnOperation( 64 context->get(), 65 mlirStringRefCreate(anchorOp.data(), anchorOp.size())); 66 return new PyPassManager(passManager); 67 }), 68 "anchor_op"_a = py::str("any"), "context"_a = py::none(), 69 "Create a new PassManager for the current (or provided) Context.") 70 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 71 &PyPassManager::getCapsule) 72 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) 73 .def("_testing_release", &PyPassManager::release, 74 "Releases (leaks) the backing pass manager (testing)") 75 .def( 76 "enable_ir_printing", 77 [](PyPassManager &passManager, bool printBeforeAll, 78 bool printAfterAll, bool printModuleScope, bool printAfterChange, 79 bool printAfterFailure) { 80 mlirPassManagerEnableIRPrinting( 81 passManager.get(), printBeforeAll, printAfterAll, 82 printModuleScope, printAfterChange, printAfterFailure); 83 }, 84 "print_before_all"_a = false, "print_after_all"_a = true, 85 "print_module_scope"_a = false, "print_after_change"_a = false, 86 "print_after_failure"_a = false, 87 "Enable IR printing, default as mlir-print-ir-after-all.") 88 .def( 89 "enable_verifier", 90 [](PyPassManager &passManager, bool enable) { 91 mlirPassManagerEnableVerifier(passManager.get(), enable); 92 }, 93 "enable"_a, "Enable / disable verify-each.") 94 .def_static( 95 "parse", 96 [](const std::string &pipeline, DefaultingPyMlirContext context) { 97 MlirPassManager passManager = mlirPassManagerCreate(context->get()); 98 PyPrintAccumulator errorMsg; 99 MlirLogicalResult status = mlirParsePassPipeline( 100 mlirPassManagerGetAsOpPassManager(passManager), 101 mlirStringRefCreate(pipeline.data(), pipeline.size()), 102 errorMsg.getCallback(), errorMsg.getUserData()); 103 if (mlirLogicalResultIsFailure(status)) 104 throw py::value_error(std::string(errorMsg.join())); 105 return new PyPassManager(passManager); 106 }, 107 "pipeline"_a, "context"_a = py::none(), 108 "Parse a textual pass-pipeline and return a top-level PassManager " 109 "that can be applied on a Module. Throw a ValueError if the pipeline " 110 "can't be parsed") 111 .def( 112 "add", 113 [](PyPassManager &passManager, const std::string &pipeline) { 114 PyPrintAccumulator errorMsg; 115 MlirLogicalResult status = mlirOpPassManagerAddPipeline( 116 mlirPassManagerGetAsOpPassManager(passManager.get()), 117 mlirStringRefCreate(pipeline.data(), pipeline.size()), 118 errorMsg.getCallback(), errorMsg.getUserData()); 119 if (mlirLogicalResultIsFailure(status)) 120 throw py::value_error(std::string(errorMsg.join())); 121 }, 122 "pipeline"_a, 123 "Add textual pipeline elements to the pass manager. Throws a " 124 "ValueError if the pipeline can't be parsed.") 125 .def( 126 "run", 127 [](PyPassManager &passManager, PyOperationBase &op, 128 bool invalidateOps) { 129 if (invalidateOps) { 130 op.getOperation().getContext()->clearOperationsInside(op); 131 } 132 // Actually run the pass manager. 133 PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); 134 MlirLogicalResult status = mlirPassManagerRunOnOp( 135 passManager.get(), op.getOperation().get()); 136 if (mlirLogicalResultIsFailure(status)) 137 throw MLIRError("Failure while executing pass pipeline", 138 errors.take()); 139 }, 140 "operation"_a, "invalidate_ops"_a = true, 141 "Run the pass manager on the provided operation, raising an " 142 "MLIRError on failure.") 143 .def( 144 "__str__", 145 [](PyPassManager &self) { 146 MlirPassManager passManager = self.get(); 147 PyPrintAccumulator printAccum; 148 mlirPrintPassPipeline( 149 mlirPassManagerGetAsOpPassManager(passManager), 150 printAccum.getCallback(), printAccum.getUserData()); 151 return printAccum.join(); 152 }, 153 "Print the textual representation for this PassManager, suitable to " 154 "be passed to `parse` for round-tripping."); 155 } 156