1 //===- IRModule.cpp - IR pybind module ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 #include "Globals.h" 11 #include "PybindUtils.h" 12 13 #include <vector> 14 #include <optional> 15 16 #include "mlir-c/Bindings/Python/Interop.h" 17 18 namespace py = pybind11; 19 using namespace mlir; 20 using namespace mlir::python; 21 22 // ----------------------------------------------------------------------------- 23 // PyGlobals 24 // ----------------------------------------------------------------------------- 25 26 PyGlobals *PyGlobals::instance = nullptr; 27 28 PyGlobals::PyGlobals() { 29 assert(!instance && "PyGlobals already constructed"); 30 instance = this; 31 // The default search path include {mlir.}dialects, where {mlir.} is the 32 // package prefix configured at compile time. 33 dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); 34 } 35 36 PyGlobals::~PyGlobals() { instance = nullptr; } 37 38 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 39 if (loadedDialectModulesCache.contains(dialectNamespace)) 40 return; 41 // Since re-entrancy is possible, make a copy of the search prefixes. 42 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 43 py::object loaded; 44 for (std::string moduleName : localSearchPrefixes) { 45 moduleName.push_back('.'); 46 moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 47 48 try { 49 loaded = py::module::import(moduleName.c_str()); 50 } catch (py::error_already_set &e) { 51 if (e.matches(PyExc_ModuleNotFoundError)) { 52 continue; 53 } 54 throw; 55 } 56 break; 57 } 58 59 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 60 // may have occurred, which may do anything. 61 loadedDialectModulesCache.insert(dialectNamespace); 62 } 63 64 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, 65 py::function pyFunc) { 66 py::object &found = attributeBuilderMap[attributeKind]; 67 if (found) { 68 throw std::runtime_error((llvm::Twine("Attribute builder for '") + 69 attributeKind + "' is already registered") 70 .str()); 71 } 72 found = std::move(pyFunc); 73 } 74 75 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 76 py::object pyClass) { 77 py::object &found = dialectClassMap[dialectNamespace]; 78 if (found) { 79 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + 80 dialectNamespace + 81 "' is already registered."); 82 } 83 found = std::move(pyClass); 84 } 85 86 void PyGlobals::registerOperationImpl(const std::string &operationName, 87 py::object pyClass, 88 py::object rawOpViewClass) { 89 py::object &found = operationClassMap[operationName]; 90 if (found) { 91 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + 92 operationName + 93 "' is already registered."); 94 } 95 found = std::move(pyClass); 96 rawOpViewClassMap[operationName] = std::move(rawOpViewClass); 97 } 98 99 std::optional<py::function> 100 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { 101 // Fast match against the class map first (common case). 102 const auto foundIt = attributeBuilderMap.find(attributeKind); 103 if (foundIt != attributeBuilderMap.end()) { 104 if (foundIt->second.is_none()) 105 return std::nullopt; 106 assert(foundIt->second && "py::function is defined"); 107 return foundIt->second; 108 } 109 110 // Not found and loading did not yield a registration. Negative cache. 111 attributeBuilderMap[attributeKind] = py::none(); 112 return std::nullopt; 113 } 114 115 std::optional<py::object> 116 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 117 loadDialectModule(dialectNamespace); 118 // Fast match against the class map first (common case). 119 const auto foundIt = dialectClassMap.find(dialectNamespace); 120 if (foundIt != dialectClassMap.end()) { 121 if (foundIt->second.is_none()) 122 return std::nullopt; 123 assert(foundIt->second && "py::object is defined"); 124 return foundIt->second; 125 } 126 127 // Not found and loading did not yield a registration. Negative cache. 128 dialectClassMap[dialectNamespace] = py::none(); 129 return std::nullopt; 130 } 131 132 std::optional<pybind11::object> 133 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { 134 { 135 auto foundIt = rawOpViewClassMapCache.find(operationName); 136 if (foundIt != rawOpViewClassMapCache.end()) { 137 if (foundIt->second.is_none()) 138 return std::nullopt; 139 assert(foundIt->second && "py::object is defined"); 140 return foundIt->second; 141 } 142 } 143 144 // Not found. Load the dialect namespace. 145 auto split = operationName.split('.'); 146 llvm::StringRef dialectNamespace = split.first; 147 loadDialectModule(dialectNamespace); 148 149 // Attempt to find from the canonical map and cache. 150 { 151 auto foundIt = rawOpViewClassMap.find(operationName); 152 if (foundIt != rawOpViewClassMap.end()) { 153 if (foundIt->second.is_none()) 154 return std::nullopt; 155 assert(foundIt->second && "py::object is defined"); 156 // Positive cache. 157 rawOpViewClassMapCache[operationName] = foundIt->second; 158 return foundIt->second; 159 } 160 // Negative cache. 161 rawOpViewClassMap[operationName] = py::none(); 162 return std::nullopt; 163 } 164 } 165 166 void PyGlobals::clearImportCache() { 167 loadedDialectModulesCache.clear(); 168 rawOpViewClassMapCache.clear(); 169 } 170