1 //===- MainModule.cpp - Main pybind module --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 10 #include "Globals.h" 11 #include "IRModule.h" 12 #include "NanobindUtils.h" 13 #include "Pass.h" 14 #include "Rewrite.h" 15 #include "mlir/Bindings/Python/Nanobind.h" 16 17 namespace nb = nanobind; 18 using namespace mlir; 19 using namespace nb::literals; 20 using namespace mlir::python; 21 22 // ----------------------------------------------------------------------------- 23 // Module initialization. 24 // ----------------------------------------------------------------------------- 25 26 NB_MODULE(_mlir, m) { 27 m.doc() = "MLIR Python Native Extension"; 28 29 nb::class_<PyGlobals>(m, "_Globals") 30 .def_prop_rw("dialect_search_modules", 31 &PyGlobals::getDialectSearchPrefixes, 32 &PyGlobals::setDialectSearchPrefixes) 33 .def( 34 "append_dialect_search_prefix", 35 [](PyGlobals &self, std::string moduleName) { 36 self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 37 }, 38 "module_name"_a) 39 .def( 40 "_check_dialect_module_loaded", 41 [](PyGlobals &self, const std::string &dialectNamespace) { 42 return self.loadDialectModule(dialectNamespace); 43 }, 44 "dialect_namespace"_a) 45 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 46 "dialect_namespace"_a, "dialect_class"_a, 47 "Testing hook for directly registering a dialect") 48 .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 49 "operation_name"_a, "operation_class"_a, nb::kw_only(), 50 "replace"_a = false, 51 "Testing hook for directly registering an operation"); 52 53 // Aside from making the globals accessible to python, having python manage 54 // it is necessary to make sure it is destroyed (and releases its python 55 // resources) properly. 56 m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership); 57 58 // Registration decorators. 59 m.def( 60 "register_dialect", 61 [](nb::type_object pyClass) { 62 std::string dialectNamespace = 63 nanobind::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE")); 64 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 65 return pyClass; 66 }, 67 "dialect_class"_a, 68 "Class decorator for registering a custom Dialect wrapper"); 69 m.def( 70 "register_operation", 71 [](const nb::type_object &dialectClass, bool replace) -> nb::object { 72 return nb::cpp_function( 73 [dialectClass, 74 replace](nb::type_object opClass) -> nb::type_object { 75 std::string operationName = 76 nanobind::cast<std::string>(opClass.attr("OPERATION_NAME")); 77 PyGlobals::get().registerOperationImpl(operationName, opClass, 78 replace); 79 80 // Dict-stuff the new opClass by name onto the dialect class. 81 nb::object opClassName = opClass.attr("__name__"); 82 dialectClass.attr(opClassName) = opClass; 83 return opClass; 84 }); 85 }, 86 "dialect_class"_a, nb::kw_only(), "replace"_a = false, 87 "Produce a class decorator for registering an Operation class as part of " 88 "a dialect"); 89 m.def( 90 MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, 91 [](MlirTypeID mlirTypeID, bool replace) -> nb::object { 92 return nb::cpp_function([mlirTypeID, replace]( 93 nb::callable typeCaster) -> nb::object { 94 PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); 95 return typeCaster; 96 }); 97 }, 98 "typeid"_a, nb::kw_only(), "replace"_a = false, 99 "Register a type caster for casting MLIR types to custom user types."); 100 m.def( 101 MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, 102 [](MlirTypeID mlirTypeID, bool replace) -> nb::object { 103 return nb::cpp_function( 104 [mlirTypeID, replace](nb::callable valueCaster) -> nb::object { 105 PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster, 106 replace); 107 return valueCaster; 108 }); 109 }, 110 "typeid"_a, nb::kw_only(), "replace"_a = false, 111 "Register a value caster for casting MLIR values to custom user values."); 112 113 // Define and populate IR submodule. 114 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 115 populateIRCore(irModule); 116 populateIRAffine(irModule); 117 populateIRAttributes(irModule); 118 populateIRInterfaces(irModule); 119 populateIRTypes(irModule); 120 121 auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings"); 122 populateRewriteSubmodule(rewriteModule); 123 124 // Define and populate PassManager submodule. 125 auto passModule = 126 m.def_submodule("passmanager", "MLIR Pass Management Bindings"); 127 populatePassManagerSubmodule(passModule); 128 } 129