xref: /llvm-project/mlir/lib/Bindings/Python/IRModule.cpp (revision a7f8b7cd8e49bb8680c34c1cc290a121ae37b4ac)
1 //===- IRModule.cpp - IR pybind module ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 #include "Globals.h"
11 #include "PybindUtils.h"
12 
13 #include <vector>
14 #include <optional>
15 
16 #include "mlir-c/Bindings/Python/Interop.h"
17 
18 namespace py = pybind11;
19 using namespace mlir;
20 using namespace mlir::python;
21 
22 // -----------------------------------------------------------------------------
23 // PyGlobals
24 // -----------------------------------------------------------------------------
25 
26 PyGlobals *PyGlobals::instance = nullptr;
27 
28 PyGlobals::PyGlobals() {
29   assert(!instance && "PyGlobals already constructed");
30   instance = this;
31   // The default search path include {mlir.}dialects, where {mlir.} is the
32   // package prefix configured at compile time.
33   dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
34 }
35 
36 PyGlobals::~PyGlobals() { instance = nullptr; }
37 
38 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
39   if (loadedDialectModulesCache.contains(dialectNamespace))
40     return;
41   // Since re-entrancy is possible, make a copy of the search prefixes.
42   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
43   py::object loaded;
44   for (std::string moduleName : localSearchPrefixes) {
45     moduleName.push_back('.');
46     moduleName.append(dialectNamespace.data(), dialectNamespace.size());
47 
48     try {
49       loaded = py::module::import(moduleName.c_str());
50     } catch (py::error_already_set &e) {
51       if (e.matches(PyExc_ModuleNotFoundError)) {
52         continue;
53       }
54       throw;
55     }
56     break;
57   }
58 
59   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
60   // may have occurred, which may do anything.
61   loadedDialectModulesCache.insert(dialectNamespace);
62 }
63 
64 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
65                                          py::function pyFunc) {
66   py::object &found = attributeBuilderMap[attributeKind];
67   if (found) {
68     throw std::runtime_error((llvm::Twine("Attribute builder for '") +
69                               attributeKind + "' is already registered")
70                                  .str());
71   }
72   found = std::move(pyFunc);
73 }
74 
75 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
76                                     py::object pyClass) {
77   py::object &found = dialectClassMap[dialectNamespace];
78   if (found) {
79     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
80                                              dialectNamespace +
81                                              "' is already registered.");
82   }
83   found = std::move(pyClass);
84 }
85 
86 void PyGlobals::registerOperationImpl(const std::string &operationName,
87                                       py::object pyClass) {
88   py::object &found = operationClassMap[operationName];
89   if (found) {
90     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
91                                              operationName +
92                                              "' is already registered.");
93   }
94   found = std::move(pyClass);
95 }
96 
97 std::optional<py::function>
98 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
99   // Fast match against the class map first (common case).
100   const auto foundIt = attributeBuilderMap.find(attributeKind);
101   if (foundIt != attributeBuilderMap.end()) {
102     if (foundIt->second.is_none())
103       return std::nullopt;
104     assert(foundIt->second && "py::function is defined");
105     return foundIt->second;
106   }
107 
108   // Not found and loading did not yield a registration. Negative cache.
109   attributeBuilderMap[attributeKind] = py::none();
110   return std::nullopt;
111 }
112 
113 std::optional<py::object>
114 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
115   loadDialectModule(dialectNamespace);
116   // Fast match against the class map first (common case).
117   const auto foundIt = dialectClassMap.find(dialectNamespace);
118   if (foundIt != dialectClassMap.end()) {
119     if (foundIt->second.is_none())
120       return std::nullopt;
121     assert(foundIt->second && "py::object is defined");
122     return foundIt->second;
123   }
124 
125   // Not found and loading did not yield a registration. Negative cache.
126   dialectClassMap[dialectNamespace] = py::none();
127   return std::nullopt;
128 }
129 
130 std::optional<pybind11::object>
131 PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
132   {
133     auto foundIt = operationClassMapCache.find(operationName);
134     if (foundIt != operationClassMapCache.end()) {
135       if (foundIt->second.is_none())
136         return std::nullopt;
137       assert(foundIt->second && "py::object is defined");
138       return foundIt->second;
139     }
140   }
141 
142   // Not found. Load the dialect namespace.
143   auto split = operationName.split('.');
144   llvm::StringRef dialectNamespace = split.first;
145   loadDialectModule(dialectNamespace);
146 
147   // Attempt to find from the canonical map and cache.
148   {
149     auto foundIt = operationClassMap.find(operationName);
150     if (foundIt != operationClassMap.end()) {
151       if (foundIt->second.is_none())
152         return std::nullopt;
153       assert(foundIt->second && "py::object is defined");
154       // Positive cache.
155       operationClassMapCache[operationName] = foundIt->second;
156       return foundIt->second;
157     }
158     // Negative cache.
159     operationClassMap[operationName] = py::none();
160     return std::nullopt;
161   }
162 }
163 
164 void PyGlobals::clearImportCache() {
165   loadedDialectModulesCache.clear();
166   operationClassMapCache.clear();
167 }
168