xref: /llvm-project/mlir/lib/Bindings/Python/Rewrite.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
118cf1cd9SJacques Pienaar //===- Rewrite.cpp - Rewrite ----------------------------------------------===//
218cf1cd9SJacques Pienaar //
318cf1cd9SJacques Pienaar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
418cf1cd9SJacques Pienaar // See https://llvm.org/LICENSE.txt for license information.
518cf1cd9SJacques Pienaar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
618cf1cd9SJacques Pienaar //
718cf1cd9SJacques Pienaar //===----------------------------------------------------------------------===//
818cf1cd9SJacques Pienaar 
918cf1cd9SJacques Pienaar #include "Rewrite.h"
1018cf1cd9SJacques Pienaar 
1118cf1cd9SJacques Pienaar #include "IRModule.h"
1218cf1cd9SJacques Pienaar #include "mlir-c/Rewrite.h"
13*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h"
14*5cd42747SPeter Hawkins #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
1518cf1cd9SJacques Pienaar #include "mlir/Config/mlir-config.h"
1618cf1cd9SJacques Pienaar 
17b56d1ec6SPeter Hawkins namespace nb = nanobind;
1818cf1cd9SJacques Pienaar using namespace mlir;
19b56d1ec6SPeter Hawkins using namespace nb::literals;
2018cf1cd9SJacques Pienaar using namespace mlir::python;
2118cf1cd9SJacques Pienaar 
2218cf1cd9SJacques Pienaar namespace {
2318cf1cd9SJacques Pienaar 
2418cf1cd9SJacques Pienaar #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
2518cf1cd9SJacques Pienaar /// Owning Wrapper around a PDLPatternModule.
2618cf1cd9SJacques Pienaar class PyPDLPatternModule {
2718cf1cd9SJacques Pienaar public:
2818cf1cd9SJacques Pienaar   PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
2918cf1cd9SJacques Pienaar   PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
3018cf1cd9SJacques Pienaar       : module(other.module) {
3118cf1cd9SJacques Pienaar     other.module.ptr = nullptr;
3218cf1cd9SJacques Pienaar   }
3318cf1cd9SJacques Pienaar   ~PyPDLPatternModule() {
3418cf1cd9SJacques Pienaar     if (module.ptr != nullptr)
3518cf1cd9SJacques Pienaar       mlirPDLPatternModuleDestroy(module);
3618cf1cd9SJacques Pienaar   }
3718cf1cd9SJacques Pienaar   MlirPDLPatternModule get() { return module; }
3818cf1cd9SJacques Pienaar 
3918cf1cd9SJacques Pienaar private:
4018cf1cd9SJacques Pienaar   MlirPDLPatternModule module;
4118cf1cd9SJacques Pienaar };
4218cf1cd9SJacques Pienaar #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
4318cf1cd9SJacques Pienaar 
4418cf1cd9SJacques Pienaar /// Owning Wrapper around a FrozenRewritePatternSet.
4518cf1cd9SJacques Pienaar class PyFrozenRewritePatternSet {
4618cf1cd9SJacques Pienaar public:
4718cf1cd9SJacques Pienaar   PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
4818cf1cd9SJacques Pienaar   PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
4918cf1cd9SJacques Pienaar       : set(other.set) {
5018cf1cd9SJacques Pienaar     other.set.ptr = nullptr;
5118cf1cd9SJacques Pienaar   }
5218cf1cd9SJacques Pienaar   ~PyFrozenRewritePatternSet() {
5318cf1cd9SJacques Pienaar     if (set.ptr != nullptr)
5418cf1cd9SJacques Pienaar       mlirFrozenRewritePatternSetDestroy(set);
5518cf1cd9SJacques Pienaar   }
5618cf1cd9SJacques Pienaar   MlirFrozenRewritePatternSet get() { return set; }
5718cf1cd9SJacques Pienaar 
58b56d1ec6SPeter Hawkins   nb::object getCapsule() {
59b56d1ec6SPeter Hawkins     return nb::steal<nb::object>(
6018cf1cd9SJacques Pienaar         mlirPythonFrozenRewritePatternSetToCapsule(get()));
6118cf1cd9SJacques Pienaar   }
6218cf1cd9SJacques Pienaar 
63b56d1ec6SPeter Hawkins   static nb::object createFromCapsule(nb::object capsule) {
6418cf1cd9SJacques Pienaar     MlirFrozenRewritePatternSet rawPm =
6518cf1cd9SJacques Pienaar         mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
6618cf1cd9SJacques Pienaar     if (rawPm.ptr == nullptr)
67b56d1ec6SPeter Hawkins       throw nb::python_error();
68b56d1ec6SPeter Hawkins     return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
6918cf1cd9SJacques Pienaar   }
7018cf1cd9SJacques Pienaar 
7118cf1cd9SJacques Pienaar private:
7218cf1cd9SJacques Pienaar   MlirFrozenRewritePatternSet set;
7318cf1cd9SJacques Pienaar };
7418cf1cd9SJacques Pienaar 
7518cf1cd9SJacques Pienaar } // namespace
7618cf1cd9SJacques Pienaar 
7718cf1cd9SJacques Pienaar /// Create the `mlir.rewrite` here.
78b56d1ec6SPeter Hawkins void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
7918cf1cd9SJacques Pienaar   //----------------------------------------------------------------------------
8018cf1cd9SJacques Pienaar   // Mapping of the top-level PassManager
8118cf1cd9SJacques Pienaar   //----------------------------------------------------------------------------
8218cf1cd9SJacques Pienaar #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
83b56d1ec6SPeter Hawkins   nb::class_<PyPDLPatternModule>(m, "PDLModule")
84b56d1ec6SPeter Hawkins       .def(
85b56d1ec6SPeter Hawkins           "__init__",
86b56d1ec6SPeter Hawkins           [](PyPDLPatternModule &self, MlirModule module) {
87b56d1ec6SPeter Hawkins             new (&self)
88b56d1ec6SPeter Hawkins                 PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
89b56d1ec6SPeter Hawkins           },
9018cf1cd9SJacques Pienaar           "module"_a, "Create a PDL module from the given module.")
9118cf1cd9SJacques Pienaar       .def("freeze", [](PyPDLPatternModule &self) {
9218cf1cd9SJacques Pienaar         return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
9318cf1cd9SJacques Pienaar             mlirRewritePatternSetFromPDLPatternModule(self.get())));
9418cf1cd9SJacques Pienaar       });
95b56d1ec6SPeter Hawkins #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
96b56d1ec6SPeter Hawkins   nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
97b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
9818cf1cd9SJacques Pienaar                    &PyFrozenRewritePatternSet::getCapsule)
9918cf1cd9SJacques Pienaar       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
10018cf1cd9SJacques Pienaar            &PyFrozenRewritePatternSet::createFromCapsule);
10118cf1cd9SJacques Pienaar   m.def(
10218cf1cd9SJacques Pienaar       "apply_patterns_and_fold_greedily",
10318cf1cd9SJacques Pienaar       [](MlirModule module, MlirFrozenRewritePatternSet set) {
10418cf1cd9SJacques Pienaar         auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
10518cf1cd9SJacques Pienaar         if (mlirLogicalResultIsFailure(status))
10618cf1cd9SJacques Pienaar           // FIXME: Not sure this is the right error to throw here.
107b56d1ec6SPeter Hawkins           throw nb::value_error("pattern application failed to converge");
10818cf1cd9SJacques Pienaar       },
10918cf1cd9SJacques Pienaar       "module"_a, "set"_a,
11018cf1cd9SJacques Pienaar       "Applys the given patterns to the given module greedily while folding "
11118cf1cd9SJacques Pienaar       "results.");
11218cf1cd9SJacques Pienaar }
113