1 //===- IRModule.cpp - IR pybind module ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 #include "Globals.h" 11 #include "PybindUtils.h" 12 13 #include "mlir-c/Bindings/Python/Interop.h" 14 #include "mlir-c/Support.h" 15 16 #include <optional> 17 #include <vector> 18 19 namespace py = pybind11; 20 using namespace mlir; 21 using namespace mlir::python; 22 23 // ----------------------------------------------------------------------------- 24 // PyGlobals 25 // ----------------------------------------------------------------------------- 26 27 PyGlobals *PyGlobals::instance = nullptr; 28 29 PyGlobals::PyGlobals() { 30 assert(!instance && "PyGlobals already constructed"); 31 instance = this; 32 // The default search path include {mlir.}dialects, where {mlir.} is the 33 // package prefix configured at compile time. 34 dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); 35 } 36 37 PyGlobals::~PyGlobals() { instance = nullptr; } 38 39 bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 40 if (loadedDialectModules.contains(dialectNamespace)) 41 return true; 42 // Since re-entrancy is possible, make a copy of the search prefixes. 43 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 44 py::object loaded = py::none(); 45 for (std::string moduleName : localSearchPrefixes) { 46 moduleName.push_back('.'); 47 moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 48 49 try { 50 loaded = py::module::import(moduleName.c_str()); 51 } catch (py::error_already_set &e) { 52 if (e.matches(PyExc_ModuleNotFoundError)) { 53 continue; 54 } 55 throw; 56 } 57 break; 58 } 59 60 if (loaded.is_none()) 61 return false; 62 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 63 // may have occurred, which may do anything. 64 loadedDialectModules.insert(dialectNamespace); 65 return true; 66 } 67 68 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, 69 py::function pyFunc, bool replace) { 70 py::object &found = attributeBuilderMap[attributeKind]; 71 if (found && !replace) { 72 throw std::runtime_error((llvm::Twine("Attribute builder for '") + 73 attributeKind + 74 "' is already registered with func: " + 75 py::str(found).operator std::string()) 76 .str()); 77 } 78 found = std::move(pyFunc); 79 } 80 81 void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, 82 pybind11::function typeCaster, 83 bool replace) { 84 pybind11::object &found = typeCasterMap[mlirTypeID]; 85 if (found && !replace) 86 throw std::runtime_error("Type caster is already registered with caster: " + 87 py::str(found).operator std::string()); 88 found = std::move(typeCaster); 89 } 90 91 void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, 92 pybind11::function valueCaster, 93 bool replace) { 94 pybind11::object &found = valueCasterMap[mlirTypeID]; 95 if (found && !replace) 96 throw std::runtime_error("Value caster is already registered: " + 97 py::repr(found).cast<std::string>()); 98 found = std::move(valueCaster); 99 } 100 101 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 102 py::object pyClass) { 103 py::object &found = dialectClassMap[dialectNamespace]; 104 if (found) { 105 throw std::runtime_error((llvm::Twine("Dialect namespace '") + 106 dialectNamespace + "' is already registered.") 107 .str()); 108 } 109 found = std::move(pyClass); 110 } 111 112 void PyGlobals::registerOperationImpl(const std::string &operationName, 113 py::object pyClass, bool replace) { 114 py::object &found = operationClassMap[operationName]; 115 if (found && !replace) { 116 throw std::runtime_error((llvm::Twine("Operation '") + operationName + 117 "' is already registered.") 118 .str()); 119 } 120 found = std::move(pyClass); 121 } 122 123 std::optional<py::function> 124 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { 125 const auto foundIt = attributeBuilderMap.find(attributeKind); 126 if (foundIt != attributeBuilderMap.end()) { 127 assert(foundIt->second && "attribute builder is defined"); 128 return foundIt->second; 129 } 130 return std::nullopt; 131 } 132 133 std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, 134 MlirDialect dialect) { 135 // Try to load dialect module. 136 (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); 137 const auto foundIt = typeCasterMap.find(mlirTypeID); 138 if (foundIt != typeCasterMap.end()) { 139 assert(foundIt->second && "type caster is defined"); 140 return foundIt->second; 141 } 142 return std::nullopt; 143 } 144 145 std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, 146 MlirDialect dialect) { 147 // Try to load dialect module. 148 (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); 149 const auto foundIt = valueCasterMap.find(mlirTypeID); 150 if (foundIt != valueCasterMap.end()) { 151 assert(foundIt->second && "value caster is defined"); 152 return foundIt->second; 153 } 154 return std::nullopt; 155 } 156 157 std::optional<py::object> 158 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 159 // Make sure dialect module is loaded. 160 if (!loadDialectModule(dialectNamespace)) 161 return std::nullopt; 162 const auto foundIt = dialectClassMap.find(dialectNamespace); 163 if (foundIt != dialectClassMap.end()) { 164 assert(foundIt->second && "dialect class is defined"); 165 return foundIt->second; 166 } 167 // Not found and loading did not yield a registration. 168 return std::nullopt; 169 } 170 171 std::optional<pybind11::object> 172 PyGlobals::lookupOperationClass(llvm::StringRef operationName) { 173 // Make sure dialect module is loaded. 174 auto split = operationName.split('.'); 175 llvm::StringRef dialectNamespace = split.first; 176 if (!loadDialectModule(dialectNamespace)) 177 return std::nullopt; 178 179 auto foundIt = operationClassMap.find(operationName); 180 if (foundIt != operationClassMap.end()) { 181 assert(foundIt->second && "OpView is defined"); 182 return foundIt->second; 183 } 184 // Not found and loading did not yield a registration. 185 return std::nullopt; 186 } 187