118cf1cd9SJacques Pienaar //===- Rewrite.cpp - Rewrite ----------------------------------------------===// 218cf1cd9SJacques Pienaar // 318cf1cd9SJacques Pienaar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 418cf1cd9SJacques Pienaar // See https://llvm.org/LICENSE.txt for license information. 518cf1cd9SJacques Pienaar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 618cf1cd9SJacques Pienaar // 718cf1cd9SJacques Pienaar //===----------------------------------------------------------------------===// 818cf1cd9SJacques Pienaar 918cf1cd9SJacques Pienaar #include "Rewrite.h" 1018cf1cd9SJacques Pienaar 1118cf1cd9SJacques Pienaar #include "IRModule.h" 1218cf1cd9SJacques Pienaar #include "mlir-c/Rewrite.h" 13*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h" 14*5cd42747SPeter Hawkins #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. 1518cf1cd9SJacques Pienaar #include "mlir/Config/mlir-config.h" 1618cf1cd9SJacques Pienaar 17b56d1ec6SPeter Hawkins namespace nb = nanobind; 1818cf1cd9SJacques Pienaar using namespace mlir; 19b56d1ec6SPeter Hawkins using namespace nb::literals; 2018cf1cd9SJacques Pienaar using namespace mlir::python; 2118cf1cd9SJacques Pienaar 2218cf1cd9SJacques Pienaar namespace { 2318cf1cd9SJacques Pienaar 2418cf1cd9SJacques Pienaar #if MLIR_ENABLE_PDL_IN_PATTERNMATCH 2518cf1cd9SJacques Pienaar /// Owning Wrapper around a PDLPatternModule. 2618cf1cd9SJacques Pienaar class PyPDLPatternModule { 2718cf1cd9SJacques Pienaar public: 2818cf1cd9SJacques Pienaar PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {} 2918cf1cd9SJacques Pienaar PyPDLPatternModule(PyPDLPatternModule &&other) noexcept 3018cf1cd9SJacques Pienaar : module(other.module) { 3118cf1cd9SJacques Pienaar other.module.ptr = nullptr; 3218cf1cd9SJacques Pienaar } 3318cf1cd9SJacques Pienaar ~PyPDLPatternModule() { 3418cf1cd9SJacques Pienaar if (module.ptr != nullptr) 3518cf1cd9SJacques Pienaar mlirPDLPatternModuleDestroy(module); 3618cf1cd9SJacques Pienaar } 3718cf1cd9SJacques Pienaar MlirPDLPatternModule get() { return module; } 3818cf1cd9SJacques Pienaar 3918cf1cd9SJacques Pienaar private: 4018cf1cd9SJacques Pienaar MlirPDLPatternModule module; 4118cf1cd9SJacques Pienaar }; 4218cf1cd9SJacques Pienaar #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH 4318cf1cd9SJacques Pienaar 4418cf1cd9SJacques Pienaar /// Owning Wrapper around a FrozenRewritePatternSet. 4518cf1cd9SJacques Pienaar class PyFrozenRewritePatternSet { 4618cf1cd9SJacques Pienaar public: 4718cf1cd9SJacques Pienaar PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {} 4818cf1cd9SJacques Pienaar PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept 4918cf1cd9SJacques Pienaar : set(other.set) { 5018cf1cd9SJacques Pienaar other.set.ptr = nullptr; 5118cf1cd9SJacques Pienaar } 5218cf1cd9SJacques Pienaar ~PyFrozenRewritePatternSet() { 5318cf1cd9SJacques Pienaar if (set.ptr != nullptr) 5418cf1cd9SJacques Pienaar mlirFrozenRewritePatternSetDestroy(set); 5518cf1cd9SJacques Pienaar } 5618cf1cd9SJacques Pienaar MlirFrozenRewritePatternSet get() { return set; } 5718cf1cd9SJacques Pienaar 58b56d1ec6SPeter Hawkins nb::object getCapsule() { 59b56d1ec6SPeter Hawkins return nb::steal<nb::object>( 6018cf1cd9SJacques Pienaar mlirPythonFrozenRewritePatternSetToCapsule(get())); 6118cf1cd9SJacques Pienaar } 6218cf1cd9SJacques Pienaar 63b56d1ec6SPeter Hawkins static nb::object createFromCapsule(nb::object capsule) { 6418cf1cd9SJacques Pienaar MlirFrozenRewritePatternSet rawPm = 6518cf1cd9SJacques Pienaar mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); 6618cf1cd9SJacques Pienaar if (rawPm.ptr == nullptr) 67b56d1ec6SPeter Hawkins throw nb::python_error(); 68b56d1ec6SPeter Hawkins return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move); 6918cf1cd9SJacques Pienaar } 7018cf1cd9SJacques Pienaar 7118cf1cd9SJacques Pienaar private: 7218cf1cd9SJacques Pienaar MlirFrozenRewritePatternSet set; 7318cf1cd9SJacques Pienaar }; 7418cf1cd9SJacques Pienaar 7518cf1cd9SJacques Pienaar } // namespace 7618cf1cd9SJacques Pienaar 7718cf1cd9SJacques Pienaar /// Create the `mlir.rewrite` here. 78b56d1ec6SPeter Hawkins void mlir::python::populateRewriteSubmodule(nb::module_ &m) { 7918cf1cd9SJacques Pienaar //---------------------------------------------------------------------------- 8018cf1cd9SJacques Pienaar // Mapping of the top-level PassManager 8118cf1cd9SJacques Pienaar //---------------------------------------------------------------------------- 8218cf1cd9SJacques Pienaar #if MLIR_ENABLE_PDL_IN_PATTERNMATCH 83b56d1ec6SPeter Hawkins nb::class_<PyPDLPatternModule>(m, "PDLModule") 84b56d1ec6SPeter Hawkins .def( 85b56d1ec6SPeter Hawkins "__init__", 86b56d1ec6SPeter Hawkins [](PyPDLPatternModule &self, MlirModule module) { 87b56d1ec6SPeter Hawkins new (&self) 88b56d1ec6SPeter Hawkins PyPDLPatternModule(mlirPDLPatternModuleFromModule(module)); 89b56d1ec6SPeter Hawkins }, 9018cf1cd9SJacques Pienaar "module"_a, "Create a PDL module from the given module.") 9118cf1cd9SJacques Pienaar .def("freeze", [](PyPDLPatternModule &self) { 9218cf1cd9SJacques Pienaar return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( 9318cf1cd9SJacques Pienaar mlirRewritePatternSetFromPDLPatternModule(self.get()))); 9418cf1cd9SJacques Pienaar }); 95b56d1ec6SPeter Hawkins #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH 96b56d1ec6SPeter Hawkins nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet") 97b56d1ec6SPeter Hawkins .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, 9818cf1cd9SJacques Pienaar &PyFrozenRewritePatternSet::getCapsule) 9918cf1cd9SJacques Pienaar .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, 10018cf1cd9SJacques Pienaar &PyFrozenRewritePatternSet::createFromCapsule); 10118cf1cd9SJacques Pienaar m.def( 10218cf1cd9SJacques Pienaar "apply_patterns_and_fold_greedily", 10318cf1cd9SJacques Pienaar [](MlirModule module, MlirFrozenRewritePatternSet set) { 10418cf1cd9SJacques Pienaar auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); 10518cf1cd9SJacques Pienaar if (mlirLogicalResultIsFailure(status)) 10618cf1cd9SJacques Pienaar // FIXME: Not sure this is the right error to throw here. 107b56d1ec6SPeter Hawkins throw nb::value_error("pattern application failed to converge"); 10818cf1cd9SJacques Pienaar }, 10918cf1cd9SJacques Pienaar "module"_a, "set"_a, 11018cf1cd9SJacques Pienaar "Applys the given patterns to the given module greedily while folding " 11118cf1cd9SJacques Pienaar "results."); 11218cf1cd9SJacques Pienaar } 113