xref: /llvm-project/mlir/lib/Bindings/Python/MainModule.cpp (revision 7c850867b9ef4427375da6d83c34d0b9c944fcb8)
1 //===- MainModule.cpp - Main pybind module --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PybindUtils.h"
10 
11 #include "Globals.h"
12 #include "IRModule.h"
13 #include "Pass.h"
14 
15 namespace py = pybind11;
16 using namespace mlir;
17 using namespace py::literals;
18 using namespace mlir::python;
19 
20 // -----------------------------------------------------------------------------
21 // Module initialization.
22 // -----------------------------------------------------------------------------
23 
24 PYBIND11_MODULE(_mlir, m) {
25   m.doc() = "MLIR Python Native Extension";
26 
27   py::class_<PyGlobals>(m, "_Globals", py::module_local())
28       .def_property("dialect_search_modules",
29                     &PyGlobals::getDialectSearchPrefixes,
30                     &PyGlobals::setDialectSearchPrefixes)
31       .def(
32           "append_dialect_search_prefix",
33           [](PyGlobals &self, std::string moduleName) {
34             self.getDialectSearchPrefixes().push_back(std::move(moduleName));
35           },
36           "module_name"_a)
37       .def(
38           "_check_dialect_module_loaded",
39           [](PyGlobals &self, const std::string &dialectNamespace) {
40             return self.loadDialectModule(dialectNamespace);
41           },
42           "dialect_namespace"_a)
43       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
44            "dialect_namespace"_a, "dialect_class"_a,
45            "Testing hook for directly registering a dialect")
46       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
47            "operation_name"_a, "operation_class"_a, py::kw_only(),
48            "replace"_a = false,
49            "Testing hook for directly registering an operation");
50 
51   // Aside from making the globals accessible to python, having python manage
52   // it is necessary to make sure it is destroyed (and releases its python
53   // resources) properly.
54   m.attr("globals") =
55       py::cast(new PyGlobals, py::return_value_policy::take_ownership);
56 
57   // Registration decorators.
58   m.def(
59       "register_dialect",
60       [](py::object pyClass) {
61         std::string dialectNamespace =
62             pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
63         PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
64         return pyClass;
65       },
66       "dialect_class"_a,
67       "Class decorator for registering a custom Dialect wrapper");
68   m.def(
69       "register_operation",
70       [](const py::object &dialectClass, bool replace) -> py::cpp_function {
71         return py::cpp_function(
72             [dialectClass, replace](py::object opClass) -> py::object {
73               std::string operationName =
74                   opClass.attr("OPERATION_NAME").cast<std::string>();
75               PyGlobals::get().registerOperationImpl(operationName, opClass,
76                                                      replace);
77 
78               // Dict-stuff the new opClass by name onto the dialect class.
79               py::object opClassName = opClass.attr("__name__");
80               dialectClass.attr(opClassName) = opClass;
81               return opClass;
82             });
83       },
84       "dialect_class"_a, py::kw_only(), "replace"_a = false,
85       "Produce a class decorator for registering an Operation class as part of "
86       "a dialect");
87   m.def(
88       MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
89       [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
90         return py::cpp_function([mlirTypeID,
91                                  replace](py::object typeCaster) -> py::object {
92           PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
93           return typeCaster;
94         });
95       },
96       "typeid"_a, py::kw_only(), "replace"_a = false,
97       "Register a type caster for casting MLIR types to custom user types.");
98   m.def(
99       MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
100       [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
101         return py::cpp_function(
102             [mlirTypeID, replace](py::object valueCaster) -> py::object {
103               PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
104                                                    replace);
105               return valueCaster;
106             });
107       },
108       "typeid"_a, py::kw_only(), "replace"_a = false,
109       "Register a value caster for casting MLIR values to custom user values.");
110 
111   // Define and populate IR submodule.
112   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
113   populateIRCore(irModule);
114   populateIRAffine(irModule);
115   populateIRAttributes(irModule);
116   populateIRInterfaces(irModule);
117   populateIRTypes(irModule);
118 
119   // Define and populate PassManager submodule.
120   auto passModule =
121       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
122   populatePassManagerSubmodule(passModule);
123 }
124