xref: /llvm-project/mlir/lib/Bindings/Python/IRModule.cpp (revision b57acb9a405c289069345a498ebfc1d1b9b110de)
1 //===- IRModule.cpp - IR pybind module ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 #include "Globals.h"
11 #include "PybindUtils.h"
12 
13 #include <vector>
14 
15 #include "mlir-c/Bindings/Python/Interop.h"
16 
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20 
21 // -----------------------------------------------------------------------------
22 // PyGlobals
23 // -----------------------------------------------------------------------------
24 
25 PyGlobals *PyGlobals::instance = nullptr;
26 
27 PyGlobals::PyGlobals() {
28   assert(!instance && "PyGlobals already constructed");
29   instance = this;
30   // The default search path include {mlir.}dialects, where {mlir.} is the
31   // package prefix configured at compile time.
32   dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
33 }
34 
35 PyGlobals::~PyGlobals() { instance = nullptr; }
36 
37 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
38   if (loadedDialectModulesCache.contains(dialectNamespace))
39     return;
40   // Since re-entrancy is possible, make a copy of the search prefixes.
41   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
42   py::object loaded;
43   for (std::string moduleName : localSearchPrefixes) {
44     moduleName.push_back('.');
45     moduleName.append(dialectNamespace.data(), dialectNamespace.size());
46 
47     try {
48       loaded = py::module::import(moduleName.c_str());
49     } catch (py::error_already_set &e) {
50       if (e.matches(PyExc_ModuleNotFoundError)) {
51         continue;
52       }
53       throw;
54     }
55     break;
56   }
57 
58   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
59   // may have occurred, which may do anything.
60   loadedDialectModulesCache.insert(dialectNamespace);
61 }
62 
63 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
64                                          py::function pyFunc) {
65   py::function &found = attributeBuilderMap[attributeKind];
66   if (found) {
67     throw std::runtime_error((llvm::Twine("Attribute builder for '") +
68                               attributeKind + "' is already registered")
69                                  .str());
70   }
71   found = std::move(pyFunc);
72 }
73 
74 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
75                                     py::object pyClass) {
76   py::object &found = dialectClassMap[dialectNamespace];
77   if (found) {
78     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
79                                              dialectNamespace +
80                                              "' is already registered.");
81   }
82   found = std::move(pyClass);
83 }
84 
85 void PyGlobals::registerOperationImpl(const std::string &operationName,
86                                       py::object pyClass,
87                                       py::object rawOpViewClass) {
88   py::object &found = operationClassMap[operationName];
89   if (found) {
90     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
91                                              operationName +
92                                              "' is already registered.");
93   }
94   found = std::move(pyClass);
95   rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
96 }
97 
98 std::optional<py::function>
99 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
100   // Fast match against the class map first (common case).
101   const auto foundIt = attributeBuilderMap.find(attributeKind);
102   if (foundIt != attributeBuilderMap.end()) {
103     if (foundIt->second.is_none())
104       return std::nullopt;
105     assert(foundIt->second && "py::function is defined");
106     return foundIt->second;
107   }
108 
109   // Not found and loading did not yield a registration. Negative cache.
110   attributeBuilderMap[attributeKind] = py::none();
111   return std::nullopt;
112 }
113 
114 llvm::Optional<py::object>
115 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
116   loadDialectModule(dialectNamespace);
117   // Fast match against the class map first (common case).
118   const auto foundIt = dialectClassMap.find(dialectNamespace);
119   if (foundIt != dialectClassMap.end()) {
120     if (foundIt->second.is_none())
121       return std::nullopt;
122     assert(foundIt->second && "py::object is defined");
123     return foundIt->second;
124   }
125 
126   // Not found and loading did not yield a registration. Negative cache.
127   dialectClassMap[dialectNamespace] = py::none();
128   return std::nullopt;
129 }
130 
131 llvm::Optional<pybind11::object>
132 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
133   {
134     auto foundIt = rawOpViewClassMapCache.find(operationName);
135     if (foundIt != rawOpViewClassMapCache.end()) {
136       if (foundIt->second.is_none())
137         return std::nullopt;
138       assert(foundIt->second && "py::object is defined");
139       return foundIt->second;
140     }
141   }
142 
143   // Not found. Load the dialect namespace.
144   auto split = operationName.split('.');
145   llvm::StringRef dialectNamespace = split.first;
146   loadDialectModule(dialectNamespace);
147 
148   // Attempt to find from the canonical map and cache.
149   {
150     auto foundIt = rawOpViewClassMap.find(operationName);
151     if (foundIt != rawOpViewClassMap.end()) {
152       if (foundIt->second.is_none())
153         return std::nullopt;
154       assert(foundIt->second && "py::object is defined");
155       // Positive cache.
156       rawOpViewClassMapCache[operationName] = foundIt->second;
157       return foundIt->second;
158     }
159     // Negative cache.
160     rawOpViewClassMap[operationName] = py::none();
161     return std::nullopt;
162   }
163 }
164 
165 void PyGlobals::clearImportCache() {
166   loadedDialectModulesCache.clear();
167   rawOpViewClassMapCache.clear();
168 }
169