1722475a3SStella Laurenzo //===- MainModule.cpp - Main pybind module --------------------------------===// 2722475a3SStella Laurenzo // 3722475a3SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4722475a3SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5722475a3SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6722475a3SStella Laurenzo // 7722475a3SStella Laurenzo //===----------------------------------------------------------------------===// 8722475a3SStella Laurenzo 9722475a3SStella Laurenzo #include <tuple> 10722475a3SStella Laurenzo 11013b9322SStella Laurenzo #include "PybindUtils.h" 12722475a3SStella Laurenzo 1313cb4317SMehdi Amini #include "ExecutionEngine.h" 14013b9322SStella Laurenzo #include "Globals.h" 15*436c6c9cSStella Laurenzo #include "IRModule.h" 16dc43f785SMehdi Amini #include "Pass.h" 17722475a3SStella Laurenzo 1895b77f2eSStella Laurenzo namespace py = pybind11; 19722475a3SStella Laurenzo using namespace mlir; 2095b77f2eSStella Laurenzo using namespace mlir::python; 21722475a3SStella Laurenzo 22013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 23013b9322SStella Laurenzo // PyGlobals 24013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 25013b9322SStella Laurenzo 26013b9322SStella Laurenzo PyGlobals *PyGlobals::instance = nullptr; 27013b9322SStella Laurenzo 28013b9322SStella Laurenzo PyGlobals::PyGlobals() { 29013b9322SStella Laurenzo assert(!instance && "PyGlobals already constructed"); 30013b9322SStella Laurenzo instance = this; 31013b9322SStella Laurenzo } 32013b9322SStella Laurenzo 33013b9322SStella Laurenzo PyGlobals::~PyGlobals() { instance = nullptr; } 34013b9322SStella Laurenzo 358260db75SStella Laurenzo void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 368260db75SStella Laurenzo py::gil_scoped_acquire(); 378260db75SStella Laurenzo if (loadedDialectModulesCache.contains(dialectNamespace)) 38013b9322SStella Laurenzo return; 39013b9322SStella Laurenzo // Since re-entrancy is possible, make a copy of the search prefixes. 40013b9322SStella Laurenzo std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 41013b9322SStella Laurenzo py::object loaded; 42013b9322SStella Laurenzo for (std::string moduleName : localSearchPrefixes) { 43013b9322SStella Laurenzo moduleName.push_back('.'); 448260db75SStella Laurenzo moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 45013b9322SStella Laurenzo 46013b9322SStella Laurenzo try { 478260db75SStella Laurenzo py::gil_scoped_release(); 48013b9322SStella Laurenzo loaded = py::module::import(moduleName.c_str()); 49013b9322SStella Laurenzo } catch (py::error_already_set &e) { 50013b9322SStella Laurenzo if (e.matches(PyExc_ModuleNotFoundError)) { 51013b9322SStella Laurenzo continue; 52013b9322SStella Laurenzo } else { 53013b9322SStella Laurenzo throw; 54013b9322SStella Laurenzo } 55013b9322SStella Laurenzo } 56013b9322SStella Laurenzo break; 57013b9322SStella Laurenzo } 58013b9322SStella Laurenzo 59013b9322SStella Laurenzo // Note: Iterator cannot be shared from prior to loading, since re-entrancy 60013b9322SStella Laurenzo // may have occurred, which may do anything. 618260db75SStella Laurenzo loadedDialectModulesCache.insert(dialectNamespace); 62013b9322SStella Laurenzo } 63013b9322SStella Laurenzo 64013b9322SStella Laurenzo void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 65013b9322SStella Laurenzo py::object pyClass) { 668260db75SStella Laurenzo py::gil_scoped_acquire(); 67013b9322SStella Laurenzo py::object &found = dialectClassMap[dialectNamespace]; 68013b9322SStella Laurenzo if (found) { 69013b9322SStella Laurenzo throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + 70013b9322SStella Laurenzo dialectNamespace + 71013b9322SStella Laurenzo "' is already registered."); 72013b9322SStella Laurenzo } 73013b9322SStella Laurenzo found = std::move(pyClass); 74013b9322SStella Laurenzo } 75013b9322SStella Laurenzo 76013b9322SStella Laurenzo void PyGlobals::registerOperationImpl(const std::string &operationName, 778260db75SStella Laurenzo py::object pyClass, 788260db75SStella Laurenzo py::object rawOpViewClass) { 798260db75SStella Laurenzo py::gil_scoped_acquire(); 80013b9322SStella Laurenzo py::object &found = operationClassMap[operationName]; 81013b9322SStella Laurenzo if (found) { 82013b9322SStella Laurenzo throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + 83013b9322SStella Laurenzo operationName + 84013b9322SStella Laurenzo "' is already registered."); 85013b9322SStella Laurenzo } 86013b9322SStella Laurenzo found = std::move(pyClass); 878260db75SStella Laurenzo rawOpViewClassMap[operationName] = std::move(rawOpViewClass); 88013b9322SStella Laurenzo } 89013b9322SStella Laurenzo 90013b9322SStella Laurenzo llvm::Optional<py::object> 91013b9322SStella Laurenzo PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 928260db75SStella Laurenzo py::gil_scoped_acquire(); 93013b9322SStella Laurenzo loadDialectModule(dialectNamespace); 94013b9322SStella Laurenzo // Fast match against the class map first (common case). 95013b9322SStella Laurenzo const auto foundIt = dialectClassMap.find(dialectNamespace); 96013b9322SStella Laurenzo if (foundIt != dialectClassMap.end()) { 97013b9322SStella Laurenzo if (foundIt->second.is_none()) 98013b9322SStella Laurenzo return llvm::None; 99013b9322SStella Laurenzo assert(foundIt->second && "py::object is defined"); 100013b9322SStella Laurenzo return foundIt->second; 101013b9322SStella Laurenzo } 102013b9322SStella Laurenzo 103013b9322SStella Laurenzo // Not found and loading did not yield a registration. Negative cache. 104013b9322SStella Laurenzo dialectClassMap[dialectNamespace] = py::none(); 105013b9322SStella Laurenzo return llvm::None; 106013b9322SStella Laurenzo } 107013b9322SStella Laurenzo 1088260db75SStella Laurenzo llvm::Optional<pybind11::object> 1098260db75SStella Laurenzo PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { 1108260db75SStella Laurenzo { 1118260db75SStella Laurenzo py::gil_scoped_acquire(); 1128260db75SStella Laurenzo auto foundIt = rawOpViewClassMapCache.find(operationName); 1138260db75SStella Laurenzo if (foundIt != rawOpViewClassMapCache.end()) { 1148260db75SStella Laurenzo if (foundIt->second.is_none()) 1158260db75SStella Laurenzo return llvm::None; 1168260db75SStella Laurenzo assert(foundIt->second && "py::object is defined"); 1178260db75SStella Laurenzo return foundIt->second; 1188260db75SStella Laurenzo } 1198260db75SStella Laurenzo } 1208260db75SStella Laurenzo 1218260db75SStella Laurenzo // Not found. Load the dialect namespace. 1228260db75SStella Laurenzo auto split = operationName.split('.'); 1238260db75SStella Laurenzo llvm::StringRef dialectNamespace = split.first; 1248260db75SStella Laurenzo loadDialectModule(dialectNamespace); 1258260db75SStella Laurenzo 1268260db75SStella Laurenzo // Attempt to find from the canonical map and cache. 1278260db75SStella Laurenzo { 1288260db75SStella Laurenzo py::gil_scoped_acquire(); 1298260db75SStella Laurenzo auto foundIt = rawOpViewClassMap.find(operationName); 1308260db75SStella Laurenzo if (foundIt != rawOpViewClassMap.end()) { 1318260db75SStella Laurenzo if (foundIt->second.is_none()) 1328260db75SStella Laurenzo return llvm::None; 1338260db75SStella Laurenzo assert(foundIt->second && "py::object is defined"); 1348260db75SStella Laurenzo // Positive cache. 1358260db75SStella Laurenzo rawOpViewClassMapCache[operationName] = foundIt->second; 1368260db75SStella Laurenzo return foundIt->second; 1378260db75SStella Laurenzo } else { 1388260db75SStella Laurenzo // Negative cache. 1398260db75SStella Laurenzo rawOpViewClassMap[operationName] = py::none(); 1408260db75SStella Laurenzo return llvm::None; 1418260db75SStella Laurenzo } 1428260db75SStella Laurenzo } 1438260db75SStella Laurenzo } 1448260db75SStella Laurenzo 1458260db75SStella Laurenzo void PyGlobals::clearImportCache() { 1468260db75SStella Laurenzo py::gil_scoped_acquire(); 1478260db75SStella Laurenzo loadedDialectModulesCache.clear(); 1488260db75SStella Laurenzo rawOpViewClassMapCache.clear(); 1498260db75SStella Laurenzo } 1508260db75SStella Laurenzo 151013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 152013b9322SStella Laurenzo // Module initialization. 153013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 154013b9322SStella Laurenzo 155722475a3SStella Laurenzo PYBIND11_MODULE(_mlir, m) { 156722475a3SStella Laurenzo m.doc() = "MLIR Python Native Extension"; 157722475a3SStella Laurenzo 158013b9322SStella Laurenzo py::class_<PyGlobals>(m, "_Globals") 159013b9322SStella Laurenzo .def_property("dialect_search_modules", 160013b9322SStella Laurenzo &PyGlobals::getDialectSearchPrefixes, 161013b9322SStella Laurenzo &PyGlobals::setDialectSearchPrefixes) 162013b9322SStella Laurenzo .def("append_dialect_search_prefix", 163013b9322SStella Laurenzo [](PyGlobals &self, std::string moduleName) { 164013b9322SStella Laurenzo self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 1658260db75SStella Laurenzo self.clearImportCache(); 166013b9322SStella Laurenzo }) 167013b9322SStella Laurenzo .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 168013b9322SStella Laurenzo "Testing hook for directly registering a dialect") 169013b9322SStella Laurenzo .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 170013b9322SStella Laurenzo "Testing hook for directly registering an operation"); 171013b9322SStella Laurenzo 172013b9322SStella Laurenzo // Aside from making the globals accessible to python, having python manage 173013b9322SStella Laurenzo // it is necessary to make sure it is destroyed (and releases its python 174013b9322SStella Laurenzo // resources) properly. 175013b9322SStella Laurenzo m.attr("globals") = 176013b9322SStella Laurenzo py::cast(new PyGlobals, py::return_value_policy::take_ownership); 177013b9322SStella Laurenzo 178013b9322SStella Laurenzo // Registration decorators. 179013b9322SStella Laurenzo m.def( 180013b9322SStella Laurenzo "register_dialect", 181013b9322SStella Laurenzo [](py::object pyClass) { 182013b9322SStella Laurenzo std::string dialectNamespace = 183013b9322SStella Laurenzo pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 184013b9322SStella Laurenzo PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 185013b9322SStella Laurenzo return pyClass; 186013b9322SStella Laurenzo }, 187013b9322SStella Laurenzo "Class decorator for registering a custom Dialect wrapper"); 188013b9322SStella Laurenzo m.def( 189013b9322SStella Laurenzo "register_operation", 190013b9322SStella Laurenzo [](py::object dialectClass) -> py::cpp_function { 191013b9322SStella Laurenzo return py::cpp_function( 192013b9322SStella Laurenzo [dialectClass](py::object opClass) -> py::object { 193013b9322SStella Laurenzo std::string operationName = 194013b9322SStella Laurenzo opClass.attr("OPERATION_NAME").cast<std::string>(); 195013b9322SStella Laurenzo auto rawSubclass = PyOpView::createRawSubclass(opClass); 196013b9322SStella Laurenzo PyGlobals::get().registerOperationImpl(operationName, opClass, 197013b9322SStella Laurenzo rawSubclass); 198013b9322SStella Laurenzo 199013b9322SStella Laurenzo // Dict-stuff the new opClass by name onto the dialect class. 200013b9322SStella Laurenzo py::object opClassName = opClass.attr("__name__"); 201013b9322SStella Laurenzo dialectClass.attr(opClassName) = opClass; 202013b9322SStella Laurenzo 203013b9322SStella Laurenzo // Now create a special "Raw" subclass that passes through 204013b9322SStella Laurenzo // construction to the OpView parent (bypasses the intermediate 205013b9322SStella Laurenzo // child's __init__). 206013b9322SStella Laurenzo opClass.attr("_Raw") = rawSubclass; 207013b9322SStella Laurenzo return opClass; 208013b9322SStella Laurenzo }); 209013b9322SStella Laurenzo }, 210013b9322SStella Laurenzo "Class decorator for registering a custom Operation wrapper"); 211013b9322SStella Laurenzo 212fcd2969dSzhanghb97 // Define and populate IR submodule. 213fcd2969dSzhanghb97 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 214*436c6c9cSStella Laurenzo populateIRCore(irModule); 215*436c6c9cSStella Laurenzo populateIRAffine(irModule); 216*436c6c9cSStella Laurenzo populateIRAttributes(irModule); 217*436c6c9cSStella Laurenzo populateIRTypes(irModule); 218dc43f785SMehdi Amini 219dc43f785SMehdi Amini // Define and populate PassManager submodule. 220dc43f785SMehdi Amini auto passModule = 221dc43f785SMehdi Amini m.def_submodule("passmanager", "MLIR Pass Management Bindings"); 222dc43f785SMehdi Amini populatePassManagerSubmodule(passModule); 22313cb4317SMehdi Amini 22413cb4317SMehdi Amini // Define and populate ExecutionEngine submodule. 22513cb4317SMehdi Amini auto executionEngineModule = 22613cb4317SMehdi Amini m.def_submodule("execution_engine", "MLIR JIT Execution Engine"); 22713cb4317SMehdi Amini populateExecutionEngineSubmodule(executionEngineModule); 228722475a3SStella Laurenzo } 229