1 //===- Pass.cpp - Pass Management -----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Pass.h" 10 11 #include "IRModule.h" 12 #include "mlir-c/Bindings/Python/Interop.h" 13 #include "mlir-c/Pass.h" 14 15 namespace py = pybind11; 16 using namespace py::literals; 17 using namespace mlir; 18 using namespace mlir::python; 19 20 namespace { 21 22 /// Owning Wrapper around a PassManager. 23 class PyPassManager { 24 public: 25 PyPassManager(MlirPassManager passManager) : passManager(passManager) {} 26 PyPassManager(PyPassManager &&other) noexcept 27 : passManager(other.passManager) { 28 other.passManager.ptr = nullptr; 29 } 30 ~PyPassManager() { 31 if (!mlirPassManagerIsNull(passManager)) 32 mlirPassManagerDestroy(passManager); 33 } 34 MlirPassManager get() { return passManager; } 35 36 void release() { passManager.ptr = nullptr; } 37 pybind11::object getCapsule() { 38 return py::reinterpret_steal<py::object>( 39 mlirPythonPassManagerToCapsule(get())); 40 } 41 42 static pybind11::object createFromCapsule(pybind11::object capsule) { 43 MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); 44 if (mlirPassManagerIsNull(rawPm)) 45 throw py::error_already_set(); 46 return py::cast(PyPassManager(rawPm), py::return_value_policy::move); 47 } 48 49 private: 50 MlirPassManager passManager; 51 }; 52 53 } // namespace 54 55 /// Create the `mlir.passmanager` here. 56 void mlir::python::populatePassManagerSubmodule(py::module &m) { 57 //---------------------------------------------------------------------------- 58 // Mapping of the top-level PassManager 59 //---------------------------------------------------------------------------- 60 py::class_<PyPassManager>(m, "PassManager", py::module_local()) 61 .def(py::init<>([](const std::string &anchorOp, 62 DefaultingPyMlirContext context) { 63 MlirPassManager passManager = mlirPassManagerCreateOnOperation( 64 context->get(), 65 mlirStringRefCreate(anchorOp.data(), anchorOp.size())); 66 return new PyPassManager(passManager); 67 }), 68 "anchor_op"_a = py::str("any"), "context"_a = py::none(), 69 "Create a new PassManager for the current (or provided) Context.") 70 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 71 &PyPassManager::getCapsule) 72 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) 73 .def("_testing_release", &PyPassManager::release, 74 "Releases (leaks) the backing pass manager (testing)") 75 .def( 76 "enable_ir_printing", 77 [](PyPassManager &passManager, bool printBeforeAll, 78 bool printAfterAll, bool printModuleScope, bool printAfterChange, 79 bool printAfterFailure, std::optional<int64_t> largeElementsLimit, 80 bool enableDebugInfo, bool printGenericOpForm, 81 std::optional<std::string> optionalTreePrintingPath) { 82 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 83 if (largeElementsLimit) 84 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 85 *largeElementsLimit); 86 if (enableDebugInfo) 87 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, 88 /*prettyForm=*/false); 89 if (printGenericOpForm) 90 mlirOpPrintingFlagsPrintGenericOpForm(flags); 91 std::string treePrintingPath = ""; 92 if (optionalTreePrintingPath.has_value()) 93 treePrintingPath = optionalTreePrintingPath.value(); 94 mlirPassManagerEnableIRPrinting( 95 passManager.get(), printBeforeAll, printAfterAll, 96 printModuleScope, printAfterChange, printAfterFailure, flags, 97 mlirStringRefCreate(treePrintingPath.data(), 98 treePrintingPath.size())); 99 mlirOpPrintingFlagsDestroy(flags); 100 }, 101 "print_before_all"_a = false, "print_after_all"_a = true, 102 "print_module_scope"_a = false, "print_after_change"_a = false, 103 "print_after_failure"_a = false, 104 "large_elements_limit"_a = py::none(), "enable_debug_info"_a = false, 105 "print_generic_op_form"_a = false, 106 "tree_printing_dir_path"_a = py::none(), 107 "Enable IR printing, default as mlir-print-ir-after-all.") 108 .def( 109 "enable_verifier", 110 [](PyPassManager &passManager, bool enable) { 111 mlirPassManagerEnableVerifier(passManager.get(), enable); 112 }, 113 "enable"_a, "Enable / disable verify-each.") 114 .def_static( 115 "parse", 116 [](const std::string &pipeline, DefaultingPyMlirContext context) { 117 MlirPassManager passManager = mlirPassManagerCreate(context->get()); 118 PyPrintAccumulator errorMsg; 119 MlirLogicalResult status = mlirParsePassPipeline( 120 mlirPassManagerGetAsOpPassManager(passManager), 121 mlirStringRefCreate(pipeline.data(), pipeline.size()), 122 errorMsg.getCallback(), errorMsg.getUserData()); 123 if (mlirLogicalResultIsFailure(status)) 124 throw py::value_error(std::string(errorMsg.join())); 125 return new PyPassManager(passManager); 126 }, 127 "pipeline"_a, "context"_a = py::none(), 128 "Parse a textual pass-pipeline and return a top-level PassManager " 129 "that can be applied on a Module. Throw a ValueError if the pipeline " 130 "can't be parsed") 131 .def( 132 "add", 133 [](PyPassManager &passManager, const std::string &pipeline) { 134 PyPrintAccumulator errorMsg; 135 MlirLogicalResult status = mlirOpPassManagerAddPipeline( 136 mlirPassManagerGetAsOpPassManager(passManager.get()), 137 mlirStringRefCreate(pipeline.data(), pipeline.size()), 138 errorMsg.getCallback(), errorMsg.getUserData()); 139 if (mlirLogicalResultIsFailure(status)) 140 throw py::value_error(std::string(errorMsg.join())); 141 }, 142 "pipeline"_a, 143 "Add textual pipeline elements to the pass manager. Throws a " 144 "ValueError if the pipeline can't be parsed.") 145 .def( 146 "run", 147 [](PyPassManager &passManager, PyOperationBase &op, 148 bool invalidateOps) { 149 if (invalidateOps) { 150 op.getOperation().getContext()->clearOperationsInside(op); 151 } 152 // Actually run the pass manager. 153 PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); 154 MlirLogicalResult status = mlirPassManagerRunOnOp( 155 passManager.get(), op.getOperation().get()); 156 if (mlirLogicalResultIsFailure(status)) 157 throw MLIRError("Failure while executing pass pipeline", 158 errors.take()); 159 }, 160 "operation"_a, "invalidate_ops"_a = true, 161 "Run the pass manager on the provided operation, raising an " 162 "MLIRError on failure.") 163 .def( 164 "__str__", 165 [](PyPassManager &self) { 166 MlirPassManager passManager = self.get(); 167 PyPrintAccumulator printAccum; 168 mlirPrintPassPipeline( 169 mlirPassManagerGetAsOpPassManager(passManager), 170 printAccum.getCallback(), printAccum.getUserData()); 171 return printAccum.join(); 172 }, 173 "Print the textual representation for this PassManager, suitable to " 174 "be passed to `parse` for round-tripping."); 175 } 176