xref: /llvm-project/mlir/include/mlir/IR/AttributeSupport.h (revision 3dbac2c007c114a720300d2a4d79abe9ca1351e7)
1 //===- AttributeSupport.h ---------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines support types for registering dialect extended attributes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_ATTRIBUTESUPPORT_H
14 #define MLIR_IR_ATTRIBUTESUPPORT_H
15 
16 #include "mlir/IR/MLIRContext.h"
17 #include "mlir/IR/StorageUniquerSupport.h"
18 #include "mlir/IR/Types.h"
19 #include "llvm/ADT/PointerIntPair.h"
20 #include "llvm/ADT/Twine.h"
21 
22 namespace mlir {
23 //===----------------------------------------------------------------------===//
24 // AbstractAttribute
25 //===----------------------------------------------------------------------===//
26 
27 /// This class contains all of the static information common to all instances of
28 /// a registered Attribute.
29 class AbstractAttribute {
30 public:
31   using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
32   using WalkImmediateSubElementsFn = function_ref<void(
33       Attribute, function_ref<void(Attribute)>, function_ref<void(Type)>)>;
34   using ReplaceImmediateSubElementsFn =
35       function_ref<Attribute(Attribute, ArrayRef<Attribute>, ArrayRef<Type>)>;
36 
37   /// Look up the specified abstract attribute in the MLIRContext and return a
38   /// reference to it.
39   static const AbstractAttribute &lookup(TypeID typeID, MLIRContext *context);
40 
41   /// Look up the specified abstract attribute in the MLIRContext and return a
42   /// reference to it if it exists.
43   static std::optional<std::reference_wrapper<const AbstractAttribute>>
44   lookup(StringRef name, MLIRContext *context);
45 
46   /// This method is used by Dialect objects when they register the list of
47   /// attributes they contain.
48   template <typename T>
get(Dialect & dialect)49   static AbstractAttribute get(Dialect &dialect) {
50     return AbstractAttribute(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
51                              T::getWalkImmediateSubElementsFn(),
52                              T::getReplaceImmediateSubElementsFn(),
53                              T::getTypeID(), T::name);
54   }
55 
56   /// This method is used by Dialect objects to register attributes with
57   /// custom TypeIDs.
58   /// The use of this method is in general discouraged in favor of
59   /// 'get<CustomAttribute>(dialect)'.
60   static AbstractAttribute
get(Dialect & dialect,detail::InterfaceMap && interfaceMap,HasTraitFn && hasTrait,WalkImmediateSubElementsFn walkImmediateSubElementsFn,ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,TypeID typeID,StringRef name)61   get(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
62       HasTraitFn &&hasTrait,
63       WalkImmediateSubElementsFn walkImmediateSubElementsFn,
64       ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
65       TypeID typeID, StringRef name) {
66     return AbstractAttribute(dialect, std::move(interfaceMap),
67                              std::move(hasTrait), walkImmediateSubElementsFn,
68                              replaceImmediateSubElementsFn, typeID, name);
69   }
70 
71   /// Return the dialect this attribute was registered to.
getDialect()72   Dialect &getDialect() const { return const_cast<Dialect &>(dialect); }
73 
74   /// Returns an instance of the concept object for the given interface if it
75   /// was registered to this attribute, null otherwise. This should not be used
76   /// directly.
77   template <typename T>
getInterface()78   typename T::Concept *getInterface() const {
79     return interfaceMap.lookup<T>();
80   }
81 
82   /// Returns true if the attribute has the interface with the given ID
83   /// registered.
hasInterface(TypeID interfaceID)84   bool hasInterface(TypeID interfaceID) const {
85     return interfaceMap.contains(interfaceID);
86   }
87 
88   /// Returns true if the attribute has a particular trait.
89   template <template <typename T> class Trait>
hasTrait()90   bool hasTrait() const {
91     return hasTraitFn(TypeID::get<Trait>());
92   }
93 
94   /// Returns true if the attribute has a particular trait.
hasTrait(TypeID traitID)95   bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
96 
97   /// Walk the immediate sub-elements of this attribute.
98   void walkImmediateSubElements(Attribute attr,
99                                 function_ref<void(Attribute)> walkAttrsFn,
100                                 function_ref<void(Type)> walkTypesFn) const;
101 
102   /// Replace the immediate sub-elements of this attribute.
103   Attribute replaceImmediateSubElements(Attribute attr,
104                                         ArrayRef<Attribute> replAttrs,
105                                         ArrayRef<Type> replTypes) const;
106 
107   /// Return the unique identifier representing the concrete attribute class.
getTypeID()108   TypeID getTypeID() const { return typeID; }
109 
110   /// Return the unique name representing the type.
getName()111   StringRef getName() const { return name; }
112 
113 private:
AbstractAttribute(Dialect & dialect,detail::InterfaceMap && interfaceMap,HasTraitFn && hasTraitFn,WalkImmediateSubElementsFn walkImmediateSubElementsFn,ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,TypeID typeID,StringRef name)114   AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
115                     HasTraitFn &&hasTraitFn,
116                     WalkImmediateSubElementsFn walkImmediateSubElementsFn,
117                     ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
118                     TypeID typeID, StringRef name)
119       : dialect(dialect), interfaceMap(std::move(interfaceMap)),
120         hasTraitFn(std::move(hasTraitFn)),
121         walkImmediateSubElementsFn(walkImmediateSubElementsFn),
122         replaceImmediateSubElementsFn(replaceImmediateSubElementsFn),
123         typeID(typeID), name(name) {}
124 
125   /// Give StorageUserBase access to the mutable lookup.
126   template <typename ConcreteT, typename BaseT, typename StorageT,
127             typename UniquerT, template <typename T> class... Traits>
128   friend class detail::StorageUserBase;
129 
130   /// Look up the specified abstract attribute in the MLIRContext and return a
131   /// (mutable) pointer to it. Return a null pointer if the attribute could not
132   /// be found in the context.
133   static AbstractAttribute *lookupMutable(TypeID typeID, MLIRContext *context);
134 
135   /// This is the dialect that this attribute was registered to.
136   const Dialect &dialect;
137 
138   /// This is a collection of the interfaces registered to this attribute.
139   detail::InterfaceMap interfaceMap;
140 
141   /// Function to check if the attribute has a particular trait.
142   HasTraitFn hasTraitFn;
143 
144   /// Function to walk the immediate sub-elements of this attribute.
145   WalkImmediateSubElementsFn walkImmediateSubElementsFn;
146 
147   /// Function to replace the immediate sub-elements of this attribute.
148   ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn;
149 
150   /// The unique identifier of the derived Attribute class.
151   const TypeID typeID;
152 
153   /// The unique name of this attribute. The string is not owned by the context,
154   /// so the lifetime of this string should outlive the MLIR context.
155   const StringRef name;
156 };
157 
158 //===----------------------------------------------------------------------===//
159 // AttributeStorage
160 //===----------------------------------------------------------------------===//
161 
162 namespace detail {
163 class AttributeUniquer;
164 class DistinctAttributeUniquer;
165 } // namespace detail
166 
167 /// Base storage class appearing in an attribute. Derived storage classes should
168 /// only be constructed within the context of the AttributeUniquer.
169 class alignas(8) AttributeStorage : public StorageUniquer::BaseStorage {
170   friend detail::AttributeUniquer;
171   friend detail::DistinctAttributeUniquer;
172   friend StorageUniquer;
173 
174 public:
175   /// Return the abstract descriptor for this attribute.
getAbstractAttribute()176   const AbstractAttribute &getAbstractAttribute() const {
177     assert(abstractAttribute && "Malformed attribute storage object.");
178     return *abstractAttribute;
179   }
180 
181 protected:
182   /// Set the abstract attribute for this storage instance. This is used by the
183   /// AttributeUniquer when initializing a newly constructed storage object.
initializeAbstractAttribute(const AbstractAttribute & abstractAttr)184   void initializeAbstractAttribute(const AbstractAttribute &abstractAttr) {
185     abstractAttribute = &abstractAttr;
186   }
187 
188   /// Default initialization for attribute storage classes that require no
189   /// additional initialization.
initialize(MLIRContext * context)190   void initialize(MLIRContext *context) {}
191 
192 private:
193   /// The abstract descriptor for this attribute.
194   const AbstractAttribute *abstractAttribute = nullptr;
195 };
196 
197 /// Default storage type for attributes that require no additional
198 /// initialization or storage.
199 using DefaultAttributeStorage = AttributeStorage;
200 
201 //===----------------------------------------------------------------------===//
202 // AttributeStorageAllocator
203 //===----------------------------------------------------------------------===//
204 
205 // This is a utility allocator used to allocate memory for instances of derived
206 // Attributes.
207 using AttributeStorageAllocator = StorageUniquer::StorageAllocator;
208 
209 //===----------------------------------------------------------------------===//
210 // AttributeUniquer
211 //===----------------------------------------------------------------------===//
212 namespace detail {
213 // A utility class to get, or create, unique instances of attributes within an
214 // MLIRContext. This class manages all creation and uniquing of attributes.
215 class AttributeUniquer {
216 public:
217   /// Get an uniqued instance of an attribute T.
218   template <typename T, typename... Args>
get(MLIRContext * ctx,Args &&...args)219   static T get(MLIRContext *ctx, Args &&...args) {
220     return getWithTypeID<T, Args...>(ctx, T::getTypeID(),
221                                      std::forward<Args>(args)...);
222   }
223 
224   /// Get an uniqued instance of a parametric attribute T.
225   /// The use of this method is in general discouraged in favor of
226   /// 'get<T, Args>(ctx, args)'.
227   template <typename T, typename... Args>
228   static std::enable_if_t<
229       !std::is_same<typename T::ImplType, AttributeStorage>::value, T>
getWithTypeID(MLIRContext * ctx,TypeID typeID,Args &&...args)230   getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) {
231 #ifndef NDEBUG
232     if (!ctx->getAttributeUniquer().isParametricStorageInitialized(typeID))
233       llvm::report_fatal_error(
234           llvm::Twine("can't create Attribute '") + llvm::getTypeName<T>() +
235           "' because storage uniquer isn't initialized: the dialect was likely "
236           "not loaded, or the attribute wasn't added with addAttributes<...>() "
237           "in the Dialect::initialize() method.");
238 #endif
239     return ctx->getAttributeUniquer().get<typename T::ImplType>(
240         [typeID, ctx](AttributeStorage *storage) {
241           initializeAttributeStorage(storage, ctx, typeID);
242 
243           // Execute any additional attribute storage initialization with the
244           // context.
245           static_cast<typename T::ImplType *>(storage)->initialize(ctx);
246         },
247         typeID, std::forward<Args>(args)...);
248   }
249   /// Get an uniqued instance of a singleton attribute T.
250   /// The use of this method is in general discouraged in favor of
251   /// 'get<T, Args>(ctx, args)'.
252   template <typename T>
253   static std::enable_if_t<
254       std::is_same<typename T::ImplType, AttributeStorage>::value, T>
getWithTypeID(MLIRContext * ctx,TypeID typeID)255   getWithTypeID(MLIRContext *ctx, TypeID typeID) {
256 #ifndef NDEBUG
257     if (!ctx->getAttributeUniquer().isSingletonStorageInitialized(typeID))
258       llvm::report_fatal_error(
259           llvm::Twine("can't create Attribute '") + llvm::getTypeName<T>() +
260           "' because storage uniquer isn't initialized: the dialect was likely "
261           "not loaded, or the attribute wasn't added with addAttributes<...>() "
262           "in the Dialect::initialize() method.");
263 #endif
264     return ctx->getAttributeUniquer().get<typename T::ImplType>(typeID);
265   }
266 
267   template <typename T, typename... Args>
mutate(MLIRContext * ctx,typename T::ImplType * impl,Args &&...args)268   static LogicalResult mutate(MLIRContext *ctx, typename T::ImplType *impl,
269                               Args &&...args) {
270     assert(impl && "cannot mutate null attribute");
271     return ctx->getAttributeUniquer().mutate(T::getTypeID(), impl,
272                                              std::forward<Args>(args)...);
273   }
274 
275   /// Register an attribute instance T with the uniquer.
276   template <typename T>
registerAttribute(MLIRContext * ctx)277   static void registerAttribute(MLIRContext *ctx) {
278     registerAttribute<T>(ctx, T::getTypeID());
279   }
280 
281   /// Register a parametric attribute instance T with the uniquer.
282   /// The use of this method is in general discouraged in favor of
283   /// 'registerAttribute<T>(ctx)'.
284   template <typename T>
285   static std::enable_if_t<
286       !std::is_same<typename T::ImplType, AttributeStorage>::value>
registerAttribute(MLIRContext * ctx,TypeID typeID)287   registerAttribute(MLIRContext *ctx, TypeID typeID) {
288     ctx->getAttributeUniquer()
289         .registerParametricStorageType<typename T::ImplType>(typeID);
290   }
291   /// Register a singleton attribute instance T with the uniquer.
292   /// The use of this method is in general discouraged in favor of
293   /// 'registerAttribute<T>(ctx)'.
294   template <typename T>
295   static std::enable_if_t<
296       std::is_same<typename T::ImplType, AttributeStorage>::value>
registerAttribute(MLIRContext * ctx,TypeID typeID)297   registerAttribute(MLIRContext *ctx, TypeID typeID) {
298     ctx->getAttributeUniquer()
299         .registerSingletonStorageType<typename T::ImplType>(
300             typeID, [ctx, typeID](AttributeStorage *storage) {
301               initializeAttributeStorage(storage, ctx, typeID);
302             });
303   }
304 
305 private:
306   /// Initialize the given attribute storage instance.
307   static void initializeAttributeStorage(AttributeStorage *storage,
308                                          MLIRContext *ctx, TypeID attrID);
309 };
310 
311 // Internal function called by ODS generated code.
312 // Default initializes the type within a FailureOr<T> if T is default
313 // constructible and returns a reference to the instance.
314 // Otherwise, returns a reference to the FailureOr<T>.
315 template <class T>
decltype(auto)316 decltype(auto) unwrapForCustomParse(FailureOr<T> &failureOr) {
317   if constexpr (std::is_default_constructible_v<T>)
318     return failureOr.emplace();
319   else
320     return failureOr;
321 }
322 
323 } // namespace detail
324 
325 } // namespace mlir
326 
327 #endif
328