xref: /llvm-project/mlir/lib/Bindings/Python/Globals.h (revision f136c800b60dbfacdbb645e7e92acba52e2f279f)
1013b9322SStella Laurenzo //===- Globals.h - MLIR Python extension globals --------------------------===//
2013b9322SStella Laurenzo //
3013b9322SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4013b9322SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5013b9322SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6013b9322SStella Laurenzo //
7013b9322SStella Laurenzo //===----------------------------------------------------------------------===//
8013b9322SStella Laurenzo 
9013b9322SStella Laurenzo #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
10013b9322SStella Laurenzo #define MLIR_BINDINGS_PYTHON_GLOBALS_H
11013b9322SStella Laurenzo 
12b56d1ec6SPeter Hawkins #include <optional>
13b56d1ec6SPeter Hawkins #include <string>
14b56d1ec6SPeter Hawkins #include <vector>
15013b9322SStella Laurenzo 
16b56d1ec6SPeter Hawkins #include "NanobindUtils.h"
17bfb1ba75Smax #include "mlir-c/IR.h"
18bfb1ba75Smax #include "mlir/CAPI/Support.h"
19bfb1ba75Smax #include "llvm/ADT/DenseMap.h"
20013b9322SStella Laurenzo #include "llvm/ADT/StringRef.h"
21013b9322SStella Laurenzo #include "llvm/ADT/StringSet.h"
22013b9322SStella Laurenzo 
23013b9322SStella Laurenzo namespace mlir {
24013b9322SStella Laurenzo namespace python {
25013b9322SStella Laurenzo 
26013b9322SStella Laurenzo /// Globals that are always accessible once the extension has been initialized.
27*f136c800Svfdev /// Methods of this class are thread-safe.
28013b9322SStella Laurenzo class PyGlobals {
29013b9322SStella Laurenzo public:
30013b9322SStella Laurenzo   PyGlobals();
31013b9322SStella Laurenzo   ~PyGlobals();
32013b9322SStella Laurenzo 
33013b9322SStella Laurenzo   /// Most code should get the globals via this static accessor.
34013b9322SStella Laurenzo   static PyGlobals &get() {
35013b9322SStella Laurenzo     assert(instance && "PyGlobals is null");
36013b9322SStella Laurenzo     return *instance;
37013b9322SStella Laurenzo   }
38013b9322SStella Laurenzo 
39013b9322SStella Laurenzo   /// Get and set the list of parent modules to search for dialect
40013b9322SStella Laurenzo   /// implementation classes.
41*f136c800Svfdev   std::vector<std::string> getDialectSearchPrefixes() {
42*f136c800Svfdev     nanobind::ft_lock_guard lock(mutex);
43013b9322SStella Laurenzo     return dialectSearchPrefixes;
44013b9322SStella Laurenzo   }
45013b9322SStella Laurenzo   void setDialectSearchPrefixes(std::vector<std::string> newValues) {
46*f136c800Svfdev     nanobind::ft_lock_guard lock(mutex);
47013b9322SStella Laurenzo     dialectSearchPrefixes.swap(newValues);
48013b9322SStella Laurenzo   }
49*f136c800Svfdev   void addDialectSearchPrefix(std::string value) {
50*f136c800Svfdev     nanobind::ft_lock_guard lock(mutex);
51*f136c800Svfdev     dialectSearchPrefixes.push_back(std::move(value));
52*f136c800Svfdev   }
53013b9322SStella Laurenzo 
54013b9322SStella Laurenzo   /// Loads a python module corresponding to the given dialect namespace.
55013b9322SStella Laurenzo   /// No-ops if the module has already been loaded or is not found. Raises
56013b9322SStella Laurenzo   /// an error on any evaluation issues.
57013b9322SStella Laurenzo   /// Note that this returns void because it is expected that the module
58013b9322SStella Laurenzo   /// contains calls to decorators and helpers that register the salient
595192e299SMaksim Levental   /// entities. Returns true if dialect is successfully loaded.
605192e299SMaksim Levental   bool loadDialectModule(llvm::StringRef dialectNamespace);
61013b9322SStella Laurenzo 
62b57acb9aSJacques Pienaar   /// Adds a user-friendly Attribute builder.
6392233062Smax   /// Raises an exception if the mapping already exists and replace == false.
64b57acb9aSJacques Pienaar   /// This is intended to be called by implementation code.
65b57acb9aSJacques Pienaar   void registerAttributeBuilder(const std::string &attributeKind,
66b56d1ec6SPeter Hawkins                                 nanobind::callable pyFunc,
6792233062Smax                                 bool replace = false);
68b57acb9aSJacques Pienaar 
69bfb1ba75Smax   /// Adds a user-friendly type caster. Raises an exception if the mapping
70bfb1ba75Smax   /// already exists and replace == false. This is intended to be called by
71bfb1ba75Smax   /// implementation code.
72b56d1ec6SPeter Hawkins   void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster,
73bfb1ba75Smax                           bool replace = false);
74bfb1ba75Smax 
757c850867SMaksim Levental   /// Adds a user-friendly value caster. Raises an exception if the mapping
767c850867SMaksim Levental   /// already exists and replace == false. This is intended to be called by
777c850867SMaksim Levental   /// implementation code.
787c850867SMaksim Levental   void registerValueCaster(MlirTypeID mlirTypeID,
79b56d1ec6SPeter Hawkins                            nanobind::callable valueCaster,
807c850867SMaksim Levental                            bool replace = false);
817c850867SMaksim Levental 
82013b9322SStella Laurenzo   /// Adds a concrete implementation dialect class.
83013b9322SStella Laurenzo   /// Raises an exception if the mapping already exists.
84013b9322SStella Laurenzo   /// This is intended to be called by implementation code.
85013b9322SStella Laurenzo   void registerDialectImpl(const std::string &dialectNamespace,
86b56d1ec6SPeter Hawkins                            nanobind::object pyClass);
87013b9322SStella Laurenzo 
88013b9322SStella Laurenzo   /// Adds a concrete implementation operation class.
89a2288a89SMaksim Levental   /// Raises an exception if the mapping already exists and replace == false.
90013b9322SStella Laurenzo   /// This is intended to be called by implementation code.
91013b9322SStella Laurenzo   void registerOperationImpl(const std::string &operationName,
92b56d1ec6SPeter Hawkins                              nanobind::object pyClass, bool replace = false);
93013b9322SStella Laurenzo 
94b57acb9aSJacques Pienaar   /// Returns the custom Attribute builder for Attribute kind.
95b56d1ec6SPeter Hawkins   std::optional<nanobind::callable>
96b57acb9aSJacques Pienaar   lookupAttributeBuilder(const std::string &attributeKind);
97b57acb9aSJacques Pienaar 
98bfb1ba75Smax   /// Returns the custom type caster for MlirTypeID mlirTypeID.
99b56d1ec6SPeter Hawkins   std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID,
100bfb1ba75Smax                                                      MlirDialect dialect);
101bfb1ba75Smax 
1027c850867SMaksim Levental   /// Returns the custom value caster for MlirTypeID mlirTypeID.
103b56d1ec6SPeter Hawkins   std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID,
1047c850867SMaksim Levental                                                       MlirDialect dialect);
1057c850867SMaksim Levental 
106013b9322SStella Laurenzo   /// Looks up a registered dialect class by namespace. Note that this may
107013b9322SStella Laurenzo   /// trigger loading of the defining module and can arbitrarily re-enter.
108b56d1ec6SPeter Hawkins   std::optional<nanobind::object>
109013b9322SStella Laurenzo   lookupDialectClass(const std::string &dialectNamespace);
110013b9322SStella Laurenzo 
111a7f8b7cdSRahul Kayaith   /// Looks up a registered operation class (deriving from OpView) by operation
112a7f8b7cdSRahul Kayaith   /// name. Note that this may trigger a load of the dialect, which can
113a7f8b7cdSRahul Kayaith   /// arbitrarily re-enter.
114b56d1ec6SPeter Hawkins   std::optional<nanobind::object>
115a7f8b7cdSRahul Kayaith   lookupOperationClass(llvm::StringRef operationName);
1168260db75SStella Laurenzo 
117013b9322SStella Laurenzo private:
118013b9322SStella Laurenzo   static PyGlobals *instance;
119*f136c800Svfdev 
120*f136c800Svfdev   nanobind::ft_mutex mutex;
121*f136c800Svfdev 
122013b9322SStella Laurenzo   /// Module name prefixes to search under for dialect implementation modules.
123013b9322SStella Laurenzo   std::vector<std::string> dialectSearchPrefixes;
124013b9322SStella Laurenzo   /// Map of dialect namespace to external dialect class object.
125b56d1ec6SPeter Hawkins   llvm::StringMap<nanobind::object> dialectClassMap;
126013b9322SStella Laurenzo   /// Map of full operation name to external operation class object.
127b56d1ec6SPeter Hawkins   llvm::StringMap<nanobind::object> operationClassMap;
128b57acb9aSJacques Pienaar   /// Map of attribute ODS name to custom builder.
129b56d1ec6SPeter Hawkins   llvm::StringMap<nanobind::callable> attributeBuilderMap;
130bfb1ba75Smax   /// Map of MlirTypeID to custom type caster.
131b56d1ec6SPeter Hawkins   llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap;
1327c850867SMaksim Levental   /// Map of MlirTypeID to custom value caster.
133b56d1ec6SPeter Hawkins   llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap;
1348260db75SStella Laurenzo   /// Set of dialect namespaces that we have attempted to import implementation
1358260db75SStella Laurenzo   /// modules for.
1365192e299SMaksim Levental   llvm::StringSet<> loadedDialectModules;
137013b9322SStella Laurenzo };
138013b9322SStella Laurenzo 
139013b9322SStella Laurenzo } // namespace python
140013b9322SStella Laurenzo } // namespace mlir
141013b9322SStella Laurenzo 
142013b9322SStella Laurenzo #endif // MLIR_BINDINGS_PYTHON_GLOBALS_H
143