1 //===- Pass.cpp - Pass Management -----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Pass.h" 10 11 #include "IRModule.h" 12 #include "mlir-c/Bindings/Python/Interop.h" 13 #include "mlir-c/Pass.h" 14 15 namespace py = pybind11; 16 using namespace py::literals; 17 using namespace mlir; 18 using namespace mlir::python; 19 20 namespace { 21 22 /// Owning Wrapper around a PassManager. 23 class PyPassManager { 24 public: 25 PyPassManager(MlirPassManager passManager) : passManager(passManager) {} 26 PyPassManager(PyPassManager &&other) noexcept 27 : passManager(other.passManager) { 28 other.passManager.ptr = nullptr; 29 } 30 ~PyPassManager() { 31 if (!mlirPassManagerIsNull(passManager)) 32 mlirPassManagerDestroy(passManager); 33 } 34 MlirPassManager get() { return passManager; } 35 36 void release() { passManager.ptr = nullptr; } 37 pybind11::object getCapsule() { 38 return py::reinterpret_steal<py::object>( 39 mlirPythonPassManagerToCapsule(get())); 40 } 41 42 static pybind11::object createFromCapsule(pybind11::object capsule) { 43 MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); 44 if (mlirPassManagerIsNull(rawPm)) 45 throw py::error_already_set(); 46 return py::cast(PyPassManager(rawPm), py::return_value_policy::move); 47 } 48 49 private: 50 MlirPassManager passManager; 51 }; 52 53 } // namespace 54 55 /// Create the `mlir.passmanager` here. 56 void mlir::python::populatePassManagerSubmodule(py::module &m) { 57 //---------------------------------------------------------------------------- 58 // Mapping of the top-level PassManager 59 //---------------------------------------------------------------------------- 60 py::class_<PyPassManager>(m, "PassManager", py::module_local()) 61 .def(py::init<>([](const std::string &anchorOp, 62 DefaultingPyMlirContext context) { 63 MlirPassManager passManager = mlirPassManagerCreateOnOperation( 64 context->get(), 65 mlirStringRefCreate(anchorOp.data(), anchorOp.size())); 66 return new PyPassManager(passManager); 67 }), 68 "anchor_op"_a = py::str("any"), "context"_a = py::none(), 69 "Create a new PassManager for the current (or provided) Context.") 70 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 71 &PyPassManager::getCapsule) 72 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) 73 .def("_testing_release", &PyPassManager::release, 74 "Releases (leaks) the backing pass manager (testing)") 75 .def( 76 "enable_ir_printing", 77 [](PyPassManager &passManager, bool printBeforeAll, 78 bool printAfterAll, bool printModuleScope, bool printAfterChange, 79 bool printAfterFailure, 80 std::optional<std::string> optionalTreePrintingPath) { 81 std::string treePrintingPath = ""; 82 if (optionalTreePrintingPath.has_value()) 83 treePrintingPath = optionalTreePrintingPath.value(); 84 mlirPassManagerEnableIRPrinting( 85 passManager.get(), printBeforeAll, printAfterAll, 86 printModuleScope, printAfterChange, printAfterFailure, 87 mlirStringRefCreate(treePrintingPath.data(), 88 treePrintingPath.size())); 89 }, 90 "print_before_all"_a = false, "print_after_all"_a = true, 91 "print_module_scope"_a = false, "print_after_change"_a = false, 92 "print_after_failure"_a = false, 93 "tree_printing_dir_path"_a = py::none(), 94 "Enable IR printing, default as mlir-print-ir-after-all.") 95 .def( 96 "enable_verifier", 97 [](PyPassManager &passManager, bool enable) { 98 mlirPassManagerEnableVerifier(passManager.get(), enable); 99 }, 100 "enable"_a, "Enable / disable verify-each.") 101 .def_static( 102 "parse", 103 [](const std::string &pipeline, DefaultingPyMlirContext context) { 104 MlirPassManager passManager = mlirPassManagerCreate(context->get()); 105 PyPrintAccumulator errorMsg; 106 MlirLogicalResult status = mlirParsePassPipeline( 107 mlirPassManagerGetAsOpPassManager(passManager), 108 mlirStringRefCreate(pipeline.data(), pipeline.size()), 109 errorMsg.getCallback(), errorMsg.getUserData()); 110 if (mlirLogicalResultIsFailure(status)) 111 throw py::value_error(std::string(errorMsg.join())); 112 return new PyPassManager(passManager); 113 }, 114 "pipeline"_a, "context"_a = py::none(), 115 "Parse a textual pass-pipeline and return a top-level PassManager " 116 "that can be applied on a Module. Throw a ValueError if the pipeline " 117 "can't be parsed") 118 .def( 119 "add", 120 [](PyPassManager &passManager, const std::string &pipeline) { 121 PyPrintAccumulator errorMsg; 122 MlirLogicalResult status = mlirOpPassManagerAddPipeline( 123 mlirPassManagerGetAsOpPassManager(passManager.get()), 124 mlirStringRefCreate(pipeline.data(), pipeline.size()), 125 errorMsg.getCallback(), errorMsg.getUserData()); 126 if (mlirLogicalResultIsFailure(status)) 127 throw py::value_error(std::string(errorMsg.join())); 128 }, 129 "pipeline"_a, 130 "Add textual pipeline elements to the pass manager. Throws a " 131 "ValueError if the pipeline can't be parsed.") 132 .def( 133 "run", 134 [](PyPassManager &passManager, PyOperationBase &op, 135 bool invalidateOps) { 136 if (invalidateOps) { 137 op.getOperation().getContext()->clearOperationsInside(op); 138 } 139 // Actually run the pass manager. 140 PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); 141 MlirLogicalResult status = mlirPassManagerRunOnOp( 142 passManager.get(), op.getOperation().get()); 143 if (mlirLogicalResultIsFailure(status)) 144 throw MLIRError("Failure while executing pass pipeline", 145 errors.take()); 146 }, 147 "operation"_a, "invalidate_ops"_a = true, 148 "Run the pass manager on the provided operation, raising an " 149 "MLIRError on failure.") 150 .def( 151 "__str__", 152 [](PyPassManager &self) { 153 MlirPassManager passManager = self.get(); 154 PyPrintAccumulator printAccum; 155 mlirPrintPassPipeline( 156 mlirPassManagerGetAsOpPassManager(passManager), 157 printAccum.getCallback(), printAccum.getUserData()); 158 return printAccum.join(); 159 }, 160 "Print the textual representation for this PassManager, suitable to " 161 "be passed to `parse` for round-tripping."); 162 } 163