1 //===- Attributes.h - MLIR Attribute Classes --------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_ATTRIBUTES_H 10 #define MLIR_IR_ATTRIBUTES_H 11 12 #include "mlir/IR/AttributeSupport.h" 13 #include "llvm/Support/PointerLikeTypeTraits.h" 14 15 namespace mlir { 16 class AsmState; 17 class StringAttr; 18 19 /// Attributes are known-constant values of operations. 20 /// 21 /// Instances of the Attribute class are references to immortal key-value pairs 22 /// with immutable, uniqued keys owned by MLIRContext. As such, an Attribute is 23 /// a thin wrapper around an underlying storage pointer. Attributes are usually 24 /// passed by value. 25 class Attribute { 26 public: 27 /// Utility class for implementing attributes. 28 template <typename ConcreteType, typename BaseType, typename StorageType, 29 template <typename T> class... Traits> 30 using AttrBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType, 31 detail::AttributeUniquer, Traits...>; 32 33 using ImplType = AttributeStorage; 34 using ValueType = void; 35 using AbstractTy = AbstractAttribute; 36 37 constexpr Attribute() = default; 38 /* implicit */ Attribute(const ImplType *impl) 39 : impl(const_cast<ImplType *>(impl)) {} 40 41 Attribute(const Attribute &other) = default; 42 Attribute &operator=(const Attribute &other) = default; 43 44 bool operator==(Attribute other) const { return impl == other.impl; } 45 bool operator!=(Attribute other) const { return !(*this == other); } 46 explicit operator bool() const { return impl; } 47 48 bool operator!() const { return impl == nullptr; } 49 50 /// Casting utility functions. These are deprecated and will be removed, 51 /// please prefer using the `llvm` namespace variants instead. 52 template <typename... Tys> 53 [[deprecated("Use mlir::isa<U>() instead")]] 54 bool isa() const; 55 template <typename... Tys> 56 [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]] 57 bool isa_and_nonnull() const; 58 template <typename U> 59 [[deprecated("Use mlir::dyn_cast<U>() instead")]] 60 U dyn_cast() const; 61 template <typename U> 62 [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]] 63 U dyn_cast_or_null() const; 64 template <typename U> 65 [[deprecated("Use mlir::cast<U>() instead")]] 66 U cast() const; 67 68 /// Return a unique identifier for the concrete attribute type. This is used 69 /// to support dynamic type casting. 70 TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); } 71 72 /// Return the context this attribute belongs to. 73 MLIRContext *getContext() const; 74 75 /// Get the dialect this attribute is registered to. 76 Dialect &getDialect() const { 77 return impl->getAbstractAttribute().getDialect(); 78 } 79 80 /// Print the attribute. If `elideType` is set, the attribute is printed 81 /// without a trailing colon type if it has one. 82 void print(raw_ostream &os, bool elideType = false) const; 83 void print(raw_ostream &os, AsmState &state, bool elideType = false) const; 84 void dump() const; 85 86 /// Print the attribute without dialect wrapping. 87 void printStripped(raw_ostream &os) const; 88 void printStripped(raw_ostream &os, AsmState &state) const; 89 90 /// Get an opaque pointer to the attribute. 91 const void *getAsOpaquePointer() const { return impl; } 92 /// Construct an attribute from the opaque pointer representation. 93 static Attribute getFromOpaquePointer(const void *ptr) { 94 return Attribute(reinterpret_cast<const ImplType *>(ptr)); 95 } 96 97 friend ::llvm::hash_code hash_value(Attribute arg); 98 99 /// Returns true if `InterfaceT` has been promised by the dialect or 100 /// implemented. 101 template <typename InterfaceT> 102 bool hasPromiseOrImplementsInterface() { 103 return dialect_extension_detail::hasPromisedInterface( 104 getDialect(), getTypeID(), InterfaceT::getInterfaceID()) || 105 mlir::isa<InterfaceT>(*this); 106 } 107 108 /// Returns true if the type was registered with a particular trait. 109 template <template <typename T> class Trait> 110 bool hasTrait() { 111 return getAbstractAttribute().hasTrait<Trait>(); 112 } 113 114 /// Return the abstract descriptor for this attribute. 115 const AbstractTy &getAbstractAttribute() const { 116 return impl->getAbstractAttribute(); 117 } 118 119 /// Walk all of the immediately nested sub-attributes and sub-types. This 120 /// method does not recurse into sub elements. 121 void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn, 122 function_ref<void(Type)> walkTypesFn) const { 123 getAbstractAttribute().walkImmediateSubElements(*this, walkAttrsFn, 124 walkTypesFn); 125 } 126 127 /// Replace the immediately nested sub-attributes and sub-types with those 128 /// provided. The order of the provided elements is derived from the order of 129 /// the elements returned by the callbacks of `walkImmediateSubElements`. The 130 /// element at index 0 would replace the very first attribute given by 131 /// `walkImmediateSubElements`. On success, the new instance with the values 132 /// replaced is returned. If replacement fails, nullptr is returned. 133 auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, 134 ArrayRef<Type> replTypes) const { 135 return getAbstractAttribute().replaceImmediateSubElements(*this, replAttrs, 136 replTypes); 137 } 138 139 /// Walk this attribute and all attibutes/types nested within using the 140 /// provided walk functions. See `AttrTypeWalker` for information on the 141 /// supported walk function types. 142 template <WalkOrder Order = WalkOrder::PostOrder, typename... WalkFns> 143 auto walk(WalkFns &&...walkFns) { 144 AttrTypeWalker walker; 145 (walker.addWalk(std::forward<WalkFns>(walkFns)), ...); 146 return walker.walk<Order>(*this); 147 } 148 149 /// Recursively replace all of the nested sub-attributes and sub-types using 150 /// the provided map functions. Returns nullptr in the case of failure. See 151 /// `AttrTypeReplacer` for information on the support replacement function 152 /// types. 153 template <typename... ReplacementFns> 154 auto replace(ReplacementFns &&...replacementFns) { 155 AttrTypeReplacer replacer; 156 (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)), 157 ...); 158 return replacer.replace(*this); 159 } 160 161 /// Return the internal Attribute implementation. 162 ImplType *getImpl() const { return impl; } 163 164 protected: 165 ImplType *impl{nullptr}; 166 }; 167 168 inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) { 169 attr.print(os); 170 return os; 171 } 172 173 template <typename... Tys> 174 bool Attribute::isa() const { 175 return llvm::isa<Tys...>(*this); 176 } 177 178 template <typename... Tys> 179 bool Attribute::isa_and_nonnull() const { 180 return llvm::isa_and_present<Tys...>(*this); 181 } 182 183 template <typename U> 184 U Attribute::dyn_cast() const { 185 return llvm::dyn_cast<U>(*this); 186 } 187 188 template <typename U> 189 U Attribute::dyn_cast_or_null() const { 190 return llvm::dyn_cast_if_present<U>(*this); 191 } 192 193 template <typename U> 194 U Attribute::cast() const { 195 return llvm::cast<U>(*this); 196 } 197 198 inline ::llvm::hash_code hash_value(Attribute arg) { 199 return DenseMapInfo<const Attribute::ImplType *>::getHashValue(arg.impl); 200 } 201 202 //===----------------------------------------------------------------------===// 203 // NamedAttribute 204 //===----------------------------------------------------------------------===// 205 206 /// NamedAttribute represents a combination of a name and an Attribute value. 207 class NamedAttribute { 208 public: 209 NamedAttribute(StringAttr name, Attribute value); 210 NamedAttribute(StringRef name, Attribute value); 211 212 /// Return the name of the attribute. 213 StringAttr getName() const; 214 215 /// Return the dialect of the name of this attribute, if the name is prefixed 216 /// by a dialect namespace. For example, `llvm.fast_math` would return the 217 /// LLVM dialect (if it is loaded). Returns nullptr if the dialect isn't 218 /// loaded, or if the name is not prefixed by a dialect namespace. 219 Dialect *getNameDialect() const; 220 221 /// Return the value of the attribute. 222 Attribute getValue() const { return value; } 223 224 /// Set the name of this attribute. 225 void setName(StringAttr newName); 226 227 /// Set the value of this attribute. 228 void setValue(Attribute newValue) { 229 assert(value && "expected valid attribute value"); 230 value = newValue; 231 } 232 233 /// Compare this attribute to the provided attribute, ordering by name. 234 bool operator<(const NamedAttribute &rhs) const; 235 /// Compare this attribute to the provided string, ordering by name. 236 bool operator<(StringRef rhs) const; 237 238 bool operator==(const NamedAttribute &rhs) const { 239 return name == rhs.name && value == rhs.value; 240 } 241 bool operator!=(const NamedAttribute &rhs) const { return !(*this == rhs); } 242 243 private: 244 NamedAttribute(Attribute name, Attribute value) : name(name), value(value) {} 245 246 /// Allow access to internals to enable hashing. 247 friend ::llvm::hash_code hash_value(const NamedAttribute &arg); 248 friend DenseMapInfo<NamedAttribute>; 249 250 /// The name of the attribute. This is represented as a StringAttr, but 251 /// type-erased to Attribute in the field. 252 Attribute name; 253 /// The value of the attribute. 254 Attribute value; 255 }; 256 257 inline ::llvm::hash_code hash_value(const NamedAttribute &arg) { 258 using AttrPairT = std::pair<Attribute, Attribute>; 259 return DenseMapInfo<AttrPairT>::getHashValue(AttrPairT(arg.name, arg.value)); 260 } 261 262 /// Allow walking and replacing the subelements of a NamedAttribute. 263 template <> 264 struct AttrTypeSubElementHandler<NamedAttribute> { 265 template <typename T> 266 static void walk(T param, AttrTypeImmediateSubElementWalker &walker) { 267 walker.walk(param.getName()); 268 walker.walk(param.getValue()); 269 } 270 template <typename T> 271 static T replace(T param, AttrSubElementReplacements &attrRepls, 272 TypeSubElementReplacements &typeRepls) { 273 ArrayRef<Attribute> paramRepls = attrRepls.take_front(2); 274 return T(cast<decltype(param.getName())>(paramRepls[0]), paramRepls[1]); 275 } 276 }; 277 278 //===----------------------------------------------------------------------===// 279 // AttributeTraitBase 280 //===----------------------------------------------------------------------===// 281 282 namespace AttributeTrait { 283 /// This class represents the base of an attribute trait. 284 template <typename ConcreteType, template <typename> class TraitType> 285 using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>; 286 } // namespace AttributeTrait 287 288 //===----------------------------------------------------------------------===// 289 // AttributeInterface 290 //===----------------------------------------------------------------------===// 291 292 /// This class represents the base of an attribute interface. See the definition 293 /// of `detail::Interface` for requirements on the `Traits` type. 294 template <typename ConcreteType, typename Traits> 295 class AttributeInterface 296 : public detail::Interface<ConcreteType, Attribute, Traits, Attribute, 297 AttributeTrait::TraitBase> { 298 public: 299 using Base = AttributeInterface<ConcreteType, Traits>; 300 using InterfaceBase = detail::Interface<ConcreteType, Attribute, Traits, 301 Attribute, AttributeTrait::TraitBase>; 302 using InterfaceBase::InterfaceBase; 303 304 protected: 305 /// Returns the impl interface instance for the given type. 306 static typename InterfaceBase::Concept *getInterfaceFor(Attribute attr) { 307 #ifndef NDEBUG 308 // Check that the current interface isn't an unresolved promise for the 309 // given attribute. 310 dialect_extension_detail::handleUseOfUndefinedPromisedInterface( 311 attr.getDialect(), attr.getTypeID(), ConcreteType::getInterfaceID(), 312 llvm::getTypeName<ConcreteType>()); 313 #endif 314 315 return attr.getAbstractAttribute().getInterface<ConcreteType>(); 316 } 317 318 /// Allow access to 'getInterfaceFor'. 319 friend InterfaceBase; 320 }; 321 322 //===----------------------------------------------------------------------===// 323 // Core AttributeTrait 324 //===----------------------------------------------------------------------===// 325 326 namespace AttributeTrait { 327 /// This trait is used to determine if an attribute is mutable or not. It is 328 /// attached on an attribute if the corresponding ConcreteType defines a 329 /// `mutate` function with proper signature. 330 template <typename ConcreteType> 331 using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>; 332 333 /// This trait is used to determine if an attribute is a location or not. It is 334 /// attached to an attribute by the user if they intend the attribute to be used 335 /// as a location. 336 template <typename ConcreteType> 337 struct IsLocation : public AttributeTrait::TraitBase<ConcreteType, IsLocation> { 338 }; 339 } // namespace AttributeTrait 340 341 } // namespace mlir. 342 343 namespace llvm { 344 345 // Attribute hash just like pointers. 346 template <> 347 struct DenseMapInfo<mlir::Attribute> { 348 static mlir::Attribute getEmptyKey() { 349 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); 350 return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)); 351 } 352 static mlir::Attribute getTombstoneKey() { 353 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); 354 return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)); 355 } 356 static unsigned getHashValue(mlir::Attribute val) { 357 return mlir::hash_value(val); 358 } 359 static bool isEqual(mlir::Attribute LHS, mlir::Attribute RHS) { 360 return LHS == RHS; 361 } 362 }; 363 template <typename T> 364 struct DenseMapInfo< 365 T, std::enable_if_t<std::is_base_of<mlir::Attribute, T>::value && 366 !mlir::detail::IsInterface<T>::value>> 367 : public DenseMapInfo<mlir::Attribute> { 368 static T getEmptyKey() { 369 const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey(); 370 return T::getFromOpaquePointer(pointer); 371 } 372 static T getTombstoneKey() { 373 const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey(); 374 return T::getFromOpaquePointer(pointer); 375 } 376 }; 377 378 /// Allow LLVM to steal the low bits of Attributes. 379 template <> 380 struct PointerLikeTypeTraits<mlir::Attribute> { 381 static inline void *getAsVoidPointer(mlir::Attribute attr) { 382 return const_cast<void *>(attr.getAsOpaquePointer()); 383 } 384 static inline mlir::Attribute getFromVoidPointer(void *ptr) { 385 return mlir::Attribute::getFromOpaquePointer(ptr); 386 } 387 static constexpr int NumLowBitsAvailable = llvm::PointerLikeTypeTraits< 388 mlir::AttributeStorage *>::NumLowBitsAvailable; 389 }; 390 391 template <> 392 struct DenseMapInfo<mlir::NamedAttribute> { 393 static mlir::NamedAttribute getEmptyKey() { 394 auto emptyAttr = llvm::DenseMapInfo<mlir::Attribute>::getEmptyKey(); 395 return mlir::NamedAttribute(emptyAttr, emptyAttr); 396 } 397 static mlir::NamedAttribute getTombstoneKey() { 398 auto tombAttr = llvm::DenseMapInfo<mlir::Attribute>::getTombstoneKey(); 399 return mlir::NamedAttribute(tombAttr, tombAttr); 400 } 401 static unsigned getHashValue(mlir::NamedAttribute val) { 402 return mlir::hash_value(val); 403 } 404 static bool isEqual(mlir::NamedAttribute lhs, mlir::NamedAttribute rhs) { 405 return lhs == rhs; 406 } 407 }; 408 409 /// Add support for llvm style casts. We provide a cast between To and From if 410 /// From is mlir::Attribute or derives from it. 411 template <typename To, typename From> 412 struct CastInfo<To, From, 413 std::enable_if_t<std::is_same_v<mlir::Attribute, 414 std::remove_const_t<From>> || 415 std::is_base_of_v<mlir::Attribute, From>>> 416 : NullableValueCastFailed<To>, 417 DefaultDoCastIfPossible<To, From, CastInfo<To, From>> { 418 /// Arguments are taken as mlir::Attribute here and not as `From`, because 419 /// when casting from an intermediate type of the hierarchy to one of its 420 /// children, the val.getTypeID() inside T::classof will use the static 421 /// getTypeID of the parent instead of the non-static Type::getTypeID that 422 /// returns the dynamic ID. This means that T::classof would end up comparing 423 /// the static TypeID of the children to the static TypeID of its parent, 424 /// making it impossible to downcast from the parent to the child. 425 static inline bool isPossible(mlir::Attribute ty) { 426 /// Return a constant true instead of a dynamic true when casting to self or 427 /// up the hierarchy. 428 if constexpr (std::is_base_of_v<To, From>) { 429 return true; 430 } else { 431 return To::classof(ty); 432 } 433 } 434 static inline To doCast(mlir::Attribute attr) { return To(attr.getImpl()); } 435 }; 436 437 } // namespace llvm 438 439 #endif 440