xref: /llvm-project/mlir/lib/Bindings/Python/IRModule.cpp (revision f136c800b60dbfacdbb645e7e92acba52e2f279f)
1ac0a70f3SAlex Zinenko //===- IRModule.cpp - IR pybind module ------------------------------------===//
2ac0a70f3SAlex Zinenko //
3ac0a70f3SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ac0a70f3SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5ac0a70f3SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ac0a70f3SAlex Zinenko //
7ac0a70f3SAlex Zinenko //===----------------------------------------------------------------------===//
8ac0a70f3SAlex Zinenko 
9ac0a70f3SAlex Zinenko #include "IRModule.h"
10ac0a70f3SAlex Zinenko 
115192e299SMaksim Levental #include <optional>
125192e299SMaksim Levental #include <vector>
135192e299SMaksim Levental 
14b56d1ec6SPeter Hawkins #include "Globals.h"
15b56d1ec6SPeter Hawkins #include "NanobindUtils.h"
16b56d1ec6SPeter Hawkins #include "mlir-c/Support.h"
175cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h"
185cd42747SPeter Hawkins #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
19b56d1ec6SPeter Hawkins 
20b56d1ec6SPeter Hawkins namespace nb = nanobind;
21ac0a70f3SAlex Zinenko using namespace mlir;
22ac0a70f3SAlex Zinenko using namespace mlir::python;
23ac0a70f3SAlex Zinenko 
24ac0a70f3SAlex Zinenko // -----------------------------------------------------------------------------
25ac0a70f3SAlex Zinenko // PyGlobals
26ac0a70f3SAlex Zinenko // -----------------------------------------------------------------------------
27ac0a70f3SAlex Zinenko 
28ac0a70f3SAlex Zinenko PyGlobals *PyGlobals::instance = nullptr;
29ac0a70f3SAlex Zinenko 
30ac0a70f3SAlex Zinenko PyGlobals::PyGlobals() {
31ac0a70f3SAlex Zinenko   assert(!instance && "PyGlobals already constructed");
32ac0a70f3SAlex Zinenko   instance = this;
33cb7b0381SStella Laurenzo   // The default search path include {mlir.}dialects, where {mlir.} is the
34cb7b0381SStella Laurenzo   // package prefix configured at compile time.
35e5639b3fSMehdi Amini   dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
36ac0a70f3SAlex Zinenko }
37ac0a70f3SAlex Zinenko 
38ac0a70f3SAlex Zinenko PyGlobals::~PyGlobals() { instance = nullptr; }
39ac0a70f3SAlex Zinenko 
405192e299SMaksim Levental bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
41*f136c800Svfdev   {
42*f136c800Svfdev     nb::ft_lock_guard lock(mutex);
435192e299SMaksim Levental     if (loadedDialectModules.contains(dialectNamespace))
445192e299SMaksim Levental       return true;
45*f136c800Svfdev   }
46ac0a70f3SAlex Zinenko   // Since re-entrancy is possible, make a copy of the search prefixes.
47ac0a70f3SAlex Zinenko   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
48b56d1ec6SPeter Hawkins   nb::object loaded = nb::none();
49ac0a70f3SAlex Zinenko   for (std::string moduleName : localSearchPrefixes) {
50ac0a70f3SAlex Zinenko     moduleName.push_back('.');
51ac0a70f3SAlex Zinenko     moduleName.append(dialectNamespace.data(), dialectNamespace.size());
52ac0a70f3SAlex Zinenko 
53ac0a70f3SAlex Zinenko     try {
54b56d1ec6SPeter Hawkins       loaded = nb::module_::import_(moduleName.c_str());
55b56d1ec6SPeter Hawkins     } catch (nb::python_error &e) {
56ac0a70f3SAlex Zinenko       if (e.matches(PyExc_ModuleNotFoundError)) {
57ac0a70f3SAlex Zinenko         continue;
58ac0a70f3SAlex Zinenko       }
5902b6fb21SMehdi Amini       throw;
60ac0a70f3SAlex Zinenko     }
61ac0a70f3SAlex Zinenko     break;
62ac0a70f3SAlex Zinenko   }
63ac0a70f3SAlex Zinenko 
645192e299SMaksim Levental   if (loaded.is_none())
655192e299SMaksim Levental     return false;
66ac0a70f3SAlex Zinenko   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
67ac0a70f3SAlex Zinenko   // may have occurred, which may do anything.
68*f136c800Svfdev   nb::ft_lock_guard lock(mutex);
695192e299SMaksim Levental   loadedDialectModules.insert(dialectNamespace);
705192e299SMaksim Levental   return true;
71ac0a70f3SAlex Zinenko }
72ac0a70f3SAlex Zinenko 
73b57acb9aSJacques Pienaar void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
74b56d1ec6SPeter Hawkins                                          nb::callable pyFunc, bool replace) {
75*f136c800Svfdev   nb::ft_lock_guard lock(mutex);
76b56d1ec6SPeter Hawkins   nb::object &found = attributeBuilderMap[attributeKind];
775192e299SMaksim Levental   if (found && !replace) {
78b57acb9aSJacques Pienaar     throw std::runtime_error((llvm::Twine("Attribute builder for '") +
7992233062Smax                               attributeKind +
8092233062Smax                               "' is already registered with func: " +
81b56d1ec6SPeter Hawkins                               nb::cast<std::string>(nb::str(found)))
82b57acb9aSJacques Pienaar                                  .str());
83b57acb9aSJacques Pienaar   }
84b57acb9aSJacques Pienaar   found = std::move(pyFunc);
85b57acb9aSJacques Pienaar }
86b57acb9aSJacques Pienaar 
87bfb1ba75Smax void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
88b56d1ec6SPeter Hawkins                                    nb::callable typeCaster, bool replace) {
89*f136c800Svfdev   nb::ft_lock_guard lock(mutex);
90b56d1ec6SPeter Hawkins   nb::object &found = typeCasterMap[mlirTypeID];
915192e299SMaksim Levental   if (found && !replace)
925192e299SMaksim Levental     throw std::runtime_error("Type caster is already registered with caster: " +
93b56d1ec6SPeter Hawkins                              nb::cast<std::string>(nb::str(found)));
94bfb1ba75Smax   found = std::move(typeCaster);
95bfb1ba75Smax }
96bfb1ba75Smax 
977c850867SMaksim Levental void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
98b56d1ec6SPeter Hawkins                                     nb::callable valueCaster, bool replace) {
99*f136c800Svfdev   nb::ft_lock_guard lock(mutex);
100b56d1ec6SPeter Hawkins   nb::object &found = valueCasterMap[mlirTypeID];
1017c850867SMaksim Levental   if (found && !replace)
1027c850867SMaksim Levental     throw std::runtime_error("Value caster is already registered: " +
103b56d1ec6SPeter Hawkins                              nb::cast<std::string>(nb::repr(found)));
1047c850867SMaksim Levental   found = std::move(valueCaster);
1057c850867SMaksim Levental }
1067c850867SMaksim Levental 
107ac0a70f3SAlex Zinenko void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
108b56d1ec6SPeter Hawkins                                     nb::object pyClass) {
109*f136c800Svfdev   nb::ft_lock_guard lock(mutex);
110b56d1ec6SPeter Hawkins   nb::object &found = dialectClassMap[dialectNamespace];
111ac0a70f3SAlex Zinenko   if (found) {
1124811270bSmax     throw std::runtime_error((llvm::Twine("Dialect namespace '") +
1134811270bSmax                               dialectNamespace + "' is already registered.")
1144811270bSmax                                  .str());
115ac0a70f3SAlex Zinenko   }
116ac0a70f3SAlex Zinenko   found = std::move(pyClass);
117ac0a70f3SAlex Zinenko }
118ac0a70f3SAlex Zinenko 
119ac0a70f3SAlex Zinenko void PyGlobals::registerOperationImpl(const std::string &operationName,
120b56d1ec6SPeter Hawkins                                       nb::object pyClass, bool replace) {
121*f136c800Svfdev   nb::ft_lock_guard lock(mutex);
122b56d1ec6SPeter Hawkins   nb::object &found = operationClassMap[operationName];
123a2288a89SMaksim Levental   if (found && !replace) {
1244811270bSmax     throw std::runtime_error((llvm::Twine("Operation '") + operationName +
1254811270bSmax                               "' is already registered.")
1264811270bSmax                                  .str());
127ac0a70f3SAlex Zinenko   }
128ac0a70f3SAlex Zinenko   found = std::move(pyClass);
129ac0a70f3SAlex Zinenko }
130ac0a70f3SAlex Zinenko 
131b56d1ec6SPeter Hawkins std::optional<nb::callable>
132b57acb9aSJacques Pienaar PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
133*f136c800Svfdev   nb::ft_lock_guard lock(mutex);
134b57acb9aSJacques Pienaar   const auto foundIt = attributeBuilderMap.find(attributeKind);
135b57acb9aSJacques Pienaar   if (foundIt != attributeBuilderMap.end()) {
1365192e299SMaksim Levental     assert(foundIt->second && "attribute builder is defined");
137b57acb9aSJacques Pienaar     return foundIt->second;
138b57acb9aSJacques Pienaar   }
139b57acb9aSJacques Pienaar   return std::nullopt;
140b57acb9aSJacques Pienaar }
141b57acb9aSJacques Pienaar 
142b56d1ec6SPeter Hawkins std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
143bfb1ba75Smax                                                         MlirDialect dialect) {
14426dc7650SMaksim Levental   // Try to load dialect module.
14526dc7650SMaksim Levental   (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
146*f136c800Svfdev   nb::ft_lock_guard lock(mutex);
147bfb1ba75Smax   const auto foundIt = typeCasterMap.find(mlirTypeID);
148bfb1ba75Smax   if (foundIt != typeCasterMap.end()) {
1495192e299SMaksim Levental     assert(foundIt->second && "type caster is defined");
150bfb1ba75Smax     return foundIt->second;
151bfb1ba75Smax   }
152bfb1ba75Smax   return std::nullopt;
153bfb1ba75Smax }
154bfb1ba75Smax 
155b56d1ec6SPeter Hawkins std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
1567c850867SMaksim Levental                                                          MlirDialect dialect) {
15726dc7650SMaksim Levental   // Try to load dialect module.
15826dc7650SMaksim Levental   (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
159*f136c800Svfdev   nb::ft_lock_guard lock(mutex);
1607c850867SMaksim Levental   const auto foundIt = valueCasterMap.find(mlirTypeID);
1617c850867SMaksim Levental   if (foundIt != valueCasterMap.end()) {
1627c850867SMaksim Levental     assert(foundIt->second && "value caster is defined");
1637c850867SMaksim Levental     return foundIt->second;
1647c850867SMaksim Levental   }
1657c850867SMaksim Levental   return std::nullopt;
1667c850867SMaksim Levental }
1677c850867SMaksim Levental 
168b56d1ec6SPeter Hawkins std::optional<nb::object>
169ac0a70f3SAlex Zinenko PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
1705192e299SMaksim Levental   // Make sure dialect module is loaded.
1715192e299SMaksim Levental   if (!loadDialectModule(dialectNamespace))
1725192e299SMaksim Levental     return std::nullopt;
173*f136c800Svfdev   nb::ft_lock_guard lock(mutex);
174ac0a70f3SAlex Zinenko   const auto foundIt = dialectClassMap.find(dialectNamespace);
175ac0a70f3SAlex Zinenko   if (foundIt != dialectClassMap.end()) {
1765192e299SMaksim Levental     assert(foundIt->second && "dialect class is defined");
177ac0a70f3SAlex Zinenko     return foundIt->second;
178ac0a70f3SAlex Zinenko   }
1795192e299SMaksim Levental   // Not found and loading did not yield a registration.
180e823ababSKazu Hirata   return std::nullopt;
181ac0a70f3SAlex Zinenko }
182ac0a70f3SAlex Zinenko 
183b56d1ec6SPeter Hawkins std::optional<nb::object>
184a7f8b7cdSRahul Kayaith PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
1855192e299SMaksim Levental   // Make sure dialect module is loaded.
186ac0a70f3SAlex Zinenko   auto split = operationName.split('.');
187ac0a70f3SAlex Zinenko   llvm::StringRef dialectNamespace = split.first;
1885192e299SMaksim Levental   if (!loadDialectModule(dialectNamespace))
1895192e299SMaksim Levental     return std::nullopt;
190ac0a70f3SAlex Zinenko 
191*f136c800Svfdev   nb::ft_lock_guard lock(mutex);
192a7f8b7cdSRahul Kayaith   auto foundIt = operationClassMap.find(operationName);
193a7f8b7cdSRahul Kayaith   if (foundIt != operationClassMap.end()) {
1945192e299SMaksim Levental     assert(foundIt->second && "OpView is defined");
195ac0a70f3SAlex Zinenko     return foundIt->second;
19602b6fb21SMehdi Amini   }
1975192e299SMaksim Levental   // Not found and loading did not yield a registration.
198e823ababSKazu Hirata   return std::nullopt;
199ac0a70f3SAlex Zinenko }
200