1 //===- IRModule.cpp - IR pybind module ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 #include "Globals.h" 11 #include "PybindUtils.h" 12 13 #include <optional> 14 #include <vector> 15 16 #include "mlir-c/Bindings/Python/Interop.h" 17 #include "mlir-c/Support.h" 18 19 namespace py = pybind11; 20 using namespace mlir; 21 using namespace mlir::python; 22 23 // ----------------------------------------------------------------------------- 24 // PyGlobals 25 // ----------------------------------------------------------------------------- 26 27 PyGlobals *PyGlobals::instance = nullptr; 28 29 PyGlobals::PyGlobals() { 30 assert(!instance && "PyGlobals already constructed"); 31 instance = this; 32 // The default search path include {mlir.}dialects, where {mlir.} is the 33 // package prefix configured at compile time. 34 dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); 35 } 36 37 PyGlobals::~PyGlobals() { instance = nullptr; } 38 39 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 40 if (loadedDialectModulesCache.contains(dialectNamespace)) 41 return; 42 // Since re-entrancy is possible, make a copy of the search prefixes. 43 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 44 py::object loaded; 45 for (std::string moduleName : localSearchPrefixes) { 46 moduleName.push_back('.'); 47 moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 48 49 try { 50 loaded = py::module::import(moduleName.c_str()); 51 } catch (py::error_already_set &e) { 52 if (e.matches(PyExc_ModuleNotFoundError)) { 53 continue; 54 } 55 throw; 56 } 57 break; 58 } 59 60 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 61 // may have occurred, which may do anything. 62 loadedDialectModulesCache.insert(dialectNamespace); 63 } 64 65 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, 66 py::function pyFunc, bool replace) { 67 py::object &found = attributeBuilderMap[attributeKind]; 68 if (found && !found.is_none() && !replace) { 69 throw std::runtime_error((llvm::Twine("Attribute builder for '") + 70 attributeKind + 71 "' is already registered with func: " + 72 py::str(found).operator std::string()) 73 .str()); 74 } 75 found = std::move(pyFunc); 76 } 77 78 void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, 79 pybind11::function typeCaster, 80 bool replace) { 81 pybind11::object &found = typeCasterMap[mlirTypeID]; 82 if (found && !found.is_none() && !replace) 83 throw std::runtime_error("Type caster is already registered"); 84 found = std::move(typeCaster); 85 } 86 87 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 88 py::object pyClass) { 89 py::object &found = dialectClassMap[dialectNamespace]; 90 if (found) { 91 throw std::runtime_error((llvm::Twine("Dialect namespace '") + 92 dialectNamespace + "' is already registered.") 93 .str()); 94 } 95 found = std::move(pyClass); 96 } 97 98 void PyGlobals::registerOperationImpl(const std::string &operationName, 99 py::object pyClass) { 100 py::object &found = operationClassMap[operationName]; 101 if (found) { 102 throw std::runtime_error((llvm::Twine("Operation '") + operationName + 103 "' is already registered.") 104 .str()); 105 } 106 found = std::move(pyClass); 107 } 108 109 std::optional<py::function> 110 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { 111 // Fast match against the class map first (common case). 112 const auto foundIt = attributeBuilderMap.find(attributeKind); 113 if (foundIt != attributeBuilderMap.end()) { 114 if (foundIt->second.is_none()) 115 return std::nullopt; 116 assert(foundIt->second && "py::function is defined"); 117 return foundIt->second; 118 } 119 120 // Not found and loading did not yield a registration. Negative cache. 121 attributeBuilderMap[attributeKind] = py::none(); 122 return std::nullopt; 123 } 124 125 std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, 126 MlirDialect dialect) { 127 { 128 // Fast match against the class map first (common case). 129 const auto foundIt = typeCasterMapCache.find(mlirTypeID); 130 if (foundIt != typeCasterMapCache.end()) { 131 if (foundIt->second.is_none()) 132 return std::nullopt; 133 assert(foundIt->second && "py::function is defined"); 134 return foundIt->second; 135 } 136 } 137 138 // Not found. Load the dialect namespace. 139 loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); 140 141 // Attempt to find from the canonical map and cache. 142 { 143 const auto foundIt = typeCasterMap.find(mlirTypeID); 144 if (foundIt != typeCasterMap.end()) { 145 if (foundIt->second.is_none()) 146 return std::nullopt; 147 assert(foundIt->second && "py::object is defined"); 148 // Positive cache. 149 typeCasterMapCache[mlirTypeID] = foundIt->second; 150 return foundIt->second; 151 } 152 // Negative cache. 153 typeCasterMap[mlirTypeID] = py::none(); 154 return std::nullopt; 155 } 156 } 157 158 std::optional<py::object> 159 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 160 loadDialectModule(dialectNamespace); 161 // Fast match against the class map first (common case). 162 const auto foundIt = dialectClassMap.find(dialectNamespace); 163 if (foundIt != dialectClassMap.end()) { 164 if (foundIt->second.is_none()) 165 return std::nullopt; 166 assert(foundIt->second && "py::object is defined"); 167 return foundIt->second; 168 } 169 170 // Not found and loading did not yield a registration. Negative cache. 171 dialectClassMap[dialectNamespace] = py::none(); 172 return std::nullopt; 173 } 174 175 std::optional<pybind11::object> 176 PyGlobals::lookupOperationClass(llvm::StringRef operationName) { 177 { 178 auto foundIt = operationClassMapCache.find(operationName); 179 if (foundIt != operationClassMapCache.end()) { 180 if (foundIt->second.is_none()) 181 return std::nullopt; 182 assert(foundIt->second && "py::object is defined"); 183 return foundIt->second; 184 } 185 } 186 187 // Not found. Load the dialect namespace. 188 auto split = operationName.split('.'); 189 llvm::StringRef dialectNamespace = split.first; 190 loadDialectModule(dialectNamespace); 191 192 // Attempt to find from the canonical map and cache. 193 { 194 auto foundIt = operationClassMap.find(operationName); 195 if (foundIt != operationClassMap.end()) { 196 if (foundIt->second.is_none()) 197 return std::nullopt; 198 assert(foundIt->second && "py::object is defined"); 199 // Positive cache. 200 operationClassMapCache[operationName] = foundIt->second; 201 return foundIt->second; 202 } 203 // Negative cache. 204 operationClassMap[operationName] = py::none(); 205 return std::nullopt; 206 } 207 } 208 209 void PyGlobals::clearImportCache() { 210 loadedDialectModulesCache.clear(); 211 operationClassMapCache.clear(); 212 typeCasterMapCache.clear(); 213 } 214