1 //===- IRModule.cpp - IR pybind module ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 #include "Globals.h" 11 #include "PybindUtils.h" 12 13 #include <optional> 14 #include <vector> 15 16 #include "mlir-c/Bindings/Python/Interop.h" 17 #include "mlir-c/Support.h" 18 19 namespace py = pybind11; 20 using namespace mlir; 21 using namespace mlir::python; 22 23 // ----------------------------------------------------------------------------- 24 // PyGlobals 25 // ----------------------------------------------------------------------------- 26 27 PyGlobals *PyGlobals::instance = nullptr; 28 29 PyGlobals::PyGlobals() { 30 assert(!instance && "PyGlobals already constructed"); 31 instance = this; 32 // The default search path include {mlir.}dialects, where {mlir.} is the 33 // package prefix configured at compile time. 34 dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); 35 } 36 37 PyGlobals::~PyGlobals() { instance = nullptr; } 38 39 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 40 if (loadedDialectModulesCache.contains(dialectNamespace)) 41 return; 42 // Since re-entrancy is possible, make a copy of the search prefixes. 43 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 44 py::object loaded; 45 for (std::string moduleName : localSearchPrefixes) { 46 moduleName.push_back('.'); 47 moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 48 49 try { 50 loaded = py::module::import(moduleName.c_str()); 51 } catch (py::error_already_set &e) { 52 if (e.matches(PyExc_ModuleNotFoundError)) { 53 continue; 54 } 55 throw; 56 } 57 break; 58 } 59 60 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 61 // may have occurred, which may do anything. 62 loadedDialectModulesCache.insert(dialectNamespace); 63 } 64 65 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, 66 py::function pyFunc) { 67 py::object &found = attributeBuilderMap[attributeKind]; 68 if (found) { 69 throw std::runtime_error((llvm::Twine("Attribute builder for '") + 70 attributeKind + "' is already registered") 71 .str()); 72 } 73 found = std::move(pyFunc); 74 } 75 76 void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, 77 pybind11::function typeCaster, 78 bool replace) { 79 pybind11::object &found = typeCasterMap[mlirTypeID]; 80 if (found && !found.is_none() && !replace) 81 throw std::runtime_error("Type caster is already registered"); 82 found = std::move(typeCaster); 83 } 84 85 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 86 py::object pyClass) { 87 py::object &found = dialectClassMap[dialectNamespace]; 88 if (found) { 89 throw std::runtime_error((llvm::Twine("Dialect namespace '") + 90 dialectNamespace + "' is already registered.") 91 .str()); 92 } 93 found = std::move(pyClass); 94 } 95 96 void PyGlobals::registerOperationImpl(const std::string &operationName, 97 py::object pyClass) { 98 py::object &found = operationClassMap[operationName]; 99 if (found) { 100 throw std::runtime_error((llvm::Twine("Operation '") + operationName + 101 "' is already registered.") 102 .str()); 103 } 104 found = std::move(pyClass); 105 } 106 107 std::optional<py::function> 108 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { 109 // Fast match against the class map first (common case). 110 const auto foundIt = attributeBuilderMap.find(attributeKind); 111 if (foundIt != attributeBuilderMap.end()) { 112 if (foundIt->second.is_none()) 113 return std::nullopt; 114 assert(foundIt->second && "py::function is defined"); 115 return foundIt->second; 116 } 117 118 // Not found and loading did not yield a registration. Negative cache. 119 attributeBuilderMap[attributeKind] = py::none(); 120 return std::nullopt; 121 } 122 123 std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, 124 MlirDialect dialect) { 125 { 126 // Fast match against the class map first (common case). 127 const auto foundIt = typeCasterMapCache.find(mlirTypeID); 128 if (foundIt != typeCasterMapCache.end()) { 129 if (foundIt->second.is_none()) 130 return std::nullopt; 131 assert(foundIt->second && "py::function is defined"); 132 return foundIt->second; 133 } 134 } 135 136 // Not found. Load the dialect namespace. 137 loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); 138 139 // Attempt to find from the canonical map and cache. 140 { 141 const auto foundIt = typeCasterMap.find(mlirTypeID); 142 if (foundIt != typeCasterMap.end()) { 143 if (foundIt->second.is_none()) 144 return std::nullopt; 145 assert(foundIt->second && "py::object is defined"); 146 // Positive cache. 147 typeCasterMapCache[mlirTypeID] = foundIt->second; 148 return foundIt->second; 149 } 150 // Negative cache. 151 typeCasterMap[mlirTypeID] = py::none(); 152 return std::nullopt; 153 } 154 } 155 156 std::optional<py::object> 157 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 158 loadDialectModule(dialectNamespace); 159 // Fast match against the class map first (common case). 160 const auto foundIt = dialectClassMap.find(dialectNamespace); 161 if (foundIt != dialectClassMap.end()) { 162 if (foundIt->second.is_none()) 163 return std::nullopt; 164 assert(foundIt->second && "py::object is defined"); 165 return foundIt->second; 166 } 167 168 // Not found and loading did not yield a registration. Negative cache. 169 dialectClassMap[dialectNamespace] = py::none(); 170 return std::nullopt; 171 } 172 173 std::optional<pybind11::object> 174 PyGlobals::lookupOperationClass(llvm::StringRef operationName) { 175 { 176 auto foundIt = operationClassMapCache.find(operationName); 177 if (foundIt != operationClassMapCache.end()) { 178 if (foundIt->second.is_none()) 179 return std::nullopt; 180 assert(foundIt->second && "py::object is defined"); 181 return foundIt->second; 182 } 183 } 184 185 // Not found. Load the dialect namespace. 186 auto split = operationName.split('.'); 187 llvm::StringRef dialectNamespace = split.first; 188 loadDialectModule(dialectNamespace); 189 190 // Attempt to find from the canonical map and cache. 191 { 192 auto foundIt = operationClassMap.find(operationName); 193 if (foundIt != operationClassMap.end()) { 194 if (foundIt->second.is_none()) 195 return std::nullopt; 196 assert(foundIt->second && "py::object is defined"); 197 // Positive cache. 198 operationClassMapCache[operationName] = foundIt->second; 199 return foundIt->second; 200 } 201 // Negative cache. 202 operationClassMap[operationName] = py::none(); 203 return std::nullopt; 204 } 205 } 206 207 void PyGlobals::clearImportCache() { 208 loadedDialectModulesCache.clear(); 209 operationClassMapCache.clear(); 210 typeCasterMapCache.clear(); 211 } 212