xref: /llvm-project/mlir/lib/Bindings/Python/MainModule.cpp (revision 436c6c9c20cc522c92a923440a5fc509c342a7db)
1722475a3SStella Laurenzo //===- MainModule.cpp - Main pybind module --------------------------------===//
2722475a3SStella Laurenzo //
3722475a3SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4722475a3SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5722475a3SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6722475a3SStella Laurenzo //
7722475a3SStella Laurenzo //===----------------------------------------------------------------------===//
8722475a3SStella Laurenzo 
9722475a3SStella Laurenzo #include <tuple>
10722475a3SStella Laurenzo 
11013b9322SStella Laurenzo #include "PybindUtils.h"
12722475a3SStella Laurenzo 
1313cb4317SMehdi Amini #include "ExecutionEngine.h"
14013b9322SStella Laurenzo #include "Globals.h"
15*436c6c9cSStella Laurenzo #include "IRModule.h"
16dc43f785SMehdi Amini #include "Pass.h"
17722475a3SStella Laurenzo 
1895b77f2eSStella Laurenzo namespace py = pybind11;
19722475a3SStella Laurenzo using namespace mlir;
2095b77f2eSStella Laurenzo using namespace mlir::python;
21722475a3SStella Laurenzo 
22013b9322SStella Laurenzo // -----------------------------------------------------------------------------
23013b9322SStella Laurenzo // PyGlobals
24013b9322SStella Laurenzo // -----------------------------------------------------------------------------
25013b9322SStella Laurenzo 
26013b9322SStella Laurenzo PyGlobals *PyGlobals::instance = nullptr;
27013b9322SStella Laurenzo 
28013b9322SStella Laurenzo PyGlobals::PyGlobals() {
29013b9322SStella Laurenzo   assert(!instance && "PyGlobals already constructed");
30013b9322SStella Laurenzo   instance = this;
31013b9322SStella Laurenzo }
32013b9322SStella Laurenzo 
33013b9322SStella Laurenzo PyGlobals::~PyGlobals() { instance = nullptr; }
34013b9322SStella Laurenzo 
358260db75SStella Laurenzo void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
368260db75SStella Laurenzo   py::gil_scoped_acquire();
378260db75SStella Laurenzo   if (loadedDialectModulesCache.contains(dialectNamespace))
38013b9322SStella Laurenzo     return;
39013b9322SStella Laurenzo   // Since re-entrancy is possible, make a copy of the search prefixes.
40013b9322SStella Laurenzo   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
41013b9322SStella Laurenzo   py::object loaded;
42013b9322SStella Laurenzo   for (std::string moduleName : localSearchPrefixes) {
43013b9322SStella Laurenzo     moduleName.push_back('.');
448260db75SStella Laurenzo     moduleName.append(dialectNamespace.data(), dialectNamespace.size());
45013b9322SStella Laurenzo 
46013b9322SStella Laurenzo     try {
478260db75SStella Laurenzo       py::gil_scoped_release();
48013b9322SStella Laurenzo       loaded = py::module::import(moduleName.c_str());
49013b9322SStella Laurenzo     } catch (py::error_already_set &e) {
50013b9322SStella Laurenzo       if (e.matches(PyExc_ModuleNotFoundError)) {
51013b9322SStella Laurenzo         continue;
52013b9322SStella Laurenzo       } else {
53013b9322SStella Laurenzo         throw;
54013b9322SStella Laurenzo       }
55013b9322SStella Laurenzo     }
56013b9322SStella Laurenzo     break;
57013b9322SStella Laurenzo   }
58013b9322SStella Laurenzo 
59013b9322SStella Laurenzo   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
60013b9322SStella Laurenzo   // may have occurred, which may do anything.
618260db75SStella Laurenzo   loadedDialectModulesCache.insert(dialectNamespace);
62013b9322SStella Laurenzo }
63013b9322SStella Laurenzo 
64013b9322SStella Laurenzo void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
65013b9322SStella Laurenzo                                     py::object pyClass) {
668260db75SStella Laurenzo   py::gil_scoped_acquire();
67013b9322SStella Laurenzo   py::object &found = dialectClassMap[dialectNamespace];
68013b9322SStella Laurenzo   if (found) {
69013b9322SStella Laurenzo     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
70013b9322SStella Laurenzo                                              dialectNamespace +
71013b9322SStella Laurenzo                                              "' is already registered.");
72013b9322SStella Laurenzo   }
73013b9322SStella Laurenzo   found = std::move(pyClass);
74013b9322SStella Laurenzo }
75013b9322SStella Laurenzo 
76013b9322SStella Laurenzo void PyGlobals::registerOperationImpl(const std::string &operationName,
778260db75SStella Laurenzo                                       py::object pyClass,
788260db75SStella Laurenzo                                       py::object rawOpViewClass) {
798260db75SStella Laurenzo   py::gil_scoped_acquire();
80013b9322SStella Laurenzo   py::object &found = operationClassMap[operationName];
81013b9322SStella Laurenzo   if (found) {
82013b9322SStella Laurenzo     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
83013b9322SStella Laurenzo                                              operationName +
84013b9322SStella Laurenzo                                              "' is already registered.");
85013b9322SStella Laurenzo   }
86013b9322SStella Laurenzo   found = std::move(pyClass);
878260db75SStella Laurenzo   rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
88013b9322SStella Laurenzo }
89013b9322SStella Laurenzo 
90013b9322SStella Laurenzo llvm::Optional<py::object>
91013b9322SStella Laurenzo PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
928260db75SStella Laurenzo   py::gil_scoped_acquire();
93013b9322SStella Laurenzo   loadDialectModule(dialectNamespace);
94013b9322SStella Laurenzo   // Fast match against the class map first (common case).
95013b9322SStella Laurenzo   const auto foundIt = dialectClassMap.find(dialectNamespace);
96013b9322SStella Laurenzo   if (foundIt != dialectClassMap.end()) {
97013b9322SStella Laurenzo     if (foundIt->second.is_none())
98013b9322SStella Laurenzo       return llvm::None;
99013b9322SStella Laurenzo     assert(foundIt->second && "py::object is defined");
100013b9322SStella Laurenzo     return foundIt->second;
101013b9322SStella Laurenzo   }
102013b9322SStella Laurenzo 
103013b9322SStella Laurenzo   // Not found and loading did not yield a registration. Negative cache.
104013b9322SStella Laurenzo   dialectClassMap[dialectNamespace] = py::none();
105013b9322SStella Laurenzo   return llvm::None;
106013b9322SStella Laurenzo }
107013b9322SStella Laurenzo 
1088260db75SStella Laurenzo llvm::Optional<pybind11::object>
1098260db75SStella Laurenzo PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
1108260db75SStella Laurenzo   {
1118260db75SStella Laurenzo     py::gil_scoped_acquire();
1128260db75SStella Laurenzo     auto foundIt = rawOpViewClassMapCache.find(operationName);
1138260db75SStella Laurenzo     if (foundIt != rawOpViewClassMapCache.end()) {
1148260db75SStella Laurenzo       if (foundIt->second.is_none())
1158260db75SStella Laurenzo         return llvm::None;
1168260db75SStella Laurenzo       assert(foundIt->second && "py::object is defined");
1178260db75SStella Laurenzo       return foundIt->second;
1188260db75SStella Laurenzo     }
1198260db75SStella Laurenzo   }
1208260db75SStella Laurenzo 
1218260db75SStella Laurenzo   // Not found. Load the dialect namespace.
1228260db75SStella Laurenzo   auto split = operationName.split('.');
1238260db75SStella Laurenzo   llvm::StringRef dialectNamespace = split.first;
1248260db75SStella Laurenzo   loadDialectModule(dialectNamespace);
1258260db75SStella Laurenzo 
1268260db75SStella Laurenzo   // Attempt to find from the canonical map and cache.
1278260db75SStella Laurenzo   {
1288260db75SStella Laurenzo     py::gil_scoped_acquire();
1298260db75SStella Laurenzo     auto foundIt = rawOpViewClassMap.find(operationName);
1308260db75SStella Laurenzo     if (foundIt != rawOpViewClassMap.end()) {
1318260db75SStella Laurenzo       if (foundIt->second.is_none())
1328260db75SStella Laurenzo         return llvm::None;
1338260db75SStella Laurenzo       assert(foundIt->second && "py::object is defined");
1348260db75SStella Laurenzo       // Positive cache.
1358260db75SStella Laurenzo       rawOpViewClassMapCache[operationName] = foundIt->second;
1368260db75SStella Laurenzo       return foundIt->second;
1378260db75SStella Laurenzo     } else {
1388260db75SStella Laurenzo       // Negative cache.
1398260db75SStella Laurenzo       rawOpViewClassMap[operationName] = py::none();
1408260db75SStella Laurenzo       return llvm::None;
1418260db75SStella Laurenzo     }
1428260db75SStella Laurenzo   }
1438260db75SStella Laurenzo }
1448260db75SStella Laurenzo 
1458260db75SStella Laurenzo void PyGlobals::clearImportCache() {
1468260db75SStella Laurenzo   py::gil_scoped_acquire();
1478260db75SStella Laurenzo   loadedDialectModulesCache.clear();
1488260db75SStella Laurenzo   rawOpViewClassMapCache.clear();
1498260db75SStella Laurenzo }
1508260db75SStella Laurenzo 
151013b9322SStella Laurenzo // -----------------------------------------------------------------------------
152013b9322SStella Laurenzo // Module initialization.
153013b9322SStella Laurenzo // -----------------------------------------------------------------------------
154013b9322SStella Laurenzo 
155722475a3SStella Laurenzo PYBIND11_MODULE(_mlir, m) {
156722475a3SStella Laurenzo   m.doc() = "MLIR Python Native Extension";
157722475a3SStella Laurenzo 
158013b9322SStella Laurenzo   py::class_<PyGlobals>(m, "_Globals")
159013b9322SStella Laurenzo       .def_property("dialect_search_modules",
160013b9322SStella Laurenzo                     &PyGlobals::getDialectSearchPrefixes,
161013b9322SStella Laurenzo                     &PyGlobals::setDialectSearchPrefixes)
162013b9322SStella Laurenzo       .def("append_dialect_search_prefix",
163013b9322SStella Laurenzo            [](PyGlobals &self, std::string moduleName) {
164013b9322SStella Laurenzo              self.getDialectSearchPrefixes().push_back(std::move(moduleName));
1658260db75SStella Laurenzo              self.clearImportCache();
166013b9322SStella Laurenzo            })
167013b9322SStella Laurenzo       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
168013b9322SStella Laurenzo            "Testing hook for directly registering a dialect")
169013b9322SStella Laurenzo       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
170013b9322SStella Laurenzo            "Testing hook for directly registering an operation");
171013b9322SStella Laurenzo 
172013b9322SStella Laurenzo   // Aside from making the globals accessible to python, having python manage
173013b9322SStella Laurenzo   // it is necessary to make sure it is destroyed (and releases its python
174013b9322SStella Laurenzo   // resources) properly.
175013b9322SStella Laurenzo   m.attr("globals") =
176013b9322SStella Laurenzo       py::cast(new PyGlobals, py::return_value_policy::take_ownership);
177013b9322SStella Laurenzo 
178013b9322SStella Laurenzo   // Registration decorators.
179013b9322SStella Laurenzo   m.def(
180013b9322SStella Laurenzo       "register_dialect",
181013b9322SStella Laurenzo       [](py::object pyClass) {
182013b9322SStella Laurenzo         std::string dialectNamespace =
183013b9322SStella Laurenzo             pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
184013b9322SStella Laurenzo         PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
185013b9322SStella Laurenzo         return pyClass;
186013b9322SStella Laurenzo       },
187013b9322SStella Laurenzo       "Class decorator for registering a custom Dialect wrapper");
188013b9322SStella Laurenzo   m.def(
189013b9322SStella Laurenzo       "register_operation",
190013b9322SStella Laurenzo       [](py::object dialectClass) -> py::cpp_function {
191013b9322SStella Laurenzo         return py::cpp_function(
192013b9322SStella Laurenzo             [dialectClass](py::object opClass) -> py::object {
193013b9322SStella Laurenzo               std::string operationName =
194013b9322SStella Laurenzo                   opClass.attr("OPERATION_NAME").cast<std::string>();
195013b9322SStella Laurenzo               auto rawSubclass = PyOpView::createRawSubclass(opClass);
196013b9322SStella Laurenzo               PyGlobals::get().registerOperationImpl(operationName, opClass,
197013b9322SStella Laurenzo                                                      rawSubclass);
198013b9322SStella Laurenzo 
199013b9322SStella Laurenzo               // Dict-stuff the new opClass by name onto the dialect class.
200013b9322SStella Laurenzo               py::object opClassName = opClass.attr("__name__");
201013b9322SStella Laurenzo               dialectClass.attr(opClassName) = opClass;
202013b9322SStella Laurenzo 
203013b9322SStella Laurenzo               // Now create a special "Raw" subclass that passes through
204013b9322SStella Laurenzo               // construction to the OpView parent (bypasses the intermediate
205013b9322SStella Laurenzo               // child's __init__).
206013b9322SStella Laurenzo               opClass.attr("_Raw") = rawSubclass;
207013b9322SStella Laurenzo               return opClass;
208013b9322SStella Laurenzo             });
209013b9322SStella Laurenzo       },
210013b9322SStella Laurenzo       "Class decorator for registering a custom Operation wrapper");
211013b9322SStella Laurenzo 
212fcd2969dSzhanghb97   // Define and populate IR submodule.
213fcd2969dSzhanghb97   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
214*436c6c9cSStella Laurenzo   populateIRCore(irModule);
215*436c6c9cSStella Laurenzo   populateIRAffine(irModule);
216*436c6c9cSStella Laurenzo   populateIRAttributes(irModule);
217*436c6c9cSStella Laurenzo   populateIRTypes(irModule);
218dc43f785SMehdi Amini 
219dc43f785SMehdi Amini   // Define and populate PassManager submodule.
220dc43f785SMehdi Amini   auto passModule =
221dc43f785SMehdi Amini       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
222dc43f785SMehdi Amini   populatePassManagerSubmodule(passModule);
22313cb4317SMehdi Amini 
22413cb4317SMehdi Amini   // Define and populate ExecutionEngine submodule.
22513cb4317SMehdi Amini   auto executionEngineModule =
22613cb4317SMehdi Amini       m.def_submodule("execution_engine", "MLIR JIT Execution Engine");
22713cb4317SMehdi Amini   populateExecutionEngineSubmodule(executionEngineModule);
228722475a3SStella Laurenzo }
229