xref: /llvm-project/mlir/lib/Bindings/Python/MainModule.cpp (revision b56d1ec6cb8b5cb3ff46cba39a1049ecf3831afb)
1 //===- MainModule.cpp - Main pybind module --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <nanobind/nanobind.h>
10 #include <nanobind/stl/string.h>
11 
12 #include "Globals.h"
13 #include "IRModule.h"
14 #include "NanobindUtils.h"
15 #include "Pass.h"
16 #include "Rewrite.h"
17 
18 namespace nb = nanobind;
19 using namespace mlir;
20 using namespace nb::literals;
21 using namespace mlir::python;
22 
23 // -----------------------------------------------------------------------------
24 // Module initialization.
25 // -----------------------------------------------------------------------------
26 
27 NB_MODULE(_mlir, m) {
28   m.doc() = "MLIR Python Native Extension";
29 
30   nb::class_<PyGlobals>(m, "_Globals")
31       .def_prop_rw("dialect_search_modules",
32                    &PyGlobals::getDialectSearchPrefixes,
33                    &PyGlobals::setDialectSearchPrefixes)
34       .def(
35           "append_dialect_search_prefix",
36           [](PyGlobals &self, std::string moduleName) {
37             self.getDialectSearchPrefixes().push_back(std::move(moduleName));
38           },
39           "module_name"_a)
40       .def(
41           "_check_dialect_module_loaded",
42           [](PyGlobals &self, const std::string &dialectNamespace) {
43             return self.loadDialectModule(dialectNamespace);
44           },
45           "dialect_namespace"_a)
46       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
47            "dialect_namespace"_a, "dialect_class"_a,
48            "Testing hook for directly registering a dialect")
49       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
50            "operation_name"_a, "operation_class"_a, nb::kw_only(),
51            "replace"_a = false,
52            "Testing hook for directly registering an operation");
53 
54   // Aside from making the globals accessible to python, having python manage
55   // it is necessary to make sure it is destroyed (and releases its python
56   // resources) properly.
57   m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
58 
59   // Registration decorators.
60   m.def(
61       "register_dialect",
62       [](nb::type_object pyClass) {
63         std::string dialectNamespace =
64             nanobind::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
65         PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
66         return pyClass;
67       },
68       "dialect_class"_a,
69       "Class decorator for registering a custom Dialect wrapper");
70   m.def(
71       "register_operation",
72       [](const nb::type_object &dialectClass, bool replace) -> nb::object {
73         return nb::cpp_function(
74             [dialectClass,
75              replace](nb::type_object opClass) -> nb::type_object {
76               std::string operationName =
77                   nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
78               PyGlobals::get().registerOperationImpl(operationName, opClass,
79                                                      replace);
80 
81               // Dict-stuff the new opClass by name onto the dialect class.
82               nb::object opClassName = opClass.attr("__name__");
83               dialectClass.attr(opClassName) = opClass;
84               return opClass;
85             });
86       },
87       "dialect_class"_a, nb::kw_only(), "replace"_a = false,
88       "Produce a class decorator for registering an Operation class as part of "
89       "a dialect");
90   m.def(
91       MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
92       [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
93         return nb::cpp_function([mlirTypeID, replace](
94                                     nb::callable typeCaster) -> nb::object {
95           PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
96           return typeCaster;
97         });
98       },
99       "typeid"_a, nb::kw_only(), "replace"_a = false,
100       "Register a type caster for casting MLIR types to custom user types.");
101   m.def(
102       MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
103       [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
104         return nb::cpp_function(
105             [mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
106               PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
107                                                    replace);
108               return valueCaster;
109             });
110       },
111       "typeid"_a, nb::kw_only(), "replace"_a = false,
112       "Register a value caster for casting MLIR values to custom user values.");
113 
114   // Define and populate IR submodule.
115   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
116   populateIRCore(irModule);
117   populateIRAffine(irModule);
118   populateIRAttributes(irModule);
119   populateIRInterfaces(irModule);
120   populateIRTypes(irModule);
121 
122   auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings");
123   populateRewriteSubmodule(rewriteModule);
124 
125   // Define and populate PassManager submodule.
126   auto passModule =
127       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
128   populatePassManagerSubmodule(passModule);
129 }
130