1 //===- MainModule.cpp - Main pybind module --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <tuple> 10 11 #include "PybindUtils.h" 12 13 #include "ExecutionEngine.h" 14 #include "Globals.h" 15 #include "IRModules.h" 16 #include "Pass.h" 17 18 namespace py = pybind11; 19 using namespace mlir; 20 using namespace mlir::python; 21 22 // ----------------------------------------------------------------------------- 23 // PyGlobals 24 // ----------------------------------------------------------------------------- 25 26 PyGlobals *PyGlobals::instance = nullptr; 27 28 PyGlobals::PyGlobals() { 29 assert(!instance && "PyGlobals already constructed"); 30 instance = this; 31 } 32 33 PyGlobals::~PyGlobals() { instance = nullptr; } 34 35 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 36 py::gil_scoped_acquire(); 37 if (loadedDialectModulesCache.contains(dialectNamespace)) 38 return; 39 // Since re-entrancy is possible, make a copy of the search prefixes. 40 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 41 py::object loaded; 42 for (std::string moduleName : localSearchPrefixes) { 43 moduleName.push_back('.'); 44 moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 45 46 try { 47 py::gil_scoped_release(); 48 loaded = py::module::import(moduleName.c_str()); 49 } catch (py::error_already_set &e) { 50 if (e.matches(PyExc_ModuleNotFoundError)) { 51 continue; 52 } else { 53 throw; 54 } 55 } 56 break; 57 } 58 59 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 60 // may have occurred, which may do anything. 61 loadedDialectModulesCache.insert(dialectNamespace); 62 } 63 64 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 65 py::object pyClass) { 66 py::gil_scoped_acquire(); 67 py::object &found = dialectClassMap[dialectNamespace]; 68 if (found) { 69 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + 70 dialectNamespace + 71 "' is already registered."); 72 } 73 found = std::move(pyClass); 74 } 75 76 void PyGlobals::registerOperationImpl(const std::string &operationName, 77 py::object pyClass, 78 py::object rawOpViewClass) { 79 py::gil_scoped_acquire(); 80 py::object &found = operationClassMap[operationName]; 81 if (found) { 82 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + 83 operationName + 84 "' is already registered."); 85 } 86 found = std::move(pyClass); 87 rawOpViewClassMap[operationName] = std::move(rawOpViewClass); 88 } 89 90 llvm::Optional<py::object> 91 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 92 py::gil_scoped_acquire(); 93 loadDialectModule(dialectNamespace); 94 // Fast match against the class map first (common case). 95 const auto foundIt = dialectClassMap.find(dialectNamespace); 96 if (foundIt != dialectClassMap.end()) { 97 if (foundIt->second.is_none()) 98 return llvm::None; 99 assert(foundIt->second && "py::object is defined"); 100 return foundIt->second; 101 } 102 103 // Not found and loading did not yield a registration. Negative cache. 104 dialectClassMap[dialectNamespace] = py::none(); 105 return llvm::None; 106 } 107 108 llvm::Optional<pybind11::object> 109 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { 110 { 111 py::gil_scoped_acquire(); 112 auto foundIt = rawOpViewClassMapCache.find(operationName); 113 if (foundIt != rawOpViewClassMapCache.end()) { 114 if (foundIt->second.is_none()) 115 return llvm::None; 116 assert(foundIt->second && "py::object is defined"); 117 return foundIt->second; 118 } 119 } 120 121 // Not found. Load the dialect namespace. 122 auto split = operationName.split('.'); 123 llvm::StringRef dialectNamespace = split.first; 124 loadDialectModule(dialectNamespace); 125 126 // Attempt to find from the canonical map and cache. 127 { 128 py::gil_scoped_acquire(); 129 auto foundIt = rawOpViewClassMap.find(operationName); 130 if (foundIt != rawOpViewClassMap.end()) { 131 if (foundIt->second.is_none()) 132 return llvm::None; 133 assert(foundIt->second && "py::object is defined"); 134 // Positive cache. 135 rawOpViewClassMapCache[operationName] = foundIt->second; 136 return foundIt->second; 137 } else { 138 // Negative cache. 139 rawOpViewClassMap[operationName] = py::none(); 140 return llvm::None; 141 } 142 } 143 } 144 145 void PyGlobals::clearImportCache() { 146 py::gil_scoped_acquire(); 147 loadedDialectModulesCache.clear(); 148 rawOpViewClassMapCache.clear(); 149 } 150 151 // ----------------------------------------------------------------------------- 152 // Module initialization. 153 // ----------------------------------------------------------------------------- 154 155 PYBIND11_MODULE(_mlir, m) { 156 m.doc() = "MLIR Python Native Extension"; 157 158 py::class_<PyGlobals>(m, "_Globals") 159 .def_property("dialect_search_modules", 160 &PyGlobals::getDialectSearchPrefixes, 161 &PyGlobals::setDialectSearchPrefixes) 162 .def("append_dialect_search_prefix", 163 [](PyGlobals &self, std::string moduleName) { 164 self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 165 self.clearImportCache(); 166 }) 167 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 168 "Testing hook for directly registering a dialect") 169 .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 170 "Testing hook for directly registering an operation"); 171 172 // Aside from making the globals accessible to python, having python manage 173 // it is necessary to make sure it is destroyed (and releases its python 174 // resources) properly. 175 m.attr("globals") = 176 py::cast(new PyGlobals, py::return_value_policy::take_ownership); 177 178 // Registration decorators. 179 m.def( 180 "register_dialect", 181 [](py::object pyClass) { 182 std::string dialectNamespace = 183 pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 184 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 185 return pyClass; 186 }, 187 "Class decorator for registering a custom Dialect wrapper"); 188 m.def( 189 "register_operation", 190 [](py::object dialectClass) -> py::cpp_function { 191 return py::cpp_function( 192 [dialectClass](py::object opClass) -> py::object { 193 std::string operationName = 194 opClass.attr("OPERATION_NAME").cast<std::string>(); 195 auto rawSubclass = PyOpView::createRawSubclass(opClass); 196 PyGlobals::get().registerOperationImpl(operationName, opClass, 197 rawSubclass); 198 199 // Dict-stuff the new opClass by name onto the dialect class. 200 py::object opClassName = opClass.attr("__name__"); 201 dialectClass.attr(opClassName) = opClass; 202 203 // Now create a special "Raw" subclass that passes through 204 // construction to the OpView parent (bypasses the intermediate 205 // child's __init__). 206 opClass.attr("_Raw") = rawSubclass; 207 return opClass; 208 }); 209 }, 210 "Class decorator for registering a custom Operation wrapper"); 211 212 // Define and populate IR submodule. 213 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 214 populateIRSubmodule(irModule); 215 216 // Define and populate PassManager submodule. 217 auto passModule = 218 m.def_submodule("passmanager", "MLIR Pass Management Bindings"); 219 populatePassManagerSubmodule(passModule); 220 221 // Define and populate ExecutionEngine submodule. 222 auto executionEngineModule = 223 m.def_submodule("execution_engine", "MLIR JIT Execution Engine"); 224 populateExecutionEngineSubmodule(executionEngineModule); 225 } 226