1 //===- Rewrite.cpp - Rewrite ----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Rewrite.h" 10 11 #include <nanobind/nanobind.h> 12 13 #include "IRModule.h" 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/Rewrite.h" 16 #include "mlir/Config/mlir-config.h" 17 18 namespace nb = nanobind; 19 using namespace mlir; 20 using namespace nb::literals; 21 using namespace mlir::python; 22 23 namespace { 24 25 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH 26 /// Owning Wrapper around a PDLPatternModule. 27 class PyPDLPatternModule { 28 public: 29 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {} 30 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept 31 : module(other.module) { 32 other.module.ptr = nullptr; 33 } 34 ~PyPDLPatternModule() { 35 if (module.ptr != nullptr) 36 mlirPDLPatternModuleDestroy(module); 37 } 38 MlirPDLPatternModule get() { return module; } 39 40 private: 41 MlirPDLPatternModule module; 42 }; 43 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH 44 45 /// Owning Wrapper around a FrozenRewritePatternSet. 46 class PyFrozenRewritePatternSet { 47 public: 48 PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {} 49 PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept 50 : set(other.set) { 51 other.set.ptr = nullptr; 52 } 53 ~PyFrozenRewritePatternSet() { 54 if (set.ptr != nullptr) 55 mlirFrozenRewritePatternSetDestroy(set); 56 } 57 MlirFrozenRewritePatternSet get() { return set; } 58 59 nb::object getCapsule() { 60 return nb::steal<nb::object>( 61 mlirPythonFrozenRewritePatternSetToCapsule(get())); 62 } 63 64 static nb::object createFromCapsule(nb::object capsule) { 65 MlirFrozenRewritePatternSet rawPm = 66 mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); 67 if (rawPm.ptr == nullptr) 68 throw nb::python_error(); 69 return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move); 70 } 71 72 private: 73 MlirFrozenRewritePatternSet set; 74 }; 75 76 } // namespace 77 78 /// Create the `mlir.rewrite` here. 79 void mlir::python::populateRewriteSubmodule(nb::module_ &m) { 80 //---------------------------------------------------------------------------- 81 // Mapping of the top-level PassManager 82 //---------------------------------------------------------------------------- 83 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH 84 nb::class_<PyPDLPatternModule>(m, "PDLModule") 85 .def( 86 "__init__", 87 [](PyPDLPatternModule &self, MlirModule module) { 88 new (&self) 89 PyPDLPatternModule(mlirPDLPatternModuleFromModule(module)); 90 }, 91 "module"_a, "Create a PDL module from the given module.") 92 .def("freeze", [](PyPDLPatternModule &self) { 93 return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( 94 mlirRewritePatternSetFromPDLPatternModule(self.get()))); 95 }); 96 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH 97 nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet") 98 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, 99 &PyFrozenRewritePatternSet::getCapsule) 100 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, 101 &PyFrozenRewritePatternSet::createFromCapsule); 102 m.def( 103 "apply_patterns_and_fold_greedily", 104 [](MlirModule module, MlirFrozenRewritePatternSet set) { 105 auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); 106 if (mlirLogicalResultIsFailure(status)) 107 // FIXME: Not sure this is the right error to throw here. 108 throw nb::value_error("pattern application failed to converge"); 109 }, 110 "module"_a, "set"_a, 111 "Applys the given patterns to the given module greedily while folding " 112 "results."); 113 } 114