1 //===- Rewrite.cpp - Rewrite ----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Rewrite.h" 10 11 #include "IRModule.h" 12 #include "mlir-c/Rewrite.h" 13 #include "mlir/Bindings/Python/Nanobind.h" 14 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. 15 #include "mlir/Config/mlir-config.h" 16 17 namespace nb = nanobind; 18 using namespace mlir; 19 using namespace nb::literals; 20 using namespace mlir::python; 21 22 namespace { 23 24 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH 25 /// Owning Wrapper around a PDLPatternModule. 26 class PyPDLPatternModule { 27 public: 28 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {} 29 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept 30 : module(other.module) { 31 other.module.ptr = nullptr; 32 } 33 ~PyPDLPatternModule() { 34 if (module.ptr != nullptr) 35 mlirPDLPatternModuleDestroy(module); 36 } 37 MlirPDLPatternModule get() { return module; } 38 39 private: 40 MlirPDLPatternModule module; 41 }; 42 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH 43 44 /// Owning Wrapper around a FrozenRewritePatternSet. 45 class PyFrozenRewritePatternSet { 46 public: 47 PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {} 48 PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept 49 : set(other.set) { 50 other.set.ptr = nullptr; 51 } 52 ~PyFrozenRewritePatternSet() { 53 if (set.ptr != nullptr) 54 mlirFrozenRewritePatternSetDestroy(set); 55 } 56 MlirFrozenRewritePatternSet get() { return set; } 57 58 nb::object getCapsule() { 59 return nb::steal<nb::object>( 60 mlirPythonFrozenRewritePatternSetToCapsule(get())); 61 } 62 63 static nb::object createFromCapsule(nb::object capsule) { 64 MlirFrozenRewritePatternSet rawPm = 65 mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); 66 if (rawPm.ptr == nullptr) 67 throw nb::python_error(); 68 return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move); 69 } 70 71 private: 72 MlirFrozenRewritePatternSet set; 73 }; 74 75 } // namespace 76 77 /// Create the `mlir.rewrite` here. 78 void mlir::python::populateRewriteSubmodule(nb::module_ &m) { 79 //---------------------------------------------------------------------------- 80 // Mapping of the top-level PassManager 81 //---------------------------------------------------------------------------- 82 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH 83 nb::class_<PyPDLPatternModule>(m, "PDLModule") 84 .def( 85 "__init__", 86 [](PyPDLPatternModule &self, MlirModule module) { 87 new (&self) 88 PyPDLPatternModule(mlirPDLPatternModuleFromModule(module)); 89 }, 90 "module"_a, "Create a PDL module from the given module.") 91 .def("freeze", [](PyPDLPatternModule &self) { 92 return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( 93 mlirRewritePatternSetFromPDLPatternModule(self.get()))); 94 }); 95 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH 96 nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet") 97 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, 98 &PyFrozenRewritePatternSet::getCapsule) 99 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, 100 &PyFrozenRewritePatternSet::createFromCapsule); 101 m.def( 102 "apply_patterns_and_fold_greedily", 103 [](MlirModule module, MlirFrozenRewritePatternSet set) { 104 auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); 105 if (mlirLogicalResultIsFailure(status)) 106 // FIXME: Not sure this is the right error to throw here. 107 throw nb::value_error("pattern application failed to converge"); 108 }, 109 "module"_a, "set"_a, 110 "Applys the given patterns to the given module greedily while folding " 111 "results."); 112 } 113