1 //===- Globals.h - MLIR Python extension globals --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H 10 #define MLIR_BINDINGS_PYTHON_GLOBALS_H 11 12 #include <optional> 13 #include <string> 14 #include <vector> 15 16 #include "NanobindUtils.h" 17 #include "mlir-c/IR.h" 18 #include "mlir/CAPI/Support.h" 19 #include "llvm/ADT/DenseMap.h" 20 #include "llvm/ADT/StringRef.h" 21 #include "llvm/ADT/StringSet.h" 22 23 namespace mlir { 24 namespace python { 25 26 /// Globals that are always accessible once the extension has been initialized. 27 /// Methods of this class are thread-safe. 28 class PyGlobals { 29 public: 30 PyGlobals(); 31 ~PyGlobals(); 32 33 /// Most code should get the globals via this static accessor. 34 static PyGlobals &get() { 35 assert(instance && "PyGlobals is null"); 36 return *instance; 37 } 38 39 /// Get and set the list of parent modules to search for dialect 40 /// implementation classes. 41 std::vector<std::string> getDialectSearchPrefixes() { 42 nanobind::ft_lock_guard lock(mutex); 43 return dialectSearchPrefixes; 44 } 45 void setDialectSearchPrefixes(std::vector<std::string> newValues) { 46 nanobind::ft_lock_guard lock(mutex); 47 dialectSearchPrefixes.swap(newValues); 48 } 49 void addDialectSearchPrefix(std::string value) { 50 nanobind::ft_lock_guard lock(mutex); 51 dialectSearchPrefixes.push_back(std::move(value)); 52 } 53 54 /// Loads a python module corresponding to the given dialect namespace. 55 /// No-ops if the module has already been loaded or is not found. Raises 56 /// an error on any evaluation issues. 57 /// Note that this returns void because it is expected that the module 58 /// contains calls to decorators and helpers that register the salient 59 /// entities. Returns true if dialect is successfully loaded. 60 bool loadDialectModule(llvm::StringRef dialectNamespace); 61 62 /// Adds a user-friendly Attribute builder. 63 /// Raises an exception if the mapping already exists and replace == false. 64 /// This is intended to be called by implementation code. 65 void registerAttributeBuilder(const std::string &attributeKind, 66 nanobind::callable pyFunc, 67 bool replace = false); 68 69 /// Adds a user-friendly type caster. Raises an exception if the mapping 70 /// already exists and replace == false. This is intended to be called by 71 /// implementation code. 72 void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, 73 bool replace = false); 74 75 /// Adds a user-friendly value caster. Raises an exception if the mapping 76 /// already exists and replace == false. This is intended to be called by 77 /// implementation code. 78 void registerValueCaster(MlirTypeID mlirTypeID, 79 nanobind::callable valueCaster, 80 bool replace = false); 81 82 /// Adds a concrete implementation dialect class. 83 /// Raises an exception if the mapping already exists. 84 /// This is intended to be called by implementation code. 85 void registerDialectImpl(const std::string &dialectNamespace, 86 nanobind::object pyClass); 87 88 /// Adds a concrete implementation operation class. 89 /// Raises an exception if the mapping already exists and replace == false. 90 /// This is intended to be called by implementation code. 91 void registerOperationImpl(const std::string &operationName, 92 nanobind::object pyClass, bool replace = false); 93 94 /// Returns the custom Attribute builder for Attribute kind. 95 std::optional<nanobind::callable> 96 lookupAttributeBuilder(const std::string &attributeKind); 97 98 /// Returns the custom type caster for MlirTypeID mlirTypeID. 99 std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID, 100 MlirDialect dialect); 101 102 /// Returns the custom value caster for MlirTypeID mlirTypeID. 103 std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID, 104 MlirDialect dialect); 105 106 /// Looks up a registered dialect class by namespace. Note that this may 107 /// trigger loading of the defining module and can arbitrarily re-enter. 108 std::optional<nanobind::object> 109 lookupDialectClass(const std::string &dialectNamespace); 110 111 /// Looks up a registered operation class (deriving from OpView) by operation 112 /// name. Note that this may trigger a load of the dialect, which can 113 /// arbitrarily re-enter. 114 std::optional<nanobind::object> 115 lookupOperationClass(llvm::StringRef operationName); 116 117 private: 118 static PyGlobals *instance; 119 120 nanobind::ft_mutex mutex; 121 122 /// Module name prefixes to search under for dialect implementation modules. 123 std::vector<std::string> dialectSearchPrefixes; 124 /// Map of dialect namespace to external dialect class object. 125 llvm::StringMap<nanobind::object> dialectClassMap; 126 /// Map of full operation name to external operation class object. 127 llvm::StringMap<nanobind::object> operationClassMap; 128 /// Map of attribute ODS name to custom builder. 129 llvm::StringMap<nanobind::callable> attributeBuilderMap; 130 /// Map of MlirTypeID to custom type caster. 131 llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap; 132 /// Map of MlirTypeID to custom value caster. 133 llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap; 134 /// Set of dialect namespaces that we have attempted to import implementation 135 /// modules for. 136 llvm::StringSet<> loadedDialectModules; 137 }; 138 139 } // namespace python 140 } // namespace mlir 141 142 #endif // MLIR_BINDINGS_PYTHON_GLOBALS_H 143