xref: /llvm-project/mlir/lib/Bindings/Python/MainModule.cpp (revision 43b9fa3ce0ddfa673158af1596c3aac613b258b3)
1 //===- MainModule.cpp - Main pybind module --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <tuple>
10 
11 #include "PybindUtils.h"
12 
13 #include "DialectLinalg.h"
14 #include "ExecutionEngine.h"
15 #include "Globals.h"
16 #include "IRModule.h"
17 #include "Pass.h"
18 
19 namespace py = pybind11;
20 using namespace mlir;
21 using namespace mlir::python;
22 
23 // -----------------------------------------------------------------------------
24 // PyGlobals
25 // -----------------------------------------------------------------------------
26 
27 PyGlobals *PyGlobals::instance = nullptr;
28 
29 PyGlobals::PyGlobals() {
30   assert(!instance && "PyGlobals already constructed");
31   instance = this;
32 }
33 
34 PyGlobals::~PyGlobals() { instance = nullptr; }
35 
36 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
37   py::gil_scoped_acquire();
38   if (loadedDialectModulesCache.contains(dialectNamespace))
39     return;
40   // Since re-entrancy is possible, make a copy of the search prefixes.
41   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
42   py::object loaded;
43   for (std::string moduleName : localSearchPrefixes) {
44     moduleName.push_back('.');
45     moduleName.append(dialectNamespace.data(), dialectNamespace.size());
46 
47     try {
48       py::gil_scoped_release();
49       loaded = py::module::import(moduleName.c_str());
50     } catch (py::error_already_set &e) {
51       if (e.matches(PyExc_ModuleNotFoundError)) {
52         continue;
53       } else {
54         throw;
55       }
56     }
57     break;
58   }
59 
60   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
61   // may have occurred, which may do anything.
62   loadedDialectModulesCache.insert(dialectNamespace);
63 }
64 
65 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
66                                     py::object pyClass) {
67   py::gil_scoped_acquire();
68   py::object &found = dialectClassMap[dialectNamespace];
69   if (found) {
70     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
71                                              dialectNamespace +
72                                              "' is already registered.");
73   }
74   found = std::move(pyClass);
75 }
76 
77 void PyGlobals::registerOperationImpl(const std::string &operationName,
78                                       py::object pyClass,
79                                       py::object rawOpViewClass) {
80   py::gil_scoped_acquire();
81   py::object &found = operationClassMap[operationName];
82   if (found) {
83     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
84                                              operationName +
85                                              "' is already registered.");
86   }
87   found = std::move(pyClass);
88   rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
89 }
90 
91 llvm::Optional<py::object>
92 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
93   py::gil_scoped_acquire();
94   loadDialectModule(dialectNamespace);
95   // Fast match against the class map first (common case).
96   const auto foundIt = dialectClassMap.find(dialectNamespace);
97   if (foundIt != dialectClassMap.end()) {
98     if (foundIt->second.is_none())
99       return llvm::None;
100     assert(foundIt->second && "py::object is defined");
101     return foundIt->second;
102   }
103 
104   // Not found and loading did not yield a registration. Negative cache.
105   dialectClassMap[dialectNamespace] = py::none();
106   return llvm::None;
107 }
108 
109 llvm::Optional<pybind11::object>
110 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
111   {
112     py::gil_scoped_acquire();
113     auto foundIt = rawOpViewClassMapCache.find(operationName);
114     if (foundIt != rawOpViewClassMapCache.end()) {
115       if (foundIt->second.is_none())
116         return llvm::None;
117       assert(foundIt->second && "py::object is defined");
118       return foundIt->second;
119     }
120   }
121 
122   // Not found. Load the dialect namespace.
123   auto split = operationName.split('.');
124   llvm::StringRef dialectNamespace = split.first;
125   loadDialectModule(dialectNamespace);
126 
127   // Attempt to find from the canonical map and cache.
128   {
129     py::gil_scoped_acquire();
130     auto foundIt = rawOpViewClassMap.find(operationName);
131     if (foundIt != rawOpViewClassMap.end()) {
132       if (foundIt->second.is_none())
133         return llvm::None;
134       assert(foundIt->second && "py::object is defined");
135       // Positive cache.
136       rawOpViewClassMapCache[operationName] = foundIt->second;
137       return foundIt->second;
138     } else {
139       // Negative cache.
140       rawOpViewClassMap[operationName] = py::none();
141       return llvm::None;
142     }
143   }
144 }
145 
146 void PyGlobals::clearImportCache() {
147   py::gil_scoped_acquire();
148   loadedDialectModulesCache.clear();
149   rawOpViewClassMapCache.clear();
150 }
151 
152 // -----------------------------------------------------------------------------
153 // Module initialization.
154 // -----------------------------------------------------------------------------
155 
156 PYBIND11_MODULE(_mlir, m) {
157   m.doc() = "MLIR Python Native Extension";
158 
159   py::class_<PyGlobals>(m, "_Globals")
160       .def_property("dialect_search_modules",
161                     &PyGlobals::getDialectSearchPrefixes,
162                     &PyGlobals::setDialectSearchPrefixes)
163       .def("append_dialect_search_prefix",
164            [](PyGlobals &self, std::string moduleName) {
165              self.getDialectSearchPrefixes().push_back(std::move(moduleName));
166              self.clearImportCache();
167            })
168       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
169            "Testing hook for directly registering a dialect")
170       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
171            "Testing hook for directly registering an operation");
172 
173   // Aside from making the globals accessible to python, having python manage
174   // it is necessary to make sure it is destroyed (and releases its python
175   // resources) properly.
176   m.attr("globals") =
177       py::cast(new PyGlobals, py::return_value_policy::take_ownership);
178 
179   // Registration decorators.
180   m.def(
181       "register_dialect",
182       [](py::object pyClass) {
183         std::string dialectNamespace =
184             pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
185         PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
186         return pyClass;
187       },
188       "Class decorator for registering a custom Dialect wrapper");
189   m.def(
190       "register_operation",
191       [](py::object dialectClass) -> py::cpp_function {
192         return py::cpp_function(
193             [dialectClass](py::object opClass) -> py::object {
194               std::string operationName =
195                   opClass.attr("OPERATION_NAME").cast<std::string>();
196               auto rawSubclass = PyOpView::createRawSubclass(opClass);
197               PyGlobals::get().registerOperationImpl(operationName, opClass,
198                                                      rawSubclass);
199 
200               // Dict-stuff the new opClass by name onto the dialect class.
201               py::object opClassName = opClass.attr("__name__");
202               dialectClass.attr(opClassName) = opClass;
203 
204               // Now create a special "Raw" subclass that passes through
205               // construction to the OpView parent (bypasses the intermediate
206               // child's __init__).
207               opClass.attr("_Raw") = rawSubclass;
208               return opClass;
209             });
210       },
211       "Class decorator for registering a custom Operation wrapper");
212 
213   // Define and populate IR submodule.
214   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
215   populateIRCore(irModule);
216   populateIRAffine(irModule);
217   populateIRAttributes(irModule);
218   populateIRTypes(irModule);
219 
220   // Define and populate PassManager submodule.
221   auto passModule =
222       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
223   populatePassManagerSubmodule(passModule);
224 
225   // Define and populate ExecutionEngine submodule.
226   auto executionEngineModule =
227       m.def_submodule("execution_engine", "MLIR JIT Execution Engine");
228   populateExecutionEngineSubmodule(executionEngineModule);
229 
230   // Define and populate Linalg submodule.
231   auto dialectsModule = m.def_submodule("dialects");
232   auto linalgModule = dialectsModule.def_submodule("linalg");
233   populateDialectLinalgSubmodule(linalgModule);
234 }
235