//===- IRModule.cpp - IR pybind module ------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "IRModule.h" #include #include #include "Globals.h" #include "NanobindUtils.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. namespace nb = nanobind; using namespace mlir; using namespace mlir::python; // ----------------------------------------------------------------------------- // PyGlobals // ----------------------------------------------------------------------------- PyGlobals *PyGlobals::instance = nullptr; PyGlobals::PyGlobals() { assert(!instance && "PyGlobals already constructed"); instance = this; // The default search path include {mlir.}dialects, where {mlir.} is the // package prefix configured at compile time. dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); } PyGlobals::~PyGlobals() { instance = nullptr; } bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { { nb::ft_lock_guard lock(mutex); if (loadedDialectModules.contains(dialectNamespace)) return true; } // Since re-entrancy is possible, make a copy of the search prefixes. std::vector localSearchPrefixes = dialectSearchPrefixes; nb::object loaded = nb::none(); for (std::string moduleName : localSearchPrefixes) { moduleName.push_back('.'); moduleName.append(dialectNamespace.data(), dialectNamespace.size()); try { loaded = nb::module_::import_(moduleName.c_str()); } catch (nb::python_error &e) { if (e.matches(PyExc_ModuleNotFoundError)) { continue; } throw; } break; } if (loaded.is_none()) return false; // Note: Iterator cannot be shared from prior to loading, since re-entrancy // may have occurred, which may do anything. nb::ft_lock_guard lock(mutex); loadedDialectModules.insert(dialectNamespace); return true; } void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, nb::callable pyFunc, bool replace) { nb::ft_lock_guard lock(mutex); nb::object &found = attributeBuilderMap[attributeKind]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Attribute builder for '") + attributeKind + "' is already registered with func: " + nb::cast(nb::str(found))) .str()); } found = std::move(pyFunc); } void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, nb::callable typeCaster, bool replace) { nb::ft_lock_guard lock(mutex); nb::object &found = typeCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Type caster is already registered with caster: " + nb::cast(nb::str(found))); found = std::move(typeCaster); } void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, nb::callable valueCaster, bool replace) { nb::ft_lock_guard lock(mutex); nb::object &found = valueCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Value caster is already registered: " + nb::cast(nb::repr(found))); found = std::move(valueCaster); } void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, nb::object pyClass) { nb::ft_lock_guard lock(mutex); nb::object &found = dialectClassMap[dialectNamespace]; if (found) { throw std::runtime_error((llvm::Twine("Dialect namespace '") + dialectNamespace + "' is already registered.") .str()); } found = std::move(pyClass); } void PyGlobals::registerOperationImpl(const std::string &operationName, nb::object pyClass, bool replace) { nb::ft_lock_guard lock(mutex); nb::object &found = operationClassMap[operationName]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Operation '") + operationName + "' is already registered.") .str()); } found = std::move(pyClass); } std::optional PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { nb::ft_lock_guard lock(mutex); const auto foundIt = attributeBuilderMap.find(attributeKind); if (foundIt != attributeBuilderMap.end()) { assert(foundIt->second && "attribute builder is defined"); return foundIt->second; } return std::nullopt; } std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); nb::ft_lock_guard lock(mutex); const auto foundIt = typeCasterMap.find(mlirTypeID); if (foundIt != typeCasterMap.end()) { assert(foundIt->second && "type caster is defined"); return foundIt->second; } return std::nullopt; } std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); nb::ft_lock_guard lock(mutex); const auto foundIt = valueCasterMap.find(mlirTypeID); if (foundIt != valueCasterMap.end()) { assert(foundIt->second && "value caster is defined"); return foundIt->second; } return std::nullopt; } std::optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { // Make sure dialect module is loaded. if (!loadDialectModule(dialectNamespace)) return std::nullopt; nb::ft_lock_guard lock(mutex); const auto foundIt = dialectClassMap.find(dialectNamespace); if (foundIt != dialectClassMap.end()) { assert(foundIt->second && "dialect class is defined"); return foundIt->second; } // Not found and loading did not yield a registration. return std::nullopt; } std::optional PyGlobals::lookupOperationClass(llvm::StringRef operationName) { // Make sure dialect module is loaded. auto split = operationName.split('.'); llvm::StringRef dialectNamespace = split.first; if (!loadDialectModule(dialectNamespace)) return std::nullopt; nb::ft_lock_guard lock(mutex); auto foundIt = operationClassMap.find(operationName); if (foundIt != operationClassMap.end()) { assert(foundIt->second && "OpView is defined"); return foundIt->second; } // Not found and loading did not yield a registration. return std::nullopt; }