1722475a3SStella Laurenzo //===- MainModule.cpp - Main pybind module --------------------------------===// 2722475a3SStella Laurenzo // 3722475a3SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4722475a3SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5722475a3SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6722475a3SStella Laurenzo // 7722475a3SStella Laurenzo //===----------------------------------------------------------------------===// 8722475a3SStella Laurenzo 9722475a3SStella Laurenzo 10013b9322SStella Laurenzo #include "Globals.h" 11436c6c9cSStella Laurenzo #include "IRModule.h" 12b56d1ec6SPeter Hawkins #include "NanobindUtils.h" 13dc43f785SMehdi Amini #include "Pass.h" 1418cf1cd9SJacques Pienaar #include "Rewrite.h" 155cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h" 16722475a3SStella Laurenzo 17b56d1ec6SPeter Hawkins namespace nb = nanobind; 18722475a3SStella Laurenzo using namespace mlir; 19b56d1ec6SPeter Hawkins using namespace nb::literals; 2095b77f2eSStella Laurenzo using namespace mlir::python; 21722475a3SStella Laurenzo 22013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 23013b9322SStella Laurenzo // Module initialization. 24013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 25013b9322SStella Laurenzo 26b56d1ec6SPeter Hawkins NB_MODULE(_mlir, m) { 27722475a3SStella Laurenzo m.doc() = "MLIR Python Native Extension"; 28722475a3SStella Laurenzo 29b56d1ec6SPeter Hawkins nb::class_<PyGlobals>(m, "_Globals") 30b56d1ec6SPeter Hawkins .def_prop_rw("dialect_search_modules", 31013b9322SStella Laurenzo &PyGlobals::getDialectSearchPrefixes, 32013b9322SStella Laurenzo &PyGlobals::setDialectSearchPrefixes) 33*f136c800Svfdev .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix, 34bfb1ba75Smax "module_name"_a) 355192e299SMaksim Levental .def( 365192e299SMaksim Levental "_check_dialect_module_loaded", 375192e299SMaksim Levental [](PyGlobals &self, const std::string &dialectNamespace) { 385192e299SMaksim Levental return self.loadDialectModule(dialectNamespace); 395192e299SMaksim Levental }, 405192e299SMaksim Levental "dialect_namespace"_a) 41013b9322SStella Laurenzo .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 42bfb1ba75Smax "dialect_namespace"_a, "dialect_class"_a, 43013b9322SStella Laurenzo "Testing hook for directly registering a dialect") 44013b9322SStella Laurenzo .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 45b56d1ec6SPeter Hawkins "operation_name"_a, "operation_class"_a, nb::kw_only(), 467c850867SMaksim Levental "replace"_a = false, 47013b9322SStella Laurenzo "Testing hook for directly registering an operation"); 48013b9322SStella Laurenzo 49013b9322SStella Laurenzo // Aside from making the globals accessible to python, having python manage 50013b9322SStella Laurenzo // it is necessary to make sure it is destroyed (and releases its python 51013b9322SStella Laurenzo // resources) properly. 52b56d1ec6SPeter Hawkins m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership); 53013b9322SStella Laurenzo 54013b9322SStella Laurenzo // Registration decorators. 55013b9322SStella Laurenzo m.def( 56013b9322SStella Laurenzo "register_dialect", 57b56d1ec6SPeter Hawkins [](nb::type_object pyClass) { 58013b9322SStella Laurenzo std::string dialectNamespace = 59b56d1ec6SPeter Hawkins nanobind::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE")); 60013b9322SStella Laurenzo PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 61013b9322SStella Laurenzo return pyClass; 62013b9322SStella Laurenzo }, 63bfb1ba75Smax "dialect_class"_a, 64013b9322SStella Laurenzo "Class decorator for registering a custom Dialect wrapper"); 65013b9322SStella Laurenzo m.def( 66013b9322SStella Laurenzo "register_operation", 67b56d1ec6SPeter Hawkins [](const nb::type_object &dialectClass, bool replace) -> nb::object { 68b56d1ec6SPeter Hawkins return nb::cpp_function( 69b56d1ec6SPeter Hawkins [dialectClass, 70b56d1ec6SPeter Hawkins replace](nb::type_object opClass) -> nb::type_object { 71013b9322SStella Laurenzo std::string operationName = 72b56d1ec6SPeter Hawkins nanobind::cast<std::string>(opClass.attr("OPERATION_NAME")); 73a2288a89SMaksim Levental PyGlobals::get().registerOperationImpl(operationName, opClass, 74a2288a89SMaksim Levental replace); 75013b9322SStella Laurenzo // Dict-stuff the new opClass by name onto the dialect class. 76b56d1ec6SPeter Hawkins nb::object opClassName = opClass.attr("__name__"); 77013b9322SStella Laurenzo dialectClass.attr(opClassName) = opClass; 78013b9322SStella Laurenzo return opClass; 79013b9322SStella Laurenzo }); 80013b9322SStella Laurenzo }, 81b56d1ec6SPeter Hawkins "dialect_class"_a, nb::kw_only(), "replace"_a = false, 82a6e7d024SStella Laurenzo "Produce a class decorator for registering an Operation class as part of " 83a6e7d024SStella Laurenzo "a dialect"); 84bfb1ba75Smax m.def( 85bfb1ba75Smax MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, 86b56d1ec6SPeter Hawkins [](MlirTypeID mlirTypeID, bool replace) -> nb::object { 87b56d1ec6SPeter Hawkins return nb::cpp_function([mlirTypeID, replace]( 88b56d1ec6SPeter Hawkins nb::callable typeCaster) -> nb::object { 897c850867SMaksim Levental PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); 907c850867SMaksim Levental return typeCaster; 917c850867SMaksim Levental }); 92bfb1ba75Smax }, 93b56d1ec6SPeter Hawkins "typeid"_a, nb::kw_only(), "replace"_a = false, 94bfb1ba75Smax "Register a type caster for casting MLIR types to custom user types."); 957c850867SMaksim Levental m.def( 967c850867SMaksim Levental MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, 97b56d1ec6SPeter Hawkins [](MlirTypeID mlirTypeID, bool replace) -> nb::object { 98b56d1ec6SPeter Hawkins return nb::cpp_function( 99b56d1ec6SPeter Hawkins [mlirTypeID, replace](nb::callable valueCaster) -> nb::object { 1007c850867SMaksim Levental PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster, 1017c850867SMaksim Levental replace); 1027c850867SMaksim Levental return valueCaster; 1037c850867SMaksim Levental }); 1047c850867SMaksim Levental }, 105b56d1ec6SPeter Hawkins "typeid"_a, nb::kw_only(), "replace"_a = false, 1067c850867SMaksim Levental "Register a value caster for casting MLIR values to custom user values."); 107013b9322SStella Laurenzo 108fcd2969dSzhanghb97 // Define and populate IR submodule. 109fcd2969dSzhanghb97 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 110436c6c9cSStella Laurenzo populateIRCore(irModule); 111436c6c9cSStella Laurenzo populateIRAffine(irModule); 112436c6c9cSStella Laurenzo populateIRAttributes(irModule); 11314c92070SAlex Zinenko populateIRInterfaces(irModule); 114436c6c9cSStella Laurenzo populateIRTypes(irModule); 115dc43f785SMehdi Amini 11618cf1cd9SJacques Pienaar auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings"); 11718cf1cd9SJacques Pienaar populateRewriteSubmodule(rewriteModule); 11818cf1cd9SJacques Pienaar 119dc43f785SMehdi Amini // Define and populate PassManager submodule. 120dc43f785SMehdi Amini auto passModule = 121dc43f785SMehdi Amini m.def_submodule("passmanager", "MLIR Pass Management Bindings"); 122dc43f785SMehdi Amini populatePassManagerSubmodule(passModule); 123722475a3SStella Laurenzo } 124