xref: /llvm-project/mlir/lib/Bindings/Python/Rewrite.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1 //===- Rewrite.cpp - Rewrite ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Rewrite.h"
10 
11 #include "IRModule.h"
12 #include "mlir-c/Rewrite.h"
13 #include "mlir/Bindings/Python/Nanobind.h"
14 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
15 #include "mlir/Config/mlir-config.h"
16 
17 namespace nb = nanobind;
18 using namespace mlir;
19 using namespace nb::literals;
20 using namespace mlir::python;
21 
22 namespace {
23 
24 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
25 /// Owning Wrapper around a PDLPatternModule.
26 class PyPDLPatternModule {
27 public:
28   PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
29   PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
30       : module(other.module) {
31     other.module.ptr = nullptr;
32   }
33   ~PyPDLPatternModule() {
34     if (module.ptr != nullptr)
35       mlirPDLPatternModuleDestroy(module);
36   }
37   MlirPDLPatternModule get() { return module; }
38 
39 private:
40   MlirPDLPatternModule module;
41 };
42 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
43 
44 /// Owning Wrapper around a FrozenRewritePatternSet.
45 class PyFrozenRewritePatternSet {
46 public:
47   PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
48   PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
49       : set(other.set) {
50     other.set.ptr = nullptr;
51   }
52   ~PyFrozenRewritePatternSet() {
53     if (set.ptr != nullptr)
54       mlirFrozenRewritePatternSetDestroy(set);
55   }
56   MlirFrozenRewritePatternSet get() { return set; }
57 
58   nb::object getCapsule() {
59     return nb::steal<nb::object>(
60         mlirPythonFrozenRewritePatternSetToCapsule(get()));
61   }
62 
63   static nb::object createFromCapsule(nb::object capsule) {
64     MlirFrozenRewritePatternSet rawPm =
65         mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
66     if (rawPm.ptr == nullptr)
67       throw nb::python_error();
68     return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
69   }
70 
71 private:
72   MlirFrozenRewritePatternSet set;
73 };
74 
75 } // namespace
76 
77 /// Create the `mlir.rewrite` here.
78 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
79   //----------------------------------------------------------------------------
80   // Mapping of the top-level PassManager
81   //----------------------------------------------------------------------------
82 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
83   nb::class_<PyPDLPatternModule>(m, "PDLModule")
84       .def(
85           "__init__",
86           [](PyPDLPatternModule &self, MlirModule module) {
87             new (&self)
88                 PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
89           },
90           "module"_a, "Create a PDL module from the given module.")
91       .def("freeze", [](PyPDLPatternModule &self) {
92         return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
93             mlirRewritePatternSetFromPDLPatternModule(self.get())));
94       });
95 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
96   nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
97       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
98                    &PyFrozenRewritePatternSet::getCapsule)
99       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
100            &PyFrozenRewritePatternSet::createFromCapsule);
101   m.def(
102       "apply_patterns_and_fold_greedily",
103       [](MlirModule module, MlirFrozenRewritePatternSet set) {
104         auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
105         if (mlirLogicalResultIsFailure(status))
106           // FIXME: Not sure this is the right error to throw here.
107           throw nb::value_error("pattern application failed to converge");
108       },
109       "module"_a, "set"_a,
110       "Applys the given patterns to the given module greedily while folding "
111       "results.");
112 }
113