xref: /llvm-project/mlir/lib/Bindings/Python/Rewrite.cpp (revision 18cf1cd92b554ba0b870c6a2223ea4d0d3c6dd21)
1 //===- Rewrite.cpp - Rewrite ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Rewrite.h"
10 
11 #include "IRModule.h"
12 #include "mlir-c/Bindings/Python/Interop.h"
13 #include "mlir-c/Rewrite.h"
14 #include "mlir/Config/mlir-config.h"
15 
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace py::literals;
19 using namespace mlir::python;
20 
21 namespace {
22 
23 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
24 /// Owning Wrapper around a PDLPatternModule.
25 class PyPDLPatternModule {
26 public:
27   PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
28   PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
29       : module(other.module) {
30     other.module.ptr = nullptr;
31   }
32   ~PyPDLPatternModule() {
33     if (module.ptr != nullptr)
34       mlirPDLPatternModuleDestroy(module);
35   }
36   MlirPDLPatternModule get() { return module; }
37 
38 private:
39   MlirPDLPatternModule module;
40 };
41 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
42 
43 /// Owning Wrapper around a FrozenRewritePatternSet.
44 class PyFrozenRewritePatternSet {
45 public:
46   PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
47   PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
48       : set(other.set) {
49     other.set.ptr = nullptr;
50   }
51   ~PyFrozenRewritePatternSet() {
52     if (set.ptr != nullptr)
53       mlirFrozenRewritePatternSetDestroy(set);
54   }
55   MlirFrozenRewritePatternSet get() { return set; }
56 
57   pybind11::object getCapsule() {
58     return py::reinterpret_steal<py::object>(
59         mlirPythonFrozenRewritePatternSetToCapsule(get()));
60   }
61 
62   static pybind11::object createFromCapsule(pybind11::object capsule) {
63     MlirFrozenRewritePatternSet rawPm =
64         mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
65     if (rawPm.ptr == nullptr)
66       throw py::error_already_set();
67     return py::cast(PyFrozenRewritePatternSet(rawPm),
68                     py::return_value_policy::move);
69   }
70 
71 private:
72   MlirFrozenRewritePatternSet set;
73 };
74 
75 } // namespace
76 
77 /// Create the `mlir.rewrite` here.
78 void mlir::python::populateRewriteSubmodule(py::module &m) {
79   //----------------------------------------------------------------------------
80   // Mapping of the top-level PassManager
81   //----------------------------------------------------------------------------
82 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
83   py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
84       .def(py::init<>([](MlirModule module) {
85              return mlirPDLPatternModuleFromModule(module);
86            }),
87            "module"_a, "Create a PDL module from the given module.")
88       .def("freeze", [](PyPDLPatternModule &self) {
89         return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
90             mlirRewritePatternSetFromPDLPatternModule(self.get())));
91       });
92 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
93   py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
94                                         py::module_local())
95       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
96                              &PyFrozenRewritePatternSet::getCapsule)
97       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
98            &PyFrozenRewritePatternSet::createFromCapsule);
99   m.def(
100       "apply_patterns_and_fold_greedily",
101       [](MlirModule module, MlirFrozenRewritePatternSet set) {
102         auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
103         if (mlirLogicalResultIsFailure(status))
104           // FIXME: Not sure this is the right error to throw here.
105           throw py::value_error("pattern application failed to converge");
106       },
107       "module"_a, "set"_a,
108       "Applys the given patterns to the given module greedily while folding "
109       "results.");
110 }
111