1 //===- Pass.cpp - Pass Management -----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Pass.h" 10 11 #include "IRModule.h" 12 #include "mlir-c/Pass.h" 13 #include "mlir/Bindings/Python/Nanobind.h" 14 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. 15 16 namespace nb = nanobind; 17 using namespace nb::literals; 18 using namespace mlir; 19 using namespace mlir::python; 20 21 namespace { 22 23 /// Owning Wrapper around a PassManager. 24 class PyPassManager { 25 public: 26 PyPassManager(MlirPassManager passManager) : passManager(passManager) {} 27 PyPassManager(PyPassManager &&other) noexcept 28 : passManager(other.passManager) { 29 other.passManager.ptr = nullptr; 30 } 31 ~PyPassManager() { 32 if (!mlirPassManagerIsNull(passManager)) 33 mlirPassManagerDestroy(passManager); 34 } 35 MlirPassManager get() { return passManager; } 36 37 void release() { passManager.ptr = nullptr; } 38 nb::object getCapsule() { 39 return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get())); 40 } 41 42 static nb::object createFromCapsule(nb::object capsule) { 43 MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); 44 if (mlirPassManagerIsNull(rawPm)) 45 throw nb::python_error(); 46 return nb::cast(PyPassManager(rawPm), nb::rv_policy::move); 47 } 48 49 private: 50 MlirPassManager passManager; 51 }; 52 53 } // namespace 54 55 /// Create the `mlir.passmanager` here. 56 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { 57 //---------------------------------------------------------------------------- 58 // Mapping of the top-level PassManager 59 //---------------------------------------------------------------------------- 60 nb::class_<PyPassManager>(m, "PassManager") 61 .def( 62 "__init__", 63 [](PyPassManager &self, const std::string &anchorOp, 64 DefaultingPyMlirContext context) { 65 MlirPassManager passManager = mlirPassManagerCreateOnOperation( 66 context->get(), 67 mlirStringRefCreate(anchorOp.data(), anchorOp.size())); 68 new (&self) PyPassManager(passManager); 69 }, 70 "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(), 71 "Create a new PassManager for the current (or provided) Context.") 72 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule) 73 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) 74 .def("_testing_release", &PyPassManager::release, 75 "Releases (leaks) the backing pass manager (testing)") 76 .def( 77 "enable_ir_printing", 78 [](PyPassManager &passManager, bool printBeforeAll, 79 bool printAfterAll, bool printModuleScope, bool printAfterChange, 80 bool printAfterFailure, std::optional<int64_t> largeElementsLimit, 81 bool enableDebugInfo, bool printGenericOpForm, 82 std::optional<std::string> optionalTreePrintingPath) { 83 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 84 if (largeElementsLimit) 85 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 86 *largeElementsLimit); 87 if (enableDebugInfo) 88 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, 89 /*prettyForm=*/false); 90 if (printGenericOpForm) 91 mlirOpPrintingFlagsPrintGenericOpForm(flags); 92 std::string treePrintingPath = ""; 93 if (optionalTreePrintingPath.has_value()) 94 treePrintingPath = optionalTreePrintingPath.value(); 95 mlirPassManagerEnableIRPrinting( 96 passManager.get(), printBeforeAll, printAfterAll, 97 printModuleScope, printAfterChange, printAfterFailure, flags, 98 mlirStringRefCreate(treePrintingPath.data(), 99 treePrintingPath.size())); 100 mlirOpPrintingFlagsDestroy(flags); 101 }, 102 "print_before_all"_a = false, "print_after_all"_a = true, 103 "print_module_scope"_a = false, "print_after_change"_a = false, 104 "print_after_failure"_a = false, 105 "large_elements_limit"_a.none() = nb::none(), 106 "enable_debug_info"_a = false, "print_generic_op_form"_a = false, 107 "tree_printing_dir_path"_a.none() = nb::none(), 108 "Enable IR printing, default as mlir-print-ir-after-all.") 109 .def( 110 "enable_verifier", 111 [](PyPassManager &passManager, bool enable) { 112 mlirPassManagerEnableVerifier(passManager.get(), enable); 113 }, 114 "enable"_a, "Enable / disable verify-each.") 115 .def_static( 116 "parse", 117 [](const std::string &pipeline, DefaultingPyMlirContext context) { 118 MlirPassManager passManager = mlirPassManagerCreate(context->get()); 119 PyPrintAccumulator errorMsg; 120 MlirLogicalResult status = mlirParsePassPipeline( 121 mlirPassManagerGetAsOpPassManager(passManager), 122 mlirStringRefCreate(pipeline.data(), pipeline.size()), 123 errorMsg.getCallback(), errorMsg.getUserData()); 124 if (mlirLogicalResultIsFailure(status)) 125 throw nb::value_error(errorMsg.join().c_str()); 126 return new PyPassManager(passManager); 127 }, 128 "pipeline"_a, "context"_a.none() = nb::none(), 129 "Parse a textual pass-pipeline and return a top-level PassManager " 130 "that can be applied on a Module. Throw a ValueError if the pipeline " 131 "can't be parsed") 132 .def( 133 "add", 134 [](PyPassManager &passManager, const std::string &pipeline) { 135 PyPrintAccumulator errorMsg; 136 MlirLogicalResult status = mlirOpPassManagerAddPipeline( 137 mlirPassManagerGetAsOpPassManager(passManager.get()), 138 mlirStringRefCreate(pipeline.data(), pipeline.size()), 139 errorMsg.getCallback(), errorMsg.getUserData()); 140 if (mlirLogicalResultIsFailure(status)) 141 throw nb::value_error(errorMsg.join().c_str()); 142 }, 143 "pipeline"_a, 144 "Add textual pipeline elements to the pass manager. Throws a " 145 "ValueError if the pipeline can't be parsed.") 146 .def( 147 "run", 148 [](PyPassManager &passManager, PyOperationBase &op, 149 bool invalidateOps) { 150 if (invalidateOps) { 151 op.getOperation().getContext()->clearOperationsInside(op); 152 } 153 // Actually run the pass manager. 154 PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); 155 MlirLogicalResult status = mlirPassManagerRunOnOp( 156 passManager.get(), op.getOperation().get()); 157 if (mlirLogicalResultIsFailure(status)) 158 throw MLIRError("Failure while executing pass pipeline", 159 errors.take()); 160 }, 161 "operation"_a, "invalidate_ops"_a = true, 162 "Run the pass manager on the provided operation, raising an " 163 "MLIRError on failure.") 164 .def( 165 "__str__", 166 [](PyPassManager &self) { 167 MlirPassManager passManager = self.get(); 168 PyPrintAccumulator printAccum; 169 mlirPrintPassPipeline( 170 mlirPassManagerGetAsOpPassManager(passManager), 171 printAccum.getCallback(), printAccum.getUserData()); 172 return printAccum.join(); 173 }, 174 "Print the textual representation for this PassManager, suitable to " 175 "be passed to `parse` for round-tripping."); 176 } 177