xref: /llvm-project/mlir/lib/Bindings/Python/MainModule.cpp (revision f136c800b60dbfacdbb645e7e92acba52e2f279f)
1 //===- MainModule.cpp - Main pybind module --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 
10 #include "Globals.h"
11 #include "IRModule.h"
12 #include "NanobindUtils.h"
13 #include "Pass.h"
14 #include "Rewrite.h"
15 #include "mlir/Bindings/Python/Nanobind.h"
16 
17 namespace nb = nanobind;
18 using namespace mlir;
19 using namespace nb::literals;
20 using namespace mlir::python;
21 
22 // -----------------------------------------------------------------------------
23 // Module initialization.
24 // -----------------------------------------------------------------------------
25 
26 NB_MODULE(_mlir, m) {
27   m.doc() = "MLIR Python Native Extension";
28 
29   nb::class_<PyGlobals>(m, "_Globals")
30       .def_prop_rw("dialect_search_modules",
31                    &PyGlobals::getDialectSearchPrefixes,
32                    &PyGlobals::setDialectSearchPrefixes)
33       .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
34            "module_name"_a)
35       .def(
36           "_check_dialect_module_loaded",
37           [](PyGlobals &self, const std::string &dialectNamespace) {
38             return self.loadDialectModule(dialectNamespace);
39           },
40           "dialect_namespace"_a)
41       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
42            "dialect_namespace"_a, "dialect_class"_a,
43            "Testing hook for directly registering a dialect")
44       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
45            "operation_name"_a, "operation_class"_a, nb::kw_only(),
46            "replace"_a = false,
47            "Testing hook for directly registering an operation");
48 
49   // Aside from making the globals accessible to python, having python manage
50   // it is necessary to make sure it is destroyed (and releases its python
51   // resources) properly.
52   m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
53 
54   // Registration decorators.
55   m.def(
56       "register_dialect",
57       [](nb::type_object pyClass) {
58         std::string dialectNamespace =
59             nanobind::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
60         PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
61         return pyClass;
62       },
63       "dialect_class"_a,
64       "Class decorator for registering a custom Dialect wrapper");
65   m.def(
66       "register_operation",
67       [](const nb::type_object &dialectClass, bool replace) -> nb::object {
68         return nb::cpp_function(
69             [dialectClass,
70              replace](nb::type_object opClass) -> nb::type_object {
71               std::string operationName =
72                   nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
73               PyGlobals::get().registerOperationImpl(operationName, opClass,
74                                                      replace);
75               // Dict-stuff the new opClass by name onto the dialect class.
76               nb::object opClassName = opClass.attr("__name__");
77               dialectClass.attr(opClassName) = opClass;
78               return opClass;
79             });
80       },
81       "dialect_class"_a, nb::kw_only(), "replace"_a = false,
82       "Produce a class decorator for registering an Operation class as part of "
83       "a dialect");
84   m.def(
85       MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
86       [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
87         return nb::cpp_function([mlirTypeID, replace](
88                                     nb::callable typeCaster) -> nb::object {
89           PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
90           return typeCaster;
91         });
92       },
93       "typeid"_a, nb::kw_only(), "replace"_a = false,
94       "Register a type caster for casting MLIR types to custom user types.");
95   m.def(
96       MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
97       [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
98         return nb::cpp_function(
99             [mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
100               PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
101                                                    replace);
102               return valueCaster;
103             });
104       },
105       "typeid"_a, nb::kw_only(), "replace"_a = false,
106       "Register a value caster for casting MLIR values to custom user values.");
107 
108   // Define and populate IR submodule.
109   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
110   populateIRCore(irModule);
111   populateIRAffine(irModule);
112   populateIRAttributes(irModule);
113   populateIRInterfaces(irModule);
114   populateIRTypes(irModule);
115 
116   auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings");
117   populateRewriteSubmodule(rewriteModule);
118 
119   // Define and populate PassManager submodule.
120   auto passModule =
121       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
122   populatePassManagerSubmodule(passModule);
123 }
124