1 //===- IRModule.cpp - IR pybind module ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include <optional> 12 #include <vector> 13 14 #include "Globals.h" 15 #include "NanobindUtils.h" 16 #include "mlir-c/Support.h" 17 #include "mlir/Bindings/Python/Nanobind.h" 18 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. 19 20 namespace nb = nanobind; 21 using namespace mlir; 22 using namespace mlir::python; 23 24 // ----------------------------------------------------------------------------- 25 // PyGlobals 26 // ----------------------------------------------------------------------------- 27 28 PyGlobals *PyGlobals::instance = nullptr; 29 30 PyGlobals::PyGlobals() { 31 assert(!instance && "PyGlobals already constructed"); 32 instance = this; 33 // The default search path include {mlir.}dialects, where {mlir.} is the 34 // package prefix configured at compile time. 35 dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); 36 } 37 38 PyGlobals::~PyGlobals() { instance = nullptr; } 39 40 bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 41 if (loadedDialectModules.contains(dialectNamespace)) 42 return true; 43 // Since re-entrancy is possible, make a copy of the search prefixes. 44 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 45 nb::object loaded = nb::none(); 46 for (std::string moduleName : localSearchPrefixes) { 47 moduleName.push_back('.'); 48 moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 49 50 try { 51 loaded = nb::module_::import_(moduleName.c_str()); 52 } catch (nb::python_error &e) { 53 if (e.matches(PyExc_ModuleNotFoundError)) { 54 continue; 55 } 56 throw; 57 } 58 break; 59 } 60 61 if (loaded.is_none()) 62 return false; 63 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 64 // may have occurred, which may do anything. 65 loadedDialectModules.insert(dialectNamespace); 66 return true; 67 } 68 69 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, 70 nb::callable pyFunc, bool replace) { 71 nb::object &found = attributeBuilderMap[attributeKind]; 72 if (found && !replace) { 73 throw std::runtime_error((llvm::Twine("Attribute builder for '") + 74 attributeKind + 75 "' is already registered with func: " + 76 nb::cast<std::string>(nb::str(found))) 77 .str()); 78 } 79 found = std::move(pyFunc); 80 } 81 82 void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, 83 nb::callable typeCaster, bool replace) { 84 nb::object &found = typeCasterMap[mlirTypeID]; 85 if (found && !replace) 86 throw std::runtime_error("Type caster is already registered with caster: " + 87 nb::cast<std::string>(nb::str(found))); 88 found = std::move(typeCaster); 89 } 90 91 void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, 92 nb::callable valueCaster, bool replace) { 93 nb::object &found = valueCasterMap[mlirTypeID]; 94 if (found && !replace) 95 throw std::runtime_error("Value caster is already registered: " + 96 nb::cast<std::string>(nb::repr(found))); 97 found = std::move(valueCaster); 98 } 99 100 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 101 nb::object pyClass) { 102 nb::object &found = dialectClassMap[dialectNamespace]; 103 if (found) { 104 throw std::runtime_error((llvm::Twine("Dialect namespace '") + 105 dialectNamespace + "' is already registered.") 106 .str()); 107 } 108 found = std::move(pyClass); 109 } 110 111 void PyGlobals::registerOperationImpl(const std::string &operationName, 112 nb::object pyClass, bool replace) { 113 nb::object &found = operationClassMap[operationName]; 114 if (found && !replace) { 115 throw std::runtime_error((llvm::Twine("Operation '") + operationName + 116 "' is already registered.") 117 .str()); 118 } 119 found = std::move(pyClass); 120 } 121 122 std::optional<nb::callable> 123 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { 124 const auto foundIt = attributeBuilderMap.find(attributeKind); 125 if (foundIt != attributeBuilderMap.end()) { 126 assert(foundIt->second && "attribute builder is defined"); 127 return foundIt->second; 128 } 129 return std::nullopt; 130 } 131 132 std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, 133 MlirDialect dialect) { 134 // Try to load dialect module. 135 (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); 136 const auto foundIt = typeCasterMap.find(mlirTypeID); 137 if (foundIt != typeCasterMap.end()) { 138 assert(foundIt->second && "type caster is defined"); 139 return foundIt->second; 140 } 141 return std::nullopt; 142 } 143 144 std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, 145 MlirDialect dialect) { 146 // Try to load dialect module. 147 (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); 148 const auto foundIt = valueCasterMap.find(mlirTypeID); 149 if (foundIt != valueCasterMap.end()) { 150 assert(foundIt->second && "value caster is defined"); 151 return foundIt->second; 152 } 153 return std::nullopt; 154 } 155 156 std::optional<nb::object> 157 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 158 // Make sure dialect module is loaded. 159 if (!loadDialectModule(dialectNamespace)) 160 return std::nullopt; 161 const auto foundIt = dialectClassMap.find(dialectNamespace); 162 if (foundIt != dialectClassMap.end()) { 163 assert(foundIt->second && "dialect class is defined"); 164 return foundIt->second; 165 } 166 // Not found and loading did not yield a registration. 167 return std::nullopt; 168 } 169 170 std::optional<nb::object> 171 PyGlobals::lookupOperationClass(llvm::StringRef operationName) { 172 // Make sure dialect module is loaded. 173 auto split = operationName.split('.'); 174 llvm::StringRef dialectNamespace = split.first; 175 if (!loadDialectModule(dialectNamespace)) 176 return std::nullopt; 177 178 auto foundIt = operationClassMap.find(operationName); 179 if (foundIt != operationClassMap.end()) { 180 assert(foundIt->second && "OpView is defined"); 181 return foundIt->second; 182 } 183 // Not found and loading did not yield a registration. 184 return std::nullopt; 185 } 186