1013b9322SStella Laurenzo //===- Globals.h - MLIR Python extension globals --------------------------===// 2013b9322SStella Laurenzo // 3013b9322SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4013b9322SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5013b9322SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6013b9322SStella Laurenzo // 7013b9322SStella Laurenzo //===----------------------------------------------------------------------===// 8013b9322SStella Laurenzo 9013b9322SStella Laurenzo #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H 10013b9322SStella Laurenzo #define MLIR_BINDINGS_PYTHON_GLOBALS_H 11013b9322SStella Laurenzo 12b56d1ec6SPeter Hawkins #include <optional> 13b56d1ec6SPeter Hawkins #include <string> 14b56d1ec6SPeter Hawkins #include <vector> 15013b9322SStella Laurenzo 16b56d1ec6SPeter Hawkins #include "NanobindUtils.h" 17bfb1ba75Smax #include "mlir-c/IR.h" 18bfb1ba75Smax #include "mlir/CAPI/Support.h" 19bfb1ba75Smax #include "llvm/ADT/DenseMap.h" 20013b9322SStella Laurenzo #include "llvm/ADT/StringRef.h" 21013b9322SStella Laurenzo #include "llvm/ADT/StringSet.h" 22013b9322SStella Laurenzo 23013b9322SStella Laurenzo namespace mlir { 24013b9322SStella Laurenzo namespace python { 25013b9322SStella Laurenzo 26013b9322SStella Laurenzo /// Globals that are always accessible once the extension has been initialized. 27*f136c800Svfdev /// Methods of this class are thread-safe. 28013b9322SStella Laurenzo class PyGlobals { 29013b9322SStella Laurenzo public: 30013b9322SStella Laurenzo PyGlobals(); 31013b9322SStella Laurenzo ~PyGlobals(); 32013b9322SStella Laurenzo 33013b9322SStella Laurenzo /// Most code should get the globals via this static accessor. 34013b9322SStella Laurenzo static PyGlobals &get() { 35013b9322SStella Laurenzo assert(instance && "PyGlobals is null"); 36013b9322SStella Laurenzo return *instance; 37013b9322SStella Laurenzo } 38013b9322SStella Laurenzo 39013b9322SStella Laurenzo /// Get and set the list of parent modules to search for dialect 40013b9322SStella Laurenzo /// implementation classes. 41*f136c800Svfdev std::vector<std::string> getDialectSearchPrefixes() { 42*f136c800Svfdev nanobind::ft_lock_guard lock(mutex); 43013b9322SStella Laurenzo return dialectSearchPrefixes; 44013b9322SStella Laurenzo } 45013b9322SStella Laurenzo void setDialectSearchPrefixes(std::vector<std::string> newValues) { 46*f136c800Svfdev nanobind::ft_lock_guard lock(mutex); 47013b9322SStella Laurenzo dialectSearchPrefixes.swap(newValues); 48013b9322SStella Laurenzo } 49*f136c800Svfdev void addDialectSearchPrefix(std::string value) { 50*f136c800Svfdev nanobind::ft_lock_guard lock(mutex); 51*f136c800Svfdev dialectSearchPrefixes.push_back(std::move(value)); 52*f136c800Svfdev } 53013b9322SStella Laurenzo 54013b9322SStella Laurenzo /// Loads a python module corresponding to the given dialect namespace. 55013b9322SStella Laurenzo /// No-ops if the module has already been loaded or is not found. Raises 56013b9322SStella Laurenzo /// an error on any evaluation issues. 57013b9322SStella Laurenzo /// Note that this returns void because it is expected that the module 58013b9322SStella Laurenzo /// contains calls to decorators and helpers that register the salient 595192e299SMaksim Levental /// entities. Returns true if dialect is successfully loaded. 605192e299SMaksim Levental bool loadDialectModule(llvm::StringRef dialectNamespace); 61013b9322SStella Laurenzo 62b57acb9aSJacques Pienaar /// Adds a user-friendly Attribute builder. 6392233062Smax /// Raises an exception if the mapping already exists and replace == false. 64b57acb9aSJacques Pienaar /// This is intended to be called by implementation code. 65b57acb9aSJacques Pienaar void registerAttributeBuilder(const std::string &attributeKind, 66b56d1ec6SPeter Hawkins nanobind::callable pyFunc, 6792233062Smax bool replace = false); 68b57acb9aSJacques Pienaar 69bfb1ba75Smax /// Adds a user-friendly type caster. Raises an exception if the mapping 70bfb1ba75Smax /// already exists and replace == false. This is intended to be called by 71bfb1ba75Smax /// implementation code. 72b56d1ec6SPeter Hawkins void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, 73bfb1ba75Smax bool replace = false); 74bfb1ba75Smax 757c850867SMaksim Levental /// Adds a user-friendly value caster. Raises an exception if the mapping 767c850867SMaksim Levental /// already exists and replace == false. This is intended to be called by 777c850867SMaksim Levental /// implementation code. 787c850867SMaksim Levental void registerValueCaster(MlirTypeID mlirTypeID, 79b56d1ec6SPeter Hawkins nanobind::callable valueCaster, 807c850867SMaksim Levental bool replace = false); 817c850867SMaksim Levental 82013b9322SStella Laurenzo /// Adds a concrete implementation dialect class. 83013b9322SStella Laurenzo /// Raises an exception if the mapping already exists. 84013b9322SStella Laurenzo /// This is intended to be called by implementation code. 85013b9322SStella Laurenzo void registerDialectImpl(const std::string &dialectNamespace, 86b56d1ec6SPeter Hawkins nanobind::object pyClass); 87013b9322SStella Laurenzo 88013b9322SStella Laurenzo /// Adds a concrete implementation operation class. 89a2288a89SMaksim Levental /// Raises an exception if the mapping already exists and replace == false. 90013b9322SStella Laurenzo /// This is intended to be called by implementation code. 91013b9322SStella Laurenzo void registerOperationImpl(const std::string &operationName, 92b56d1ec6SPeter Hawkins nanobind::object pyClass, bool replace = false); 93013b9322SStella Laurenzo 94b57acb9aSJacques Pienaar /// Returns the custom Attribute builder for Attribute kind. 95b56d1ec6SPeter Hawkins std::optional<nanobind::callable> 96b57acb9aSJacques Pienaar lookupAttributeBuilder(const std::string &attributeKind); 97b57acb9aSJacques Pienaar 98bfb1ba75Smax /// Returns the custom type caster for MlirTypeID mlirTypeID. 99b56d1ec6SPeter Hawkins std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID, 100bfb1ba75Smax MlirDialect dialect); 101bfb1ba75Smax 1027c850867SMaksim Levental /// Returns the custom value caster for MlirTypeID mlirTypeID. 103b56d1ec6SPeter Hawkins std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID, 1047c850867SMaksim Levental MlirDialect dialect); 1057c850867SMaksim Levental 106013b9322SStella Laurenzo /// Looks up a registered dialect class by namespace. Note that this may 107013b9322SStella Laurenzo /// trigger loading of the defining module and can arbitrarily re-enter. 108b56d1ec6SPeter Hawkins std::optional<nanobind::object> 109013b9322SStella Laurenzo lookupDialectClass(const std::string &dialectNamespace); 110013b9322SStella Laurenzo 111a7f8b7cdSRahul Kayaith /// Looks up a registered operation class (deriving from OpView) by operation 112a7f8b7cdSRahul Kayaith /// name. Note that this may trigger a load of the dialect, which can 113a7f8b7cdSRahul Kayaith /// arbitrarily re-enter. 114b56d1ec6SPeter Hawkins std::optional<nanobind::object> 115a7f8b7cdSRahul Kayaith lookupOperationClass(llvm::StringRef operationName); 1168260db75SStella Laurenzo 117013b9322SStella Laurenzo private: 118013b9322SStella Laurenzo static PyGlobals *instance; 119*f136c800Svfdev 120*f136c800Svfdev nanobind::ft_mutex mutex; 121*f136c800Svfdev 122013b9322SStella Laurenzo /// Module name prefixes to search under for dialect implementation modules. 123013b9322SStella Laurenzo std::vector<std::string> dialectSearchPrefixes; 124013b9322SStella Laurenzo /// Map of dialect namespace to external dialect class object. 125b56d1ec6SPeter Hawkins llvm::StringMap<nanobind::object> dialectClassMap; 126013b9322SStella Laurenzo /// Map of full operation name to external operation class object. 127b56d1ec6SPeter Hawkins llvm::StringMap<nanobind::object> operationClassMap; 128b57acb9aSJacques Pienaar /// Map of attribute ODS name to custom builder. 129b56d1ec6SPeter Hawkins llvm::StringMap<nanobind::callable> attributeBuilderMap; 130bfb1ba75Smax /// Map of MlirTypeID to custom type caster. 131b56d1ec6SPeter Hawkins llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap; 1327c850867SMaksim Levental /// Map of MlirTypeID to custom value caster. 133b56d1ec6SPeter Hawkins llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap; 1348260db75SStella Laurenzo /// Set of dialect namespaces that we have attempted to import implementation 1358260db75SStella Laurenzo /// modules for. 1365192e299SMaksim Levental llvm::StringSet<> loadedDialectModules; 137013b9322SStella Laurenzo }; 138013b9322SStella Laurenzo 139013b9322SStella Laurenzo } // namespace python 140013b9322SStella Laurenzo } // namespace mlir 141013b9322SStella Laurenzo 142013b9322SStella Laurenzo #endif // MLIR_BINDINGS_PYTHON_GLOBALS_H 143