1 //===- Rewrite.cpp - Rewrite ----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Rewrite.h" 10 11 #include "IRModule.h" 12 #include "mlir-c/Bindings/Python/Interop.h" 13 #include "mlir-c/Rewrite.h" 14 #include "mlir/Config/mlir-config.h" 15 16 namespace py = pybind11; 17 using namespace mlir; 18 using namespace py::literals; 19 using namespace mlir::python; 20 21 namespace { 22 23 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH 24 /// Owning Wrapper around a PDLPatternModule. 25 class PyPDLPatternModule { 26 public: 27 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {} 28 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept 29 : module(other.module) { 30 other.module.ptr = nullptr; 31 } 32 ~PyPDLPatternModule() { 33 if (module.ptr != nullptr) 34 mlirPDLPatternModuleDestroy(module); 35 } 36 MlirPDLPatternModule get() { return module; } 37 38 private: 39 MlirPDLPatternModule module; 40 }; 41 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH 42 43 /// Owning Wrapper around a FrozenRewritePatternSet. 44 class PyFrozenRewritePatternSet { 45 public: 46 PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {} 47 PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept 48 : set(other.set) { 49 other.set.ptr = nullptr; 50 } 51 ~PyFrozenRewritePatternSet() { 52 if (set.ptr != nullptr) 53 mlirFrozenRewritePatternSetDestroy(set); 54 } 55 MlirFrozenRewritePatternSet get() { return set; } 56 57 pybind11::object getCapsule() { 58 return py::reinterpret_steal<py::object>( 59 mlirPythonFrozenRewritePatternSetToCapsule(get())); 60 } 61 62 static pybind11::object createFromCapsule(pybind11::object capsule) { 63 MlirFrozenRewritePatternSet rawPm = 64 mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); 65 if (rawPm.ptr == nullptr) 66 throw py::error_already_set(); 67 return py::cast(PyFrozenRewritePatternSet(rawPm), 68 py::return_value_policy::move); 69 } 70 71 private: 72 MlirFrozenRewritePatternSet set; 73 }; 74 75 } // namespace 76 77 /// Create the `mlir.rewrite` here. 78 void mlir::python::populateRewriteSubmodule(py::module &m) { 79 //---------------------------------------------------------------------------- 80 // Mapping of the top-level PassManager 81 //---------------------------------------------------------------------------- 82 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH 83 py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local()) 84 .def(py::init<>([](MlirModule module) { 85 return mlirPDLPatternModuleFromModule(module); 86 }), 87 "module"_a, "Create a PDL module from the given module.") 88 .def("freeze", [](PyPDLPatternModule &self) { 89 return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( 90 mlirRewritePatternSetFromPDLPatternModule(self.get()))); 91 }); 92 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg 93 py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet", 94 py::module_local()) 95 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 96 &PyFrozenRewritePatternSet::getCapsule) 97 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, 98 &PyFrozenRewritePatternSet::createFromCapsule); 99 m.def( 100 "apply_patterns_and_fold_greedily", 101 [](MlirModule module, MlirFrozenRewritePatternSet set) { 102 auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); 103 if (mlirLogicalResultIsFailure(status)) 104 // FIXME: Not sure this is the right error to throw here. 105 throw py::value_error("pattern application failed to converge"); 106 }, 107 "module"_a, "set"_a, 108 "Applys the given patterns to the given module greedily while folding " 109 "results."); 110 } 111