1 //===- MainModule.cpp - Main pybind module --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <tuple> 10 11 #include "PybindUtils.h" 12 13 #include "DialectLinalg.h" 14 #include "ExecutionEngine.h" 15 #include "Globals.h" 16 #include "IRModule.h" 17 #include "Pass.h" 18 19 namespace py = pybind11; 20 using namespace mlir; 21 using namespace mlir::python; 22 23 // ----------------------------------------------------------------------------- 24 // PyGlobals 25 // ----------------------------------------------------------------------------- 26 27 PyGlobals *PyGlobals::instance = nullptr; 28 29 PyGlobals::PyGlobals() { 30 assert(!instance && "PyGlobals already constructed"); 31 instance = this; 32 } 33 34 PyGlobals::~PyGlobals() { instance = nullptr; } 35 36 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 37 py::gil_scoped_acquire(); 38 if (loadedDialectModulesCache.contains(dialectNamespace)) 39 return; 40 // Since re-entrancy is possible, make a copy of the search prefixes. 41 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 42 py::object loaded; 43 for (std::string moduleName : localSearchPrefixes) { 44 moduleName.push_back('.'); 45 moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 46 47 try { 48 py::gil_scoped_release(); 49 loaded = py::module::import(moduleName.c_str()); 50 } catch (py::error_already_set &e) { 51 if (e.matches(PyExc_ModuleNotFoundError)) { 52 continue; 53 } else { 54 throw; 55 } 56 } 57 break; 58 } 59 60 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 61 // may have occurred, which may do anything. 62 loadedDialectModulesCache.insert(dialectNamespace); 63 } 64 65 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 66 py::object pyClass) { 67 py::gil_scoped_acquire(); 68 py::object &found = dialectClassMap[dialectNamespace]; 69 if (found) { 70 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + 71 dialectNamespace + 72 "' is already registered."); 73 } 74 found = std::move(pyClass); 75 } 76 77 void PyGlobals::registerOperationImpl(const std::string &operationName, 78 py::object pyClass, 79 py::object rawOpViewClass) { 80 py::gil_scoped_acquire(); 81 py::object &found = operationClassMap[operationName]; 82 if (found) { 83 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + 84 operationName + 85 "' is already registered."); 86 } 87 found = std::move(pyClass); 88 rawOpViewClassMap[operationName] = std::move(rawOpViewClass); 89 } 90 91 llvm::Optional<py::object> 92 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 93 py::gil_scoped_acquire(); 94 loadDialectModule(dialectNamespace); 95 // Fast match against the class map first (common case). 96 const auto foundIt = dialectClassMap.find(dialectNamespace); 97 if (foundIt != dialectClassMap.end()) { 98 if (foundIt->second.is_none()) 99 return llvm::None; 100 assert(foundIt->second && "py::object is defined"); 101 return foundIt->second; 102 } 103 104 // Not found and loading did not yield a registration. Negative cache. 105 dialectClassMap[dialectNamespace] = py::none(); 106 return llvm::None; 107 } 108 109 llvm::Optional<pybind11::object> 110 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { 111 { 112 py::gil_scoped_acquire(); 113 auto foundIt = rawOpViewClassMapCache.find(operationName); 114 if (foundIt != rawOpViewClassMapCache.end()) { 115 if (foundIt->second.is_none()) 116 return llvm::None; 117 assert(foundIt->second && "py::object is defined"); 118 return foundIt->second; 119 } 120 } 121 122 // Not found. Load the dialect namespace. 123 auto split = operationName.split('.'); 124 llvm::StringRef dialectNamespace = split.first; 125 loadDialectModule(dialectNamespace); 126 127 // Attempt to find from the canonical map and cache. 128 { 129 py::gil_scoped_acquire(); 130 auto foundIt = rawOpViewClassMap.find(operationName); 131 if (foundIt != rawOpViewClassMap.end()) { 132 if (foundIt->second.is_none()) 133 return llvm::None; 134 assert(foundIt->second && "py::object is defined"); 135 // Positive cache. 136 rawOpViewClassMapCache[operationName] = foundIt->second; 137 return foundIt->second; 138 } else { 139 // Negative cache. 140 rawOpViewClassMap[operationName] = py::none(); 141 return llvm::None; 142 } 143 } 144 } 145 146 void PyGlobals::clearImportCache() { 147 py::gil_scoped_acquire(); 148 loadedDialectModulesCache.clear(); 149 rawOpViewClassMapCache.clear(); 150 } 151 152 // ----------------------------------------------------------------------------- 153 // Module initialization. 154 // ----------------------------------------------------------------------------- 155 156 PYBIND11_MODULE(_mlir, m) { 157 m.doc() = "MLIR Python Native Extension"; 158 159 py::class_<PyGlobals>(m, "_Globals") 160 .def_property("dialect_search_modules", 161 &PyGlobals::getDialectSearchPrefixes, 162 &PyGlobals::setDialectSearchPrefixes) 163 .def("append_dialect_search_prefix", 164 [](PyGlobals &self, std::string moduleName) { 165 self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 166 self.clearImportCache(); 167 }) 168 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 169 "Testing hook for directly registering a dialect") 170 .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 171 "Testing hook for directly registering an operation"); 172 173 // Aside from making the globals accessible to python, having python manage 174 // it is necessary to make sure it is destroyed (and releases its python 175 // resources) properly. 176 m.attr("globals") = 177 py::cast(new PyGlobals, py::return_value_policy::take_ownership); 178 179 // Registration decorators. 180 m.def( 181 "register_dialect", 182 [](py::object pyClass) { 183 std::string dialectNamespace = 184 pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 185 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 186 return pyClass; 187 }, 188 "Class decorator for registering a custom Dialect wrapper"); 189 m.def( 190 "register_operation", 191 [](py::object dialectClass) -> py::cpp_function { 192 return py::cpp_function( 193 [dialectClass](py::object opClass) -> py::object { 194 std::string operationName = 195 opClass.attr("OPERATION_NAME").cast<std::string>(); 196 auto rawSubclass = PyOpView::createRawSubclass(opClass); 197 PyGlobals::get().registerOperationImpl(operationName, opClass, 198 rawSubclass); 199 200 // Dict-stuff the new opClass by name onto the dialect class. 201 py::object opClassName = opClass.attr("__name__"); 202 dialectClass.attr(opClassName) = opClass; 203 204 // Now create a special "Raw" subclass that passes through 205 // construction to the OpView parent (bypasses the intermediate 206 // child's __init__). 207 opClass.attr("_Raw") = rawSubclass; 208 return opClass; 209 }); 210 }, 211 "Class decorator for registering a custom Operation wrapper"); 212 213 // Define and populate IR submodule. 214 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 215 populateIRCore(irModule); 216 populateIRAffine(irModule); 217 populateIRAttributes(irModule); 218 populateIRTypes(irModule); 219 220 // Define and populate PassManager submodule. 221 auto passModule = 222 m.def_submodule("passmanager", "MLIR Pass Management Bindings"); 223 populatePassManagerSubmodule(passModule); 224 225 // Define and populate ExecutionEngine submodule. 226 auto executionEngineModule = 227 m.def_submodule("execution_engine", "MLIR JIT Execution Engine"); 228 populateExecutionEngineSubmodule(executionEngineModule); 229 230 // Define and populate Linalg submodule. 231 auto dialectsModule = m.def_submodule("dialects"); 232 auto linalgModule = dialectsModule.def_submodule("linalg"); 233 populateDialectLinalgSubmodule(linalgModule); 234 } 235