1 //===- Pass.cpp - Pass Management -----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Pass.h" 10 11 #include <nanobind/nanobind.h> 12 #include <nanobind/stl/optional.h> 13 #include <nanobind/stl/string.h> 14 15 #include "IRModule.h" 16 #include "mlir-c/Bindings/Python/Interop.h" 17 #include "mlir-c/Pass.h" 18 19 namespace nb = nanobind; 20 using namespace nb::literals; 21 using namespace mlir; 22 using namespace mlir::python; 23 24 namespace { 25 26 /// Owning Wrapper around a PassManager. 27 class PyPassManager { 28 public: 29 PyPassManager(MlirPassManager passManager) : passManager(passManager) {} 30 PyPassManager(PyPassManager &&other) noexcept 31 : passManager(other.passManager) { 32 other.passManager.ptr = nullptr; 33 } 34 ~PyPassManager() { 35 if (!mlirPassManagerIsNull(passManager)) 36 mlirPassManagerDestroy(passManager); 37 } 38 MlirPassManager get() { return passManager; } 39 40 void release() { passManager.ptr = nullptr; } 41 nb::object getCapsule() { 42 return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get())); 43 } 44 45 static nb::object createFromCapsule(nb::object capsule) { 46 MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); 47 if (mlirPassManagerIsNull(rawPm)) 48 throw nb::python_error(); 49 return nb::cast(PyPassManager(rawPm), nb::rv_policy::move); 50 } 51 52 private: 53 MlirPassManager passManager; 54 }; 55 56 } // namespace 57 58 /// Create the `mlir.passmanager` here. 59 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { 60 //---------------------------------------------------------------------------- 61 // Mapping of the top-level PassManager 62 //---------------------------------------------------------------------------- 63 nb::class_<PyPassManager>(m, "PassManager") 64 .def( 65 "__init__", 66 [](PyPassManager &self, const std::string &anchorOp, 67 DefaultingPyMlirContext context) { 68 MlirPassManager passManager = mlirPassManagerCreateOnOperation( 69 context->get(), 70 mlirStringRefCreate(anchorOp.data(), anchorOp.size())); 71 new (&self) PyPassManager(passManager); 72 }, 73 "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(), 74 "Create a new PassManager for the current (or provided) Context.") 75 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule) 76 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) 77 .def("_testing_release", &PyPassManager::release, 78 "Releases (leaks) the backing pass manager (testing)") 79 .def( 80 "enable_ir_printing", 81 [](PyPassManager &passManager, bool printBeforeAll, 82 bool printAfterAll, bool printModuleScope, bool printAfterChange, 83 bool printAfterFailure, std::optional<int64_t> largeElementsLimit, 84 bool enableDebugInfo, bool printGenericOpForm, 85 std::optional<std::string> optionalTreePrintingPath) { 86 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 87 if (largeElementsLimit) 88 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 89 *largeElementsLimit); 90 if (enableDebugInfo) 91 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, 92 /*prettyForm=*/false); 93 if (printGenericOpForm) 94 mlirOpPrintingFlagsPrintGenericOpForm(flags); 95 std::string treePrintingPath = ""; 96 if (optionalTreePrintingPath.has_value()) 97 treePrintingPath = optionalTreePrintingPath.value(); 98 mlirPassManagerEnableIRPrinting( 99 passManager.get(), printBeforeAll, printAfterAll, 100 printModuleScope, printAfterChange, printAfterFailure, flags, 101 mlirStringRefCreate(treePrintingPath.data(), 102 treePrintingPath.size())); 103 mlirOpPrintingFlagsDestroy(flags); 104 }, 105 "print_before_all"_a = false, "print_after_all"_a = true, 106 "print_module_scope"_a = false, "print_after_change"_a = false, 107 "print_after_failure"_a = false, 108 "large_elements_limit"_a.none() = nb::none(), 109 "enable_debug_info"_a = false, "print_generic_op_form"_a = false, 110 "tree_printing_dir_path"_a.none() = nb::none(), 111 "Enable IR printing, default as mlir-print-ir-after-all.") 112 .def( 113 "enable_verifier", 114 [](PyPassManager &passManager, bool enable) { 115 mlirPassManagerEnableVerifier(passManager.get(), enable); 116 }, 117 "enable"_a, "Enable / disable verify-each.") 118 .def_static( 119 "parse", 120 [](const std::string &pipeline, DefaultingPyMlirContext context) { 121 MlirPassManager passManager = mlirPassManagerCreate(context->get()); 122 PyPrintAccumulator errorMsg; 123 MlirLogicalResult status = mlirParsePassPipeline( 124 mlirPassManagerGetAsOpPassManager(passManager), 125 mlirStringRefCreate(pipeline.data(), pipeline.size()), 126 errorMsg.getCallback(), errorMsg.getUserData()); 127 if (mlirLogicalResultIsFailure(status)) 128 throw nb::value_error(errorMsg.join().c_str()); 129 return new PyPassManager(passManager); 130 }, 131 "pipeline"_a, "context"_a.none() = nb::none(), 132 "Parse a textual pass-pipeline and return a top-level PassManager " 133 "that can be applied on a Module. Throw a ValueError if the pipeline " 134 "can't be parsed") 135 .def( 136 "add", 137 [](PyPassManager &passManager, const std::string &pipeline) { 138 PyPrintAccumulator errorMsg; 139 MlirLogicalResult status = mlirOpPassManagerAddPipeline( 140 mlirPassManagerGetAsOpPassManager(passManager.get()), 141 mlirStringRefCreate(pipeline.data(), pipeline.size()), 142 errorMsg.getCallback(), errorMsg.getUserData()); 143 if (mlirLogicalResultIsFailure(status)) 144 throw nb::value_error(errorMsg.join().c_str()); 145 }, 146 "pipeline"_a, 147 "Add textual pipeline elements to the pass manager. Throws a " 148 "ValueError if the pipeline can't be parsed.") 149 .def( 150 "run", 151 [](PyPassManager &passManager, PyOperationBase &op, 152 bool invalidateOps) { 153 if (invalidateOps) { 154 op.getOperation().getContext()->clearOperationsInside(op); 155 } 156 // Actually run the pass manager. 157 PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); 158 MlirLogicalResult status = mlirPassManagerRunOnOp( 159 passManager.get(), op.getOperation().get()); 160 if (mlirLogicalResultIsFailure(status)) 161 throw MLIRError("Failure while executing pass pipeline", 162 errors.take()); 163 }, 164 "operation"_a, "invalidate_ops"_a = true, 165 "Run the pass manager on the provided operation, raising an " 166 "MLIRError on failure.") 167 .def( 168 "__str__", 169 [](PyPassManager &self) { 170 MlirPassManager passManager = self.get(); 171 PyPrintAccumulator printAccum; 172 mlirPrintPassPipeline( 173 mlirPassManagerGetAsOpPassManager(passManager), 174 printAccum.getCallback(), printAccum.getUserData()); 175 return printAccum.join(); 176 }, 177 "Print the textual representation for this PassManager, suitable to " 178 "be passed to `parse` for round-tripping."); 179 } 180