xref: /llvm-project/mlir/include/mlir/IR/Attributes.h (revision 98de5dfe6a8cbb70f21de545acec4710a77294ed)
1 //===- Attributes.h - MLIR Attribute Classes --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_ATTRIBUTES_H
10 #define MLIR_IR_ATTRIBUTES_H
11 
12 #include "mlir/IR/AttributeSupport.h"
13 #include "llvm/Support/PointerLikeTypeTraits.h"
14 
15 namespace mlir {
16 class AsmState;
17 class StringAttr;
18 
19 /// Attributes are known-constant values of operations.
20 ///
21 /// Instances of the Attribute class are references to immortal key-value pairs
22 /// with immutable, uniqued keys owned by MLIRContext. As such, an Attribute is
23 /// a thin wrapper around an underlying storage pointer. Attributes are usually
24 /// passed by value.
25 class Attribute {
26 public:
27   /// Utility class for implementing attributes.
28   template <typename ConcreteType, typename BaseType, typename StorageType,
29             template <typename T> class... Traits>
30   using AttrBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
31                                            detail::AttributeUniquer, Traits...>;
32 
33   using ImplType = AttributeStorage;
34   using ValueType = void;
35   using AbstractTy = AbstractAttribute;
36 
37   constexpr Attribute() = default;
38   /* implicit */ Attribute(const ImplType *impl)
39       : impl(const_cast<ImplType *>(impl)) {}
40 
41   Attribute(const Attribute &other) = default;
42   Attribute &operator=(const Attribute &other) = default;
43 
44   bool operator==(Attribute other) const { return impl == other.impl; }
45   bool operator!=(Attribute other) const { return !(*this == other); }
46   explicit operator bool() const { return impl; }
47 
48   bool operator!() const { return impl == nullptr; }
49 
50   /// Casting utility functions. These are deprecated and will be removed,
51   /// please prefer using the `llvm` namespace variants instead.
52   template <typename... Tys>
53   [[deprecated("Use mlir::isa<U>() instead")]]
54   bool isa() const;
55   template <typename... Tys>
56   [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]]
57   bool isa_and_nonnull() const;
58   template <typename U>
59   [[deprecated("Use mlir::dyn_cast<U>() instead")]]
60   U dyn_cast() const;
61   template <typename U>
62   [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
63   U dyn_cast_or_null() const;
64   template <typename U>
65   [[deprecated("Use mlir::cast<U>() instead")]]
66   U cast() const;
67 
68   /// Return a unique identifier for the concrete attribute type. This is used
69   /// to support dynamic type casting.
70   TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); }
71 
72   /// Return the context this attribute belongs to.
73   MLIRContext *getContext() const;
74 
75   /// Get the dialect this attribute is registered to.
76   Dialect &getDialect() const {
77     return impl->getAbstractAttribute().getDialect();
78   }
79 
80   /// Print the attribute. If `elideType` is set, the attribute is printed
81   /// without a trailing colon type if it has one.
82   void print(raw_ostream &os, bool elideType = false) const;
83   void print(raw_ostream &os, AsmState &state, bool elideType = false) const;
84   void dump() const;
85 
86   /// Print the attribute without dialect wrapping.
87   void printStripped(raw_ostream &os) const;
88   void printStripped(raw_ostream &os, AsmState &state) const;
89 
90   /// Get an opaque pointer to the attribute.
91   const void *getAsOpaquePointer() const { return impl; }
92   /// Construct an attribute from the opaque pointer representation.
93   static Attribute getFromOpaquePointer(const void *ptr) {
94     return Attribute(reinterpret_cast<const ImplType *>(ptr));
95   }
96 
97   friend ::llvm::hash_code hash_value(Attribute arg);
98 
99   /// Returns true if `InterfaceT` has been promised by the dialect or
100   /// implemented.
101   template <typename InterfaceT>
102   bool hasPromiseOrImplementsInterface() {
103     return dialect_extension_detail::hasPromisedInterface(
104                getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
105            mlir::isa<InterfaceT>(*this);
106   }
107 
108   /// Returns true if the type was registered with a particular trait.
109   template <template <typename T> class Trait>
110   bool hasTrait() {
111     return getAbstractAttribute().hasTrait<Trait>();
112   }
113 
114   /// Return the abstract descriptor for this attribute.
115   const AbstractTy &getAbstractAttribute() const {
116     return impl->getAbstractAttribute();
117   }
118 
119   /// Walk all of the immediately nested sub-attributes and sub-types. This
120   /// method does not recurse into sub elements.
121   void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
122                                 function_ref<void(Type)> walkTypesFn) const {
123     getAbstractAttribute().walkImmediateSubElements(*this, walkAttrsFn,
124                                                     walkTypesFn);
125   }
126 
127   /// Replace the immediately nested sub-attributes and sub-types with those
128   /// provided. The order of the provided elements is derived from the order of
129   /// the elements returned by the callbacks of `walkImmediateSubElements`. The
130   /// element at index 0 would replace the very first attribute given by
131   /// `walkImmediateSubElements`. On success, the new instance with the values
132   /// replaced is returned. If replacement fails, nullptr is returned.
133   auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
134                                    ArrayRef<Type> replTypes) const {
135     return getAbstractAttribute().replaceImmediateSubElements(*this, replAttrs,
136                                                               replTypes);
137   }
138 
139   /// Walk this attribute and all attibutes/types nested within using the
140   /// provided walk functions. See `AttrTypeWalker` for information on the
141   /// supported walk function types.
142   template <WalkOrder Order = WalkOrder::PostOrder, typename... WalkFns>
143   auto walk(WalkFns &&...walkFns) {
144     AttrTypeWalker walker;
145     (walker.addWalk(std::forward<WalkFns>(walkFns)), ...);
146     return walker.walk<Order>(*this);
147   }
148 
149   /// Recursively replace all of the nested sub-attributes and sub-types using
150   /// the provided map functions. Returns nullptr in the case of failure. See
151   /// `AttrTypeReplacer` for information on the support replacement function
152   /// types.
153   template <typename... ReplacementFns>
154   auto replace(ReplacementFns &&...replacementFns) {
155     AttrTypeReplacer replacer;
156     (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)),
157      ...);
158     return replacer.replace(*this);
159   }
160 
161   /// Return the internal Attribute implementation.
162   ImplType *getImpl() const { return impl; }
163 
164 protected:
165   ImplType *impl{nullptr};
166 };
167 
168 inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
169   attr.print(os);
170   return os;
171 }
172 
173 template <typename... Tys>
174 bool Attribute::isa() const {
175   return llvm::isa<Tys...>(*this);
176 }
177 
178 template <typename... Tys>
179 bool Attribute::isa_and_nonnull() const {
180   return llvm::isa_and_present<Tys...>(*this);
181 }
182 
183 template <typename U>
184 U Attribute::dyn_cast() const {
185   return llvm::dyn_cast<U>(*this);
186 }
187 
188 template <typename U>
189 U Attribute::dyn_cast_or_null() const {
190   return llvm::dyn_cast_if_present<U>(*this);
191 }
192 
193 template <typename U>
194 U Attribute::cast() const {
195   return llvm::cast<U>(*this);
196 }
197 
198 inline ::llvm::hash_code hash_value(Attribute arg) {
199   return DenseMapInfo<const Attribute::ImplType *>::getHashValue(arg.impl);
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // NamedAttribute
204 //===----------------------------------------------------------------------===//
205 
206 /// NamedAttribute represents a combination of a name and an Attribute value.
207 class NamedAttribute {
208 public:
209   NamedAttribute(StringAttr name, Attribute value);
210   NamedAttribute(StringRef name, Attribute value);
211 
212   /// Return the name of the attribute.
213   StringAttr getName() const;
214 
215   /// Return the dialect of the name of this attribute, if the name is prefixed
216   /// by a dialect namespace. For example, `llvm.fast_math` would return the
217   /// LLVM dialect (if it is loaded). Returns nullptr if the dialect isn't
218   /// loaded, or if the name is not prefixed by a dialect namespace.
219   Dialect *getNameDialect() const;
220 
221   /// Return the value of the attribute.
222   Attribute getValue() const { return value; }
223 
224   /// Set the name of this attribute.
225   void setName(StringAttr newName);
226 
227   /// Set the value of this attribute.
228   void setValue(Attribute newValue) {
229     assert(value && "expected valid attribute value");
230     value = newValue;
231   }
232 
233   /// Compare this attribute to the provided attribute, ordering by name.
234   bool operator<(const NamedAttribute &rhs) const;
235   /// Compare this attribute to the provided string, ordering by name.
236   bool operator<(StringRef rhs) const;
237 
238   bool operator==(const NamedAttribute &rhs) const {
239     return name == rhs.name && value == rhs.value;
240   }
241   bool operator!=(const NamedAttribute &rhs) const { return !(*this == rhs); }
242 
243 private:
244   NamedAttribute(Attribute name, Attribute value) : name(name), value(value) {}
245 
246   /// Allow access to internals to enable hashing.
247   friend ::llvm::hash_code hash_value(const NamedAttribute &arg);
248   friend DenseMapInfo<NamedAttribute>;
249 
250   /// The name of the attribute. This is represented as a StringAttr, but
251   /// type-erased to Attribute in the field.
252   Attribute name;
253   /// The value of the attribute.
254   Attribute value;
255 };
256 
257 inline ::llvm::hash_code hash_value(const NamedAttribute &arg) {
258   using AttrPairT = std::pair<Attribute, Attribute>;
259   return DenseMapInfo<AttrPairT>::getHashValue(AttrPairT(arg.name, arg.value));
260 }
261 
262 /// Allow walking and replacing the subelements of a NamedAttribute.
263 template <>
264 struct AttrTypeSubElementHandler<NamedAttribute> {
265   template <typename T>
266   static void walk(T param, AttrTypeImmediateSubElementWalker &walker) {
267     walker.walk(param.getName());
268     walker.walk(param.getValue());
269   }
270   template <typename T>
271   static T replace(T param, AttrSubElementReplacements &attrRepls,
272                    TypeSubElementReplacements &typeRepls) {
273     ArrayRef<Attribute> paramRepls = attrRepls.take_front(2);
274     return T(cast<decltype(param.getName())>(paramRepls[0]), paramRepls[1]);
275   }
276 };
277 
278 //===----------------------------------------------------------------------===//
279 // AttributeTraitBase
280 //===----------------------------------------------------------------------===//
281 
282 namespace AttributeTrait {
283 /// This class represents the base of an attribute trait.
284 template <typename ConcreteType, template <typename> class TraitType>
285 using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>;
286 } // namespace AttributeTrait
287 
288 //===----------------------------------------------------------------------===//
289 // AttributeInterface
290 //===----------------------------------------------------------------------===//
291 
292 /// This class represents the base of an attribute interface. See the definition
293 /// of `detail::Interface` for requirements on the `Traits` type.
294 template <typename ConcreteType, typename Traits>
295 class AttributeInterface
296     : public detail::Interface<ConcreteType, Attribute, Traits, Attribute,
297                                AttributeTrait::TraitBase> {
298 public:
299   using Base = AttributeInterface<ConcreteType, Traits>;
300   using InterfaceBase = detail::Interface<ConcreteType, Attribute, Traits,
301                                           Attribute, AttributeTrait::TraitBase>;
302   using InterfaceBase::InterfaceBase;
303 
304 protected:
305   /// Returns the impl interface instance for the given type.
306   static typename InterfaceBase::Concept *getInterfaceFor(Attribute attr) {
307 #ifndef NDEBUG
308     // Check that the current interface isn't an unresolved promise for the
309     // given attribute.
310     dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
311         attr.getDialect(), attr.getTypeID(), ConcreteType::getInterfaceID(),
312         llvm::getTypeName<ConcreteType>());
313 #endif
314 
315     return attr.getAbstractAttribute().getInterface<ConcreteType>();
316   }
317 
318   /// Allow access to 'getInterfaceFor'.
319   friend InterfaceBase;
320 };
321 
322 //===----------------------------------------------------------------------===//
323 // Core AttributeTrait
324 //===----------------------------------------------------------------------===//
325 
326 namespace AttributeTrait {
327 /// This trait is used to determine if an attribute is mutable or not. It is
328 /// attached on an attribute if the corresponding ConcreteType defines a
329 /// `mutate` function with proper signature.
330 template <typename ConcreteType>
331 using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
332 
333 /// This trait is used to determine if an attribute is a location or not. It is
334 /// attached to an attribute by the user if they intend the attribute to be used
335 /// as a location.
336 template <typename ConcreteType>
337 struct IsLocation : public AttributeTrait::TraitBase<ConcreteType, IsLocation> {
338 };
339 } // namespace AttributeTrait
340 
341 } // namespace mlir.
342 
343 namespace llvm {
344 
345 // Attribute hash just like pointers.
346 template <>
347 struct DenseMapInfo<mlir::Attribute> {
348   static mlir::Attribute getEmptyKey() {
349     auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
350     return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
351   }
352   static mlir::Attribute getTombstoneKey() {
353     auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
354     return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
355   }
356   static unsigned getHashValue(mlir::Attribute val) {
357     return mlir::hash_value(val);
358   }
359   static bool isEqual(mlir::Attribute LHS, mlir::Attribute RHS) {
360     return LHS == RHS;
361   }
362 };
363 template <typename T>
364 struct DenseMapInfo<
365     T, std::enable_if_t<std::is_base_of<mlir::Attribute, T>::value &&
366                         !mlir::detail::IsInterface<T>::value>>
367     : public DenseMapInfo<mlir::Attribute> {
368   static T getEmptyKey() {
369     const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();
370     return T::getFromOpaquePointer(pointer);
371   }
372   static T getTombstoneKey() {
373     const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey();
374     return T::getFromOpaquePointer(pointer);
375   }
376 };
377 
378 /// Allow LLVM to steal the low bits of Attributes.
379 template <>
380 struct PointerLikeTypeTraits<mlir::Attribute> {
381   static inline void *getAsVoidPointer(mlir::Attribute attr) {
382     return const_cast<void *>(attr.getAsOpaquePointer());
383   }
384   static inline mlir::Attribute getFromVoidPointer(void *ptr) {
385     return mlir::Attribute::getFromOpaquePointer(ptr);
386   }
387   static constexpr int NumLowBitsAvailable = llvm::PointerLikeTypeTraits<
388       mlir::AttributeStorage *>::NumLowBitsAvailable;
389 };
390 
391 template <>
392 struct DenseMapInfo<mlir::NamedAttribute> {
393   static mlir::NamedAttribute getEmptyKey() {
394     auto emptyAttr = llvm::DenseMapInfo<mlir::Attribute>::getEmptyKey();
395     return mlir::NamedAttribute(emptyAttr, emptyAttr);
396   }
397   static mlir::NamedAttribute getTombstoneKey() {
398     auto tombAttr = llvm::DenseMapInfo<mlir::Attribute>::getTombstoneKey();
399     return mlir::NamedAttribute(tombAttr, tombAttr);
400   }
401   static unsigned getHashValue(mlir::NamedAttribute val) {
402     return mlir::hash_value(val);
403   }
404   static bool isEqual(mlir::NamedAttribute lhs, mlir::NamedAttribute rhs) {
405     return lhs == rhs;
406   }
407 };
408 
409 /// Add support for llvm style casts. We provide a cast between To and From if
410 /// From is mlir::Attribute or derives from it.
411 template <typename To, typename From>
412 struct CastInfo<To, From,
413                 std::enable_if_t<std::is_same_v<mlir::Attribute,
414                                                 std::remove_const_t<From>> ||
415                                  std::is_base_of_v<mlir::Attribute, From>>>
416     : NullableValueCastFailed<To>,
417       DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
418   /// Arguments are taken as mlir::Attribute here and not as `From`, because
419   /// when casting from an intermediate type of the hierarchy to one of its
420   /// children, the val.getTypeID() inside T::classof will use the static
421   /// getTypeID of the parent instead of the non-static Type::getTypeID that
422   /// returns the dynamic ID. This means that T::classof would end up comparing
423   /// the static TypeID of the children to the static TypeID of its parent,
424   /// making it impossible to downcast from the parent to the child.
425   static inline bool isPossible(mlir::Attribute ty) {
426     /// Return a constant true instead of a dynamic true when casting to self or
427     /// up the hierarchy.
428     if constexpr (std::is_base_of_v<To, From>) {
429       return true;
430     } else {
431       return To::classof(ty);
432     }
433   }
434   static inline To doCast(mlir::Attribute attr) { return To(attr.getImpl()); }
435 };
436 
437 } // namespace llvm
438 
439 #endif
440