xref: /llvm-project/mlir/lib/Bindings/Python/Rewrite.cpp (revision b56d1ec6cb8b5cb3ff46cba39a1049ecf3831afb)
1 //===- Rewrite.cpp - Rewrite ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Rewrite.h"
10 
11 #include <nanobind/nanobind.h>
12 
13 #include "IRModule.h"
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/Rewrite.h"
16 #include "mlir/Config/mlir-config.h"
17 
18 namespace nb = nanobind;
19 using namespace mlir;
20 using namespace nb::literals;
21 using namespace mlir::python;
22 
23 namespace {
24 
25 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
26 /// Owning Wrapper around a PDLPatternModule.
27 class PyPDLPatternModule {
28 public:
29   PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
30   PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
31       : module(other.module) {
32     other.module.ptr = nullptr;
33   }
34   ~PyPDLPatternModule() {
35     if (module.ptr != nullptr)
36       mlirPDLPatternModuleDestroy(module);
37   }
38   MlirPDLPatternModule get() { return module; }
39 
40 private:
41   MlirPDLPatternModule module;
42 };
43 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
44 
45 /// Owning Wrapper around a FrozenRewritePatternSet.
46 class PyFrozenRewritePatternSet {
47 public:
48   PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
49   PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
50       : set(other.set) {
51     other.set.ptr = nullptr;
52   }
53   ~PyFrozenRewritePatternSet() {
54     if (set.ptr != nullptr)
55       mlirFrozenRewritePatternSetDestroy(set);
56   }
57   MlirFrozenRewritePatternSet get() { return set; }
58 
59   nb::object getCapsule() {
60     return nb::steal<nb::object>(
61         mlirPythonFrozenRewritePatternSetToCapsule(get()));
62   }
63 
64   static nb::object createFromCapsule(nb::object capsule) {
65     MlirFrozenRewritePatternSet rawPm =
66         mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
67     if (rawPm.ptr == nullptr)
68       throw nb::python_error();
69     return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
70   }
71 
72 private:
73   MlirFrozenRewritePatternSet set;
74 };
75 
76 } // namespace
77 
78 /// Create the `mlir.rewrite` here.
79 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
80   //----------------------------------------------------------------------------
81   // Mapping of the top-level PassManager
82   //----------------------------------------------------------------------------
83 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
84   nb::class_<PyPDLPatternModule>(m, "PDLModule")
85       .def(
86           "__init__",
87           [](PyPDLPatternModule &self, MlirModule module) {
88             new (&self)
89                 PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
90           },
91           "module"_a, "Create a PDL module from the given module.")
92       .def("freeze", [](PyPDLPatternModule &self) {
93         return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
94             mlirRewritePatternSetFromPDLPatternModule(self.get())));
95       });
96 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
97   nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
98       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
99                    &PyFrozenRewritePatternSet::getCapsule)
100       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
101            &PyFrozenRewritePatternSet::createFromCapsule);
102   m.def(
103       "apply_patterns_and_fold_greedily",
104       [](MlirModule module, MlirFrozenRewritePatternSet set) {
105         auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
106         if (mlirLogicalResultIsFailure(status))
107           // FIXME: Not sure this is the right error to throw here.
108           throw nb::value_error("pattern application failed to converge");
109       },
110       "module"_a, "set"_a,
111       "Applys the given patterns to the given module greedily while folding "
112       "results.");
113 }
114