xref: /llvm-project/mlir/lib/Bindings/Python/IRModule.cpp (revision 0a81ace0047a2de93e71c82cdf0977fc989660df)
1 //===- IRModule.cpp - IR pybind module ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 #include "Globals.h"
11 #include "PybindUtils.h"
12 
13 #include <vector>
14 #include <optional>
15 
16 #include "mlir-c/Bindings/Python/Interop.h"
17 
18 namespace py = pybind11;
19 using namespace mlir;
20 using namespace mlir::python;
21 
22 // -----------------------------------------------------------------------------
23 // PyGlobals
24 // -----------------------------------------------------------------------------
25 
26 PyGlobals *PyGlobals::instance = nullptr;
27 
28 PyGlobals::PyGlobals() {
29   assert(!instance && "PyGlobals already constructed");
30   instance = this;
31   // The default search path include {mlir.}dialects, where {mlir.} is the
32   // package prefix configured at compile time.
33   dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
34 }
35 
36 PyGlobals::~PyGlobals() { instance = nullptr; }
37 
38 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
39   if (loadedDialectModulesCache.contains(dialectNamespace))
40     return;
41   // Since re-entrancy is possible, make a copy of the search prefixes.
42   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
43   py::object loaded;
44   for (std::string moduleName : localSearchPrefixes) {
45     moduleName.push_back('.');
46     moduleName.append(dialectNamespace.data(), dialectNamespace.size());
47 
48     try {
49       loaded = py::module::import(moduleName.c_str());
50     } catch (py::error_already_set &e) {
51       if (e.matches(PyExc_ModuleNotFoundError)) {
52         continue;
53       }
54       throw;
55     }
56     break;
57   }
58 
59   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
60   // may have occurred, which may do anything.
61   loadedDialectModulesCache.insert(dialectNamespace);
62 }
63 
64 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
65                                          py::function pyFunc) {
66   py::object &found = attributeBuilderMap[attributeKind];
67   if (found) {
68     throw std::runtime_error((llvm::Twine("Attribute builder for '") +
69                               attributeKind + "' is already registered")
70                                  .str());
71   }
72   found = std::move(pyFunc);
73 }
74 
75 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
76                                     py::object pyClass) {
77   py::object &found = dialectClassMap[dialectNamespace];
78   if (found) {
79     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
80                                              dialectNamespace +
81                                              "' is already registered.");
82   }
83   found = std::move(pyClass);
84 }
85 
86 void PyGlobals::registerOperationImpl(const std::string &operationName,
87                                       py::object pyClass,
88                                       py::object rawOpViewClass) {
89   py::object &found = operationClassMap[operationName];
90   if (found) {
91     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
92                                              operationName +
93                                              "' is already registered.");
94   }
95   found = std::move(pyClass);
96   rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
97 }
98 
99 std::optional<py::function>
100 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
101   // Fast match against the class map first (common case).
102   const auto foundIt = attributeBuilderMap.find(attributeKind);
103   if (foundIt != attributeBuilderMap.end()) {
104     if (foundIt->second.is_none())
105       return std::nullopt;
106     assert(foundIt->second && "py::function is defined");
107     return foundIt->second;
108   }
109 
110   // Not found and loading did not yield a registration. Negative cache.
111   attributeBuilderMap[attributeKind] = py::none();
112   return std::nullopt;
113 }
114 
115 std::optional<py::object>
116 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
117   loadDialectModule(dialectNamespace);
118   // Fast match against the class map first (common case).
119   const auto foundIt = dialectClassMap.find(dialectNamespace);
120   if (foundIt != dialectClassMap.end()) {
121     if (foundIt->second.is_none())
122       return std::nullopt;
123     assert(foundIt->second && "py::object is defined");
124     return foundIt->second;
125   }
126 
127   // Not found and loading did not yield a registration. Negative cache.
128   dialectClassMap[dialectNamespace] = py::none();
129   return std::nullopt;
130 }
131 
132 std::optional<pybind11::object>
133 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
134   {
135     auto foundIt = rawOpViewClassMapCache.find(operationName);
136     if (foundIt != rawOpViewClassMapCache.end()) {
137       if (foundIt->second.is_none())
138         return std::nullopt;
139       assert(foundIt->second && "py::object is defined");
140       return foundIt->second;
141     }
142   }
143 
144   // Not found. Load the dialect namespace.
145   auto split = operationName.split('.');
146   llvm::StringRef dialectNamespace = split.first;
147   loadDialectModule(dialectNamespace);
148 
149   // Attempt to find from the canonical map and cache.
150   {
151     auto foundIt = rawOpViewClassMap.find(operationName);
152     if (foundIt != rawOpViewClassMap.end()) {
153       if (foundIt->second.is_none())
154         return std::nullopt;
155       assert(foundIt->second && "py::object is defined");
156       // Positive cache.
157       rawOpViewClassMapCache[operationName] = foundIt->second;
158       return foundIt->second;
159     }
160     // Negative cache.
161     rawOpViewClassMap[operationName] = py::none();
162     return std::nullopt;
163   }
164 }
165 
166 void PyGlobals::clearImportCache() {
167   loadedDialectModulesCache.clear();
168   rawOpViewClassMapCache.clear();
169 }
170