1ac0a70f3SAlex Zinenko //===- IRModule.cpp - IR pybind module ------------------------------------===// 2ac0a70f3SAlex Zinenko // 3ac0a70f3SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4ac0a70f3SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5ac0a70f3SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6ac0a70f3SAlex Zinenko // 7ac0a70f3SAlex Zinenko //===----------------------------------------------------------------------===// 8ac0a70f3SAlex Zinenko 9ac0a70f3SAlex Zinenko #include "IRModule.h" 10ac0a70f3SAlex Zinenko 115192e299SMaksim Levental #include <optional> 125192e299SMaksim Levental #include <vector> 135192e299SMaksim Levental 14b56d1ec6SPeter Hawkins #include "Globals.h" 15b56d1ec6SPeter Hawkins #include "NanobindUtils.h" 16b56d1ec6SPeter Hawkins #include "mlir-c/Support.h" 175cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h" 185cd42747SPeter Hawkins #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. 19b56d1ec6SPeter Hawkins 20b56d1ec6SPeter Hawkins namespace nb = nanobind; 21ac0a70f3SAlex Zinenko using namespace mlir; 22ac0a70f3SAlex Zinenko using namespace mlir::python; 23ac0a70f3SAlex Zinenko 24ac0a70f3SAlex Zinenko // ----------------------------------------------------------------------------- 25ac0a70f3SAlex Zinenko // PyGlobals 26ac0a70f3SAlex Zinenko // ----------------------------------------------------------------------------- 27ac0a70f3SAlex Zinenko 28ac0a70f3SAlex Zinenko PyGlobals *PyGlobals::instance = nullptr; 29ac0a70f3SAlex Zinenko 30ac0a70f3SAlex Zinenko PyGlobals::PyGlobals() { 31ac0a70f3SAlex Zinenko assert(!instance && "PyGlobals already constructed"); 32ac0a70f3SAlex Zinenko instance = this; 33cb7b0381SStella Laurenzo // The default search path include {mlir.}dialects, where {mlir.} is the 34cb7b0381SStella Laurenzo // package prefix configured at compile time. 35e5639b3fSMehdi Amini dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); 36ac0a70f3SAlex Zinenko } 37ac0a70f3SAlex Zinenko 38ac0a70f3SAlex Zinenko PyGlobals::~PyGlobals() { instance = nullptr; } 39ac0a70f3SAlex Zinenko 405192e299SMaksim Levental bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 41*f136c800Svfdev { 42*f136c800Svfdev nb::ft_lock_guard lock(mutex); 435192e299SMaksim Levental if (loadedDialectModules.contains(dialectNamespace)) 445192e299SMaksim Levental return true; 45*f136c800Svfdev } 46ac0a70f3SAlex Zinenko // Since re-entrancy is possible, make a copy of the search prefixes. 47ac0a70f3SAlex Zinenko std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 48b56d1ec6SPeter Hawkins nb::object loaded = nb::none(); 49ac0a70f3SAlex Zinenko for (std::string moduleName : localSearchPrefixes) { 50ac0a70f3SAlex Zinenko moduleName.push_back('.'); 51ac0a70f3SAlex Zinenko moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 52ac0a70f3SAlex Zinenko 53ac0a70f3SAlex Zinenko try { 54b56d1ec6SPeter Hawkins loaded = nb::module_::import_(moduleName.c_str()); 55b56d1ec6SPeter Hawkins } catch (nb::python_error &e) { 56ac0a70f3SAlex Zinenko if (e.matches(PyExc_ModuleNotFoundError)) { 57ac0a70f3SAlex Zinenko continue; 58ac0a70f3SAlex Zinenko } 5902b6fb21SMehdi Amini throw; 60ac0a70f3SAlex Zinenko } 61ac0a70f3SAlex Zinenko break; 62ac0a70f3SAlex Zinenko } 63ac0a70f3SAlex Zinenko 645192e299SMaksim Levental if (loaded.is_none()) 655192e299SMaksim Levental return false; 66ac0a70f3SAlex Zinenko // Note: Iterator cannot be shared from prior to loading, since re-entrancy 67ac0a70f3SAlex Zinenko // may have occurred, which may do anything. 68*f136c800Svfdev nb::ft_lock_guard lock(mutex); 695192e299SMaksim Levental loadedDialectModules.insert(dialectNamespace); 705192e299SMaksim Levental return true; 71ac0a70f3SAlex Zinenko } 72ac0a70f3SAlex Zinenko 73b57acb9aSJacques Pienaar void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, 74b56d1ec6SPeter Hawkins nb::callable pyFunc, bool replace) { 75*f136c800Svfdev nb::ft_lock_guard lock(mutex); 76b56d1ec6SPeter Hawkins nb::object &found = attributeBuilderMap[attributeKind]; 775192e299SMaksim Levental if (found && !replace) { 78b57acb9aSJacques Pienaar throw std::runtime_error((llvm::Twine("Attribute builder for '") + 7992233062Smax attributeKind + 8092233062Smax "' is already registered with func: " + 81b56d1ec6SPeter Hawkins nb::cast<std::string>(nb::str(found))) 82b57acb9aSJacques Pienaar .str()); 83b57acb9aSJacques Pienaar } 84b57acb9aSJacques Pienaar found = std::move(pyFunc); 85b57acb9aSJacques Pienaar } 86b57acb9aSJacques Pienaar 87bfb1ba75Smax void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, 88b56d1ec6SPeter Hawkins nb::callable typeCaster, bool replace) { 89*f136c800Svfdev nb::ft_lock_guard lock(mutex); 90b56d1ec6SPeter Hawkins nb::object &found = typeCasterMap[mlirTypeID]; 915192e299SMaksim Levental if (found && !replace) 925192e299SMaksim Levental throw std::runtime_error("Type caster is already registered with caster: " + 93b56d1ec6SPeter Hawkins nb::cast<std::string>(nb::str(found))); 94bfb1ba75Smax found = std::move(typeCaster); 95bfb1ba75Smax } 96bfb1ba75Smax 977c850867SMaksim Levental void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, 98b56d1ec6SPeter Hawkins nb::callable valueCaster, bool replace) { 99*f136c800Svfdev nb::ft_lock_guard lock(mutex); 100b56d1ec6SPeter Hawkins nb::object &found = valueCasterMap[mlirTypeID]; 1017c850867SMaksim Levental if (found && !replace) 1027c850867SMaksim Levental throw std::runtime_error("Value caster is already registered: " + 103b56d1ec6SPeter Hawkins nb::cast<std::string>(nb::repr(found))); 1047c850867SMaksim Levental found = std::move(valueCaster); 1057c850867SMaksim Levental } 1067c850867SMaksim Levental 107ac0a70f3SAlex Zinenko void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 108b56d1ec6SPeter Hawkins nb::object pyClass) { 109*f136c800Svfdev nb::ft_lock_guard lock(mutex); 110b56d1ec6SPeter Hawkins nb::object &found = dialectClassMap[dialectNamespace]; 111ac0a70f3SAlex Zinenko if (found) { 1124811270bSmax throw std::runtime_error((llvm::Twine("Dialect namespace '") + 1134811270bSmax dialectNamespace + "' is already registered.") 1144811270bSmax .str()); 115ac0a70f3SAlex Zinenko } 116ac0a70f3SAlex Zinenko found = std::move(pyClass); 117ac0a70f3SAlex Zinenko } 118ac0a70f3SAlex Zinenko 119ac0a70f3SAlex Zinenko void PyGlobals::registerOperationImpl(const std::string &operationName, 120b56d1ec6SPeter Hawkins nb::object pyClass, bool replace) { 121*f136c800Svfdev nb::ft_lock_guard lock(mutex); 122b56d1ec6SPeter Hawkins nb::object &found = operationClassMap[operationName]; 123a2288a89SMaksim Levental if (found && !replace) { 1244811270bSmax throw std::runtime_error((llvm::Twine("Operation '") + operationName + 1254811270bSmax "' is already registered.") 1264811270bSmax .str()); 127ac0a70f3SAlex Zinenko } 128ac0a70f3SAlex Zinenko found = std::move(pyClass); 129ac0a70f3SAlex Zinenko } 130ac0a70f3SAlex Zinenko 131b56d1ec6SPeter Hawkins std::optional<nb::callable> 132b57acb9aSJacques Pienaar PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { 133*f136c800Svfdev nb::ft_lock_guard lock(mutex); 134b57acb9aSJacques Pienaar const auto foundIt = attributeBuilderMap.find(attributeKind); 135b57acb9aSJacques Pienaar if (foundIt != attributeBuilderMap.end()) { 1365192e299SMaksim Levental assert(foundIt->second && "attribute builder is defined"); 137b57acb9aSJacques Pienaar return foundIt->second; 138b57acb9aSJacques Pienaar } 139b57acb9aSJacques Pienaar return std::nullopt; 140b57acb9aSJacques Pienaar } 141b57acb9aSJacques Pienaar 142b56d1ec6SPeter Hawkins std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, 143bfb1ba75Smax MlirDialect dialect) { 14426dc7650SMaksim Levental // Try to load dialect module. 14526dc7650SMaksim Levental (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); 146*f136c800Svfdev nb::ft_lock_guard lock(mutex); 147bfb1ba75Smax const auto foundIt = typeCasterMap.find(mlirTypeID); 148bfb1ba75Smax if (foundIt != typeCasterMap.end()) { 1495192e299SMaksim Levental assert(foundIt->second && "type caster is defined"); 150bfb1ba75Smax return foundIt->second; 151bfb1ba75Smax } 152bfb1ba75Smax return std::nullopt; 153bfb1ba75Smax } 154bfb1ba75Smax 155b56d1ec6SPeter Hawkins std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, 1567c850867SMaksim Levental MlirDialect dialect) { 15726dc7650SMaksim Levental // Try to load dialect module. 15826dc7650SMaksim Levental (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); 159*f136c800Svfdev nb::ft_lock_guard lock(mutex); 1607c850867SMaksim Levental const auto foundIt = valueCasterMap.find(mlirTypeID); 1617c850867SMaksim Levental if (foundIt != valueCasterMap.end()) { 1627c850867SMaksim Levental assert(foundIt->second && "value caster is defined"); 1637c850867SMaksim Levental return foundIt->second; 1647c850867SMaksim Levental } 1657c850867SMaksim Levental return std::nullopt; 1667c850867SMaksim Levental } 1677c850867SMaksim Levental 168b56d1ec6SPeter Hawkins std::optional<nb::object> 169ac0a70f3SAlex Zinenko PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 1705192e299SMaksim Levental // Make sure dialect module is loaded. 1715192e299SMaksim Levental if (!loadDialectModule(dialectNamespace)) 1725192e299SMaksim Levental return std::nullopt; 173*f136c800Svfdev nb::ft_lock_guard lock(mutex); 174ac0a70f3SAlex Zinenko const auto foundIt = dialectClassMap.find(dialectNamespace); 175ac0a70f3SAlex Zinenko if (foundIt != dialectClassMap.end()) { 1765192e299SMaksim Levental assert(foundIt->second && "dialect class is defined"); 177ac0a70f3SAlex Zinenko return foundIt->second; 178ac0a70f3SAlex Zinenko } 1795192e299SMaksim Levental // Not found and loading did not yield a registration. 180e823ababSKazu Hirata return std::nullopt; 181ac0a70f3SAlex Zinenko } 182ac0a70f3SAlex Zinenko 183b56d1ec6SPeter Hawkins std::optional<nb::object> 184a7f8b7cdSRahul Kayaith PyGlobals::lookupOperationClass(llvm::StringRef operationName) { 1855192e299SMaksim Levental // Make sure dialect module is loaded. 186ac0a70f3SAlex Zinenko auto split = operationName.split('.'); 187ac0a70f3SAlex Zinenko llvm::StringRef dialectNamespace = split.first; 1885192e299SMaksim Levental if (!loadDialectModule(dialectNamespace)) 1895192e299SMaksim Levental return std::nullopt; 190ac0a70f3SAlex Zinenko 191*f136c800Svfdev nb::ft_lock_guard lock(mutex); 192a7f8b7cdSRahul Kayaith auto foundIt = operationClassMap.find(operationName); 193a7f8b7cdSRahul Kayaith if (foundIt != operationClassMap.end()) { 1945192e299SMaksim Levental assert(foundIt->second && "OpView is defined"); 195ac0a70f3SAlex Zinenko return foundIt->second; 19602b6fb21SMehdi Amini } 1975192e299SMaksim Levental // Not found and loading did not yield a registration. 198e823ababSKazu Hirata return std::nullopt; 199ac0a70f3SAlex Zinenko } 200