1 //===- IRModule.cpp - IR pybind module ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include <optional> 12 #include <vector> 13 14 #include "Globals.h" 15 #include "NanobindUtils.h" 16 #include "mlir-c/Support.h" 17 #include "mlir/Bindings/Python/Nanobind.h" 18 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. 19 20 namespace nb = nanobind; 21 using namespace mlir; 22 using namespace mlir::python; 23 24 // ----------------------------------------------------------------------------- 25 // PyGlobals 26 // ----------------------------------------------------------------------------- 27 28 PyGlobals *PyGlobals::instance = nullptr; 29 30 PyGlobals::PyGlobals() { 31 assert(!instance && "PyGlobals already constructed"); 32 instance = this; 33 // The default search path include {mlir.}dialects, where {mlir.} is the 34 // package prefix configured at compile time. 35 dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); 36 } 37 38 PyGlobals::~PyGlobals() { instance = nullptr; } 39 40 bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 41 { 42 nb::ft_lock_guard lock(mutex); 43 if (loadedDialectModules.contains(dialectNamespace)) 44 return true; 45 } 46 // Since re-entrancy is possible, make a copy of the search prefixes. 47 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 48 nb::object loaded = nb::none(); 49 for (std::string moduleName : localSearchPrefixes) { 50 moduleName.push_back('.'); 51 moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 52 53 try { 54 loaded = nb::module_::import_(moduleName.c_str()); 55 } catch (nb::python_error &e) { 56 if (e.matches(PyExc_ModuleNotFoundError)) { 57 continue; 58 } 59 throw; 60 } 61 break; 62 } 63 64 if (loaded.is_none()) 65 return false; 66 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 67 // may have occurred, which may do anything. 68 nb::ft_lock_guard lock(mutex); 69 loadedDialectModules.insert(dialectNamespace); 70 return true; 71 } 72 73 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, 74 nb::callable pyFunc, bool replace) { 75 nb::ft_lock_guard lock(mutex); 76 nb::object &found = attributeBuilderMap[attributeKind]; 77 if (found && !replace) { 78 throw std::runtime_error((llvm::Twine("Attribute builder for '") + 79 attributeKind + 80 "' is already registered with func: " + 81 nb::cast<std::string>(nb::str(found))) 82 .str()); 83 } 84 found = std::move(pyFunc); 85 } 86 87 void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, 88 nb::callable typeCaster, bool replace) { 89 nb::ft_lock_guard lock(mutex); 90 nb::object &found = typeCasterMap[mlirTypeID]; 91 if (found && !replace) 92 throw std::runtime_error("Type caster is already registered with caster: " + 93 nb::cast<std::string>(nb::str(found))); 94 found = std::move(typeCaster); 95 } 96 97 void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, 98 nb::callable valueCaster, bool replace) { 99 nb::ft_lock_guard lock(mutex); 100 nb::object &found = valueCasterMap[mlirTypeID]; 101 if (found && !replace) 102 throw std::runtime_error("Value caster is already registered: " + 103 nb::cast<std::string>(nb::repr(found))); 104 found = std::move(valueCaster); 105 } 106 107 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 108 nb::object pyClass) { 109 nb::ft_lock_guard lock(mutex); 110 nb::object &found = dialectClassMap[dialectNamespace]; 111 if (found) { 112 throw std::runtime_error((llvm::Twine("Dialect namespace '") + 113 dialectNamespace + "' is already registered.") 114 .str()); 115 } 116 found = std::move(pyClass); 117 } 118 119 void PyGlobals::registerOperationImpl(const std::string &operationName, 120 nb::object pyClass, bool replace) { 121 nb::ft_lock_guard lock(mutex); 122 nb::object &found = operationClassMap[operationName]; 123 if (found && !replace) { 124 throw std::runtime_error((llvm::Twine("Operation '") + operationName + 125 "' is already registered.") 126 .str()); 127 } 128 found = std::move(pyClass); 129 } 130 131 std::optional<nb::callable> 132 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { 133 nb::ft_lock_guard lock(mutex); 134 const auto foundIt = attributeBuilderMap.find(attributeKind); 135 if (foundIt != attributeBuilderMap.end()) { 136 assert(foundIt->second && "attribute builder is defined"); 137 return foundIt->second; 138 } 139 return std::nullopt; 140 } 141 142 std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, 143 MlirDialect dialect) { 144 // Try to load dialect module. 145 (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); 146 nb::ft_lock_guard lock(mutex); 147 const auto foundIt = typeCasterMap.find(mlirTypeID); 148 if (foundIt != typeCasterMap.end()) { 149 assert(foundIt->second && "type caster is defined"); 150 return foundIt->second; 151 } 152 return std::nullopt; 153 } 154 155 std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, 156 MlirDialect dialect) { 157 // Try to load dialect module. 158 (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); 159 nb::ft_lock_guard lock(mutex); 160 const auto foundIt = valueCasterMap.find(mlirTypeID); 161 if (foundIt != valueCasterMap.end()) { 162 assert(foundIt->second && "value caster is defined"); 163 return foundIt->second; 164 } 165 return std::nullopt; 166 } 167 168 std::optional<nb::object> 169 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 170 // Make sure dialect module is loaded. 171 if (!loadDialectModule(dialectNamespace)) 172 return std::nullopt; 173 nb::ft_lock_guard lock(mutex); 174 const auto foundIt = dialectClassMap.find(dialectNamespace); 175 if (foundIt != dialectClassMap.end()) { 176 assert(foundIt->second && "dialect class is defined"); 177 return foundIt->second; 178 } 179 // Not found and loading did not yield a registration. 180 return std::nullopt; 181 } 182 183 std::optional<nb::object> 184 PyGlobals::lookupOperationClass(llvm::StringRef operationName) { 185 // Make sure dialect module is loaded. 186 auto split = operationName.split('.'); 187 llvm::StringRef dialectNamespace = split.first; 188 if (!loadDialectModule(dialectNamespace)) 189 return std::nullopt; 190 191 nb::ft_lock_guard lock(mutex); 192 auto foundIt = operationClassMap.find(operationName); 193 if (foundIt != operationClassMap.end()) { 194 assert(foundIt->second && "OpView is defined"); 195 return foundIt->second; 196 } 197 // Not found and loading did not yield a registration. 198 return std::nullopt; 199 } 200