xref: /llvm-project/mlir/lib/Bindings/Python/MainModule.cpp (revision 0cdf4915019a8ebc6570229cf140ad879dfaef56)
1 //===- MainModule.cpp - Main pybind module --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <tuple>
10 
11 #include "PybindUtils.h"
12 
13 #include "Dialects.h"
14 #include "Globals.h"
15 #include "IRModule.h"
16 #include "Pass.h"
17 
18 namespace py = pybind11;
19 using namespace mlir;
20 using namespace mlir::python;
21 
22 // -----------------------------------------------------------------------------
23 // Module initialization.
24 // -----------------------------------------------------------------------------
25 
26 PYBIND11_MODULE(_mlir, m) {
27   m.doc() = "MLIR Python Native Extension";
28 
29   py::class_<PyGlobals>(m, "_Globals")
30       .def_property("dialect_search_modules",
31                     &PyGlobals::getDialectSearchPrefixes,
32                     &PyGlobals::setDialectSearchPrefixes)
33       .def("append_dialect_search_prefix",
34            [](PyGlobals &self, std::string moduleName) {
35              self.getDialectSearchPrefixes().push_back(std::move(moduleName));
36              self.clearImportCache();
37            })
38       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
39            "Testing hook for directly registering a dialect")
40       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
41            "Testing hook for directly registering an operation");
42 
43   // Aside from making the globals accessible to python, having python manage
44   // it is necessary to make sure it is destroyed (and releases its python
45   // resources) properly.
46   m.attr("globals") =
47       py::cast(new PyGlobals, py::return_value_policy::take_ownership);
48 
49   // Registration decorators.
50   m.def(
51       "register_dialect",
52       [](py::object pyClass) {
53         std::string dialectNamespace =
54             pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
55         PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
56         return pyClass;
57       },
58       "Class decorator for registering a custom Dialect wrapper");
59   m.def(
60       "register_operation",
61       [](py::object dialectClass) -> py::cpp_function {
62         return py::cpp_function(
63             [dialectClass](py::object opClass) -> py::object {
64               std::string operationName =
65                   opClass.attr("OPERATION_NAME").cast<std::string>();
66               auto rawSubclass = PyOpView::createRawSubclass(opClass);
67               PyGlobals::get().registerOperationImpl(operationName, opClass,
68                                                      rawSubclass);
69 
70               // Dict-stuff the new opClass by name onto the dialect class.
71               py::object opClassName = opClass.attr("__name__");
72               dialectClass.attr(opClassName) = opClass;
73 
74               // Now create a special "Raw" subclass that passes through
75               // construction to the OpView parent (bypasses the intermediate
76               // child's __init__).
77               opClass.attr("_Raw") = rawSubclass;
78               return opClass;
79             });
80       },
81       "Class decorator for registering a custom Operation wrapper");
82 
83   // Define and populate IR submodule.
84   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
85   populateIRCore(irModule);
86   populateIRAffine(irModule);
87   populateIRAttributes(irModule);
88   populateIRTypes(irModule);
89 
90   // Define and populate PassManager submodule.
91   auto passModule =
92       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
93   populatePassManagerSubmodule(passModule);
94 
95   // Define and populate dialect submodules.
96   auto dialectsModule = m.def_submodule("dialects");
97   auto linalgModule = dialectsModule.def_submodule("linalg");
98   populateDialectLinalgSubmodule(linalgModule);
99   populateDialectSparseTensorSubmodule(
100       dialectsModule.def_submodule("sparse_tensor"), irModule);
101 }
102