1 //===- IRModule.cpp - IR pybind module ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include <nanobind/nanobind.h> 12 #include <nanobind/stl/string.h> 13 14 #include <optional> 15 #include <vector> 16 17 #include "Globals.h" 18 #include "NanobindUtils.h" 19 #include "mlir-c/Bindings/Python/Interop.h" 20 #include "mlir-c/Support.h" 21 22 namespace nb = nanobind; 23 using namespace mlir; 24 using namespace mlir::python; 25 26 // ----------------------------------------------------------------------------- 27 // PyGlobals 28 // ----------------------------------------------------------------------------- 29 30 PyGlobals *PyGlobals::instance = nullptr; 31 32 PyGlobals::PyGlobals() { 33 assert(!instance && "PyGlobals already constructed"); 34 instance = this; 35 // The default search path include {mlir.}dialects, where {mlir.} is the 36 // package prefix configured at compile time. 37 dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); 38 } 39 40 PyGlobals::~PyGlobals() { instance = nullptr; } 41 42 bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 43 if (loadedDialectModules.contains(dialectNamespace)) 44 return true; 45 // Since re-entrancy is possible, make a copy of the search prefixes. 46 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 47 nb::object loaded = nb::none(); 48 for (std::string moduleName : localSearchPrefixes) { 49 moduleName.push_back('.'); 50 moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 51 52 try { 53 loaded = nb::module_::import_(moduleName.c_str()); 54 } catch (nb::python_error &e) { 55 if (e.matches(PyExc_ModuleNotFoundError)) { 56 continue; 57 } 58 throw; 59 } 60 break; 61 } 62 63 if (loaded.is_none()) 64 return false; 65 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 66 // may have occurred, which may do anything. 67 loadedDialectModules.insert(dialectNamespace); 68 return true; 69 } 70 71 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, 72 nb::callable pyFunc, bool replace) { 73 nb::object &found = attributeBuilderMap[attributeKind]; 74 if (found && !replace) { 75 throw std::runtime_error((llvm::Twine("Attribute builder for '") + 76 attributeKind + 77 "' is already registered with func: " + 78 nb::cast<std::string>(nb::str(found))) 79 .str()); 80 } 81 found = std::move(pyFunc); 82 } 83 84 void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, 85 nb::callable typeCaster, bool replace) { 86 nb::object &found = typeCasterMap[mlirTypeID]; 87 if (found && !replace) 88 throw std::runtime_error("Type caster is already registered with caster: " + 89 nb::cast<std::string>(nb::str(found))); 90 found = std::move(typeCaster); 91 } 92 93 void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, 94 nb::callable valueCaster, bool replace) { 95 nb::object &found = valueCasterMap[mlirTypeID]; 96 if (found && !replace) 97 throw std::runtime_error("Value caster is already registered: " + 98 nb::cast<std::string>(nb::repr(found))); 99 found = std::move(valueCaster); 100 } 101 102 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 103 nb::object pyClass) { 104 nb::object &found = dialectClassMap[dialectNamespace]; 105 if (found) { 106 throw std::runtime_error((llvm::Twine("Dialect namespace '") + 107 dialectNamespace + "' is already registered.") 108 .str()); 109 } 110 found = std::move(pyClass); 111 } 112 113 void PyGlobals::registerOperationImpl(const std::string &operationName, 114 nb::object pyClass, bool replace) { 115 nb::object &found = operationClassMap[operationName]; 116 if (found && !replace) { 117 throw std::runtime_error((llvm::Twine("Operation '") + operationName + 118 "' is already registered.") 119 .str()); 120 } 121 found = std::move(pyClass); 122 } 123 124 std::optional<nb::callable> 125 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { 126 const auto foundIt = attributeBuilderMap.find(attributeKind); 127 if (foundIt != attributeBuilderMap.end()) { 128 assert(foundIt->second && "attribute builder is defined"); 129 return foundIt->second; 130 } 131 return std::nullopt; 132 } 133 134 std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, 135 MlirDialect dialect) { 136 // Try to load dialect module. 137 (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); 138 const auto foundIt = typeCasterMap.find(mlirTypeID); 139 if (foundIt != typeCasterMap.end()) { 140 assert(foundIt->second && "type caster is defined"); 141 return foundIt->second; 142 } 143 return std::nullopt; 144 } 145 146 std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, 147 MlirDialect dialect) { 148 // Try to load dialect module. 149 (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); 150 const auto foundIt = valueCasterMap.find(mlirTypeID); 151 if (foundIt != valueCasterMap.end()) { 152 assert(foundIt->second && "value caster is defined"); 153 return foundIt->second; 154 } 155 return std::nullopt; 156 } 157 158 std::optional<nb::object> 159 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 160 // Make sure dialect module is loaded. 161 if (!loadDialectModule(dialectNamespace)) 162 return std::nullopt; 163 const auto foundIt = dialectClassMap.find(dialectNamespace); 164 if (foundIt != dialectClassMap.end()) { 165 assert(foundIt->second && "dialect class is defined"); 166 return foundIt->second; 167 } 168 // Not found and loading did not yield a registration. 169 return std::nullopt; 170 } 171 172 std::optional<nb::object> 173 PyGlobals::lookupOperationClass(llvm::StringRef operationName) { 174 // Make sure dialect module is loaded. 175 auto split = operationName.split('.'); 176 llvm::StringRef dialectNamespace = split.first; 177 if (!loadDialectModule(dialectNamespace)) 178 return std::nullopt; 179 180 auto foundIt = operationClassMap.find(operationName); 181 if (foundIt != operationClassMap.end()) { 182 assert(foundIt->second && "OpView is defined"); 183 return foundIt->second; 184 } 185 // Not found and loading did not yield a registration. 186 return std::nullopt; 187 } 188