1 //===- IRModule.cpp - IR pybind module ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 #include "Globals.h" 11 #include "PybindUtils.h" 12 13 #include <vector> 14 15 #include "mlir-c/Bindings/Python/Interop.h" 16 17 namespace py = pybind11; 18 using namespace mlir; 19 using namespace mlir::python; 20 21 // ----------------------------------------------------------------------------- 22 // PyGlobals 23 // ----------------------------------------------------------------------------- 24 25 PyGlobals *PyGlobals::instance = nullptr; 26 27 PyGlobals::PyGlobals() { 28 assert(!instance && "PyGlobals already constructed"); 29 instance = this; 30 // The default search path include {mlir.}dialects, where {mlir.} is the 31 // package prefix configured at compile time. 32 dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); 33 } 34 35 PyGlobals::~PyGlobals() { instance = nullptr; } 36 37 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 38 if (loadedDialectModulesCache.contains(dialectNamespace)) 39 return; 40 // Since re-entrancy is possible, make a copy of the search prefixes. 41 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 42 py::object loaded; 43 for (std::string moduleName : localSearchPrefixes) { 44 moduleName.push_back('.'); 45 moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 46 47 try { 48 loaded = py::module::import(moduleName.c_str()); 49 } catch (py::error_already_set &e) { 50 if (e.matches(PyExc_ModuleNotFoundError)) { 51 continue; 52 } 53 throw; 54 } 55 break; 56 } 57 58 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 59 // may have occurred, which may do anything. 60 loadedDialectModulesCache.insert(dialectNamespace); 61 } 62 63 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, 64 py::function pyFunc) { 65 py::function &found = attributeBuilderMap[attributeKind]; 66 if (found) { 67 throw std::runtime_error((llvm::Twine("Attribute builder for '") + 68 attributeKind + "' is already registered") 69 .str()); 70 } 71 found = std::move(pyFunc); 72 } 73 74 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 75 py::object pyClass) { 76 py::object &found = dialectClassMap[dialectNamespace]; 77 if (found) { 78 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + 79 dialectNamespace + 80 "' is already registered."); 81 } 82 found = std::move(pyClass); 83 } 84 85 void PyGlobals::registerOperationImpl(const std::string &operationName, 86 py::object pyClass, 87 py::object rawOpViewClass) { 88 py::object &found = operationClassMap[operationName]; 89 if (found) { 90 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + 91 operationName + 92 "' is already registered."); 93 } 94 found = std::move(pyClass); 95 rawOpViewClassMap[operationName] = std::move(rawOpViewClass); 96 } 97 98 std::optional<py::function> 99 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { 100 // Fast match against the class map first (common case). 101 const auto foundIt = attributeBuilderMap.find(attributeKind); 102 if (foundIt != attributeBuilderMap.end()) { 103 if (foundIt->second.is_none()) 104 return std::nullopt; 105 assert(foundIt->second && "py::function is defined"); 106 return foundIt->second; 107 } 108 109 // Not found and loading did not yield a registration. Negative cache. 110 attributeBuilderMap[attributeKind] = py::none(); 111 return std::nullopt; 112 } 113 114 llvm::Optional<py::object> 115 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 116 loadDialectModule(dialectNamespace); 117 // Fast match against the class map first (common case). 118 const auto foundIt = dialectClassMap.find(dialectNamespace); 119 if (foundIt != dialectClassMap.end()) { 120 if (foundIt->second.is_none()) 121 return std::nullopt; 122 assert(foundIt->second && "py::object is defined"); 123 return foundIt->second; 124 } 125 126 // Not found and loading did not yield a registration. Negative cache. 127 dialectClassMap[dialectNamespace] = py::none(); 128 return std::nullopt; 129 } 130 131 llvm::Optional<pybind11::object> 132 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { 133 { 134 auto foundIt = rawOpViewClassMapCache.find(operationName); 135 if (foundIt != rawOpViewClassMapCache.end()) { 136 if (foundIt->second.is_none()) 137 return std::nullopt; 138 assert(foundIt->second && "py::object is defined"); 139 return foundIt->second; 140 } 141 } 142 143 // Not found. Load the dialect namespace. 144 auto split = operationName.split('.'); 145 llvm::StringRef dialectNamespace = split.first; 146 loadDialectModule(dialectNamespace); 147 148 // Attempt to find from the canonical map and cache. 149 { 150 auto foundIt = rawOpViewClassMap.find(operationName); 151 if (foundIt != rawOpViewClassMap.end()) { 152 if (foundIt->second.is_none()) 153 return std::nullopt; 154 assert(foundIt->second && "py::object is defined"); 155 // Positive cache. 156 rawOpViewClassMapCache[operationName] = foundIt->second; 157 return foundIt->second; 158 } 159 // Negative cache. 160 rawOpViewClassMap[operationName] = py::none(); 161 return std::nullopt; 162 } 163 } 164 165 void PyGlobals::clearImportCache() { 166 loadedDialectModulesCache.clear(); 167 rawOpViewClassMapCache.clear(); 168 } 169