1dc43f785SMehdi Amini //===- Pass.cpp - Pass Management -----------------------------------------===// 2dc43f785SMehdi Amini // 3dc43f785SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4dc43f785SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 5dc43f785SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6dc43f785SMehdi Amini // 7dc43f785SMehdi Amini //===----------------------------------------------------------------------===// 8dc43f785SMehdi Amini 9dc43f785SMehdi Amini #include "Pass.h" 10dc43f785SMehdi Amini 11436c6c9cSStella Laurenzo #include "IRModule.h" 12dc43f785SMehdi Amini #include "mlir-c/Pass.h" 13*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h" 14*5cd42747SPeter Hawkins #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. 15dc43f785SMehdi Amini 16b56d1ec6SPeter Hawkins namespace nb = nanobind; 17b56d1ec6SPeter Hawkins using namespace nb::literals; 18dc43f785SMehdi Amini using namespace mlir; 19dc43f785SMehdi Amini using namespace mlir::python; 20dc43f785SMehdi Amini 21dc43f785SMehdi Amini namespace { 22dc43f785SMehdi Amini 23dc43f785SMehdi Amini /// Owning Wrapper around a PassManager. 24dc43f785SMehdi Amini class PyPassManager { 25dc43f785SMehdi Amini public: 26dc43f785SMehdi Amini PyPassManager(MlirPassManager passManager) : passManager(passManager) {} 27ea2e83afSAdrian Kuegel PyPassManager(PyPassManager &&other) noexcept 28ea2e83afSAdrian Kuegel : passManager(other.passManager) { 295fef6ce0SStella Laurenzo other.passManager.ptr = nullptr; 305fef6ce0SStella Laurenzo } 315fef6ce0SStella Laurenzo ~PyPassManager() { 325fef6ce0SStella Laurenzo if (!mlirPassManagerIsNull(passManager)) 335fef6ce0SStella Laurenzo mlirPassManagerDestroy(passManager); 345fef6ce0SStella Laurenzo } 35dc43f785SMehdi Amini MlirPassManager get() { return passManager; } 36dc43f785SMehdi Amini 375fef6ce0SStella Laurenzo void release() { passManager.ptr = nullptr; } 38b56d1ec6SPeter Hawkins nb::object getCapsule() { 39b56d1ec6SPeter Hawkins return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get())); 405fef6ce0SStella Laurenzo } 415fef6ce0SStella Laurenzo 42b56d1ec6SPeter Hawkins static nb::object createFromCapsule(nb::object capsule) { 435fef6ce0SStella Laurenzo MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); 445fef6ce0SStella Laurenzo if (mlirPassManagerIsNull(rawPm)) 45b56d1ec6SPeter Hawkins throw nb::python_error(); 46b56d1ec6SPeter Hawkins return nb::cast(PyPassManager(rawPm), nb::rv_policy::move); 475fef6ce0SStella Laurenzo } 485fef6ce0SStella Laurenzo 49dc43f785SMehdi Amini private: 50dc43f785SMehdi Amini MlirPassManager passManager; 51dc43f785SMehdi Amini }; 52dc43f785SMehdi Amini 53be0a7e9fSMehdi Amini } // namespace 54dc43f785SMehdi Amini 55dc43f785SMehdi Amini /// Create the `mlir.passmanager` here. 56b56d1ec6SPeter Hawkins void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { 57dc43f785SMehdi Amini //---------------------------------------------------------------------------- 58dc43f785SMehdi Amini // Mapping of the top-level PassManager 59dc43f785SMehdi Amini //---------------------------------------------------------------------------- 60b56d1ec6SPeter Hawkins nb::class_<PyPassManager>(m, "PassManager") 61b56d1ec6SPeter Hawkins .def( 62b56d1ec6SPeter Hawkins "__init__", 63b56d1ec6SPeter Hawkins [](PyPassManager &self, const std::string &anchorOp, 64d97e8cd4Srkayaith DefaultingPyMlirContext context) { 65d97e8cd4Srkayaith MlirPassManager passManager = mlirPassManagerCreateOnOperation( 66d97e8cd4Srkayaith context->get(), 67d97e8cd4Srkayaith mlirStringRefCreate(anchorOp.data(), anchorOp.size())); 68b56d1ec6SPeter Hawkins new (&self) PyPassManager(passManager); 69b56d1ec6SPeter Hawkins }, 70b56d1ec6SPeter Hawkins "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(), 71dc43f785SMehdi Amini "Create a new PassManager for the current (or provided) Context.") 72b56d1ec6SPeter Hawkins .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule) 735fef6ce0SStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) 745fef6ce0SStella Laurenzo .def("_testing_release", &PyPassManager::release, 755fef6ce0SStella Laurenzo "Releases (leaks) the backing pass manager (testing)") 76caa159f0SNicolas Vasilache .def( 77caa159f0SNicolas Vasilache "enable_ir_printing", 78f8eceb45SBimo [](PyPassManager &passManager, bool printBeforeAll, 79f8eceb45SBimo bool printAfterAll, bool printModuleScope, bool printAfterChange, 802e51e150SYuanqiang Liu bool printAfterFailure, std::optional<int64_t> largeElementsLimit, 812e51e150SYuanqiang Liu bool enableDebugInfo, bool printGenericOpForm, 82c8b837adSMehdi Amini std::optional<std::string> optionalTreePrintingPath) { 832e51e150SYuanqiang Liu MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 842e51e150SYuanqiang Liu if (largeElementsLimit) 852e51e150SYuanqiang Liu mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 862e51e150SYuanqiang Liu *largeElementsLimit); 872e51e150SYuanqiang Liu if (enableDebugInfo) 882e51e150SYuanqiang Liu mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, 892e51e150SYuanqiang Liu /*prettyForm=*/false); 902e51e150SYuanqiang Liu if (printGenericOpForm) 912e51e150SYuanqiang Liu mlirOpPrintingFlagsPrintGenericOpForm(flags); 92c8b837adSMehdi Amini std::string treePrintingPath = ""; 93c8b837adSMehdi Amini if (optionalTreePrintingPath.has_value()) 94c8b837adSMehdi Amini treePrintingPath = optionalTreePrintingPath.value(); 95f8eceb45SBimo mlirPassManagerEnableIRPrinting( 96f8eceb45SBimo passManager.get(), printBeforeAll, printAfterAll, 972e51e150SYuanqiang Liu printModuleScope, printAfterChange, printAfterFailure, flags, 98c8b837adSMehdi Amini mlirStringRefCreate(treePrintingPath.data(), 99c8b837adSMehdi Amini treePrintingPath.size())); 1002e51e150SYuanqiang Liu mlirOpPrintingFlagsDestroy(flags); 101caa159f0SNicolas Vasilache }, 102f8eceb45SBimo "print_before_all"_a = false, "print_after_all"_a = true, 103f8eceb45SBimo "print_module_scope"_a = false, "print_after_change"_a = false, 104f8eceb45SBimo "print_after_failure"_a = false, 105b56d1ec6SPeter Hawkins "large_elements_limit"_a.none() = nb::none(), 106b56d1ec6SPeter Hawkins "enable_debug_info"_a = false, "print_generic_op_form"_a = false, 107b56d1ec6SPeter Hawkins "tree_printing_dir_path"_a.none() = nb::none(), 108f8eceb45SBimo "Enable IR printing, default as mlir-print-ir-after-all.") 109caa159f0SNicolas Vasilache .def( 110caa159f0SNicolas Vasilache "enable_verifier", 111caa159f0SNicolas Vasilache [](PyPassManager &passManager, bool enable) { 112caa159f0SNicolas Vasilache mlirPassManagerEnableVerifier(passManager.get(), enable); 113caa159f0SNicolas Vasilache }, 114bdc3e6cbSMaksim Levental "enable"_a, "Enable / disable verify-each.") 115dc43f785SMehdi Amini .def_static( 116dc43f785SMehdi Amini "parse", 117b3c5f6b1Srkayaith [](const std::string &pipeline, DefaultingPyMlirContext context) { 118dc43f785SMehdi Amini MlirPassManager passManager = mlirPassManagerCreate(context->get()); 119b3c5f6b1Srkayaith PyPrintAccumulator errorMsg; 12066645a03Srkayaith MlirLogicalResult status = mlirParsePassPipeline( 121dc43f785SMehdi Amini mlirPassManagerGetAsOpPassManager(passManager), 122b3c5f6b1Srkayaith mlirStringRefCreate(pipeline.data(), pipeline.size()), 123b3c5f6b1Srkayaith errorMsg.getCallback(), errorMsg.getUserData()); 124dc43f785SMehdi Amini if (mlirLogicalResultIsFailure(status)) 125b56d1ec6SPeter Hawkins throw nb::value_error(errorMsg.join().c_str()); 126dc43f785SMehdi Amini return new PyPassManager(passManager); 127dc43f785SMehdi Amini }, 128b56d1ec6SPeter Hawkins "pipeline"_a, "context"_a.none() = nb::none(), 129dc43f785SMehdi Amini "Parse a textual pass-pipeline and return a top-level PassManager " 130dc43f785SMehdi Amini "that can be applied on a Module. Throw a ValueError if the pipeline " 131dc43f785SMehdi Amini "can't be parsed") 132dc43f785SMehdi Amini .def( 133dd1b1d44Srkayaith "add", 134dd1b1d44Srkayaith [](PyPassManager &passManager, const std::string &pipeline) { 135dd1b1d44Srkayaith PyPrintAccumulator errorMsg; 136dd1b1d44Srkayaith MlirLogicalResult status = mlirOpPassManagerAddPipeline( 137dd1b1d44Srkayaith mlirPassManagerGetAsOpPassManager(passManager.get()), 138dd1b1d44Srkayaith mlirStringRefCreate(pipeline.data(), pipeline.size()), 139dd1b1d44Srkayaith errorMsg.getCallback(), errorMsg.getUserData()); 140dd1b1d44Srkayaith if (mlirLogicalResultIsFailure(status)) 141b56d1ec6SPeter Hawkins throw nb::value_error(errorMsg.join().c_str()); 142dd1b1d44Srkayaith }, 143bdc3e6cbSMaksim Levental "pipeline"_a, 144dd1b1d44Srkayaith "Add textual pipeline elements to the pass manager. Throws a " 145dd1b1d44Srkayaith "ValueError if the pipeline can't be parsed.") 146dd1b1d44Srkayaith .def( 1476cb1c0caSMehdi Amini "run", 148bdc3e6cbSMaksim Levental [](PyPassManager &passManager, PyOperationBase &op, 149bdc3e6cbSMaksim Levental bool invalidateOps) { 150bdc3e6cbSMaksim Levental if (invalidateOps) { 151fa19ef7aSIngo Müller op.getOperation().getContext()->clearOperationsInside(op); 152bdc3e6cbSMaksim Levental } 153bdc3e6cbSMaksim Levental // Actually run the pass manager. 1543ea4c501SRahul Kayaith PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); 1556f5590caSrkayaith MlirLogicalResult status = mlirPassManagerRunOnOp( 156c00f81ccSrkayaith passManager.get(), op.getOperation().get()); 1576cb1c0caSMehdi Amini if (mlirLogicalResultIsFailure(status)) 1583ea4c501SRahul Kayaith throw MLIRError("Failure while executing pass pipeline", 1593ea4c501SRahul Kayaith errors.take()); 1606cb1c0caSMehdi Amini }, 161bdc3e6cbSMaksim Levental "operation"_a, "invalidate_ops"_a = true, 1623ea4c501SRahul Kayaith "Run the pass manager on the provided operation, raising an " 1633ea4c501SRahul Kayaith "MLIRError on failure.") 1646cb1c0caSMehdi Amini .def( 165dc43f785SMehdi Amini "__str__", 166dc43f785SMehdi Amini [](PyPassManager &self) { 167dc43f785SMehdi Amini MlirPassManager passManager = self.get(); 168dc43f785SMehdi Amini PyPrintAccumulator printAccum; 169dc43f785SMehdi Amini mlirPrintPassPipeline( 170dc43f785SMehdi Amini mlirPassManagerGetAsOpPassManager(passManager), 171dc43f785SMehdi Amini printAccum.getCallback(), printAccum.getUserData()); 172dc43f785SMehdi Amini return printAccum.join(); 173dc43f785SMehdi Amini }, 174dc43f785SMehdi Amini "Print the textual representation for this PassManager, suitable to " 175dc43f785SMehdi Amini "be passed to `parse` for round-tripping."); 176dc43f785SMehdi Amini } 177