1 //===- Types.h - MLIR Type Classes ------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_TYPES_H 10 #define MLIR_IR_TYPES_H 11 12 #include "mlir/IR/TypeSupport.h" 13 #include "llvm/ADT/ArrayRef.h" 14 #include "llvm/ADT/DenseMapInfo.h" 15 #include "llvm/Support/PointerLikeTypeTraits.h" 16 17 namespace mlir { 18 class AsmState; 19 20 /// Instances of the Type class are uniqued, have an immutable identifier and an 21 /// optional mutable component. They wrap a pointer to the storage object owned 22 /// by MLIRContext. Therefore, instances of Type are passed around by value. 23 /// 24 /// Some types are "primitives" meaning they do not have any parameters, for 25 /// example the Index type. Parametric types have additional information that 26 /// differentiates the types of the same class, for example the Integer type has 27 /// bitwidth, making i8 and i16 belong to the same kind by be different 28 /// instances of the IntegerType. Type parameters are part of the unique 29 /// immutable key. The mutable component of the type can be modified after the 30 /// type is created, but cannot affect the identity of the type. 31 /// 32 /// Types are constructed and uniqued via the 'detail::TypeUniquer' class. 33 /// 34 /// Derived type classes are expected to implement several required 35 /// implementation hooks: 36 /// * Optional: 37 /// - static LogicalResult verifyInvariants( 38 /// function_ref<InFlightDiagnostic()> emitError, 39 /// Args... args) 40 /// * This method is invoked when calling the 'TypeBase::get/getChecked' 41 /// methods to ensure that the arguments passed in are valid to construct 42 /// a type instance with. 43 /// * This method is expected to return failure if a type cannot be 44 /// constructed with 'args', success otherwise. 45 /// * 'args' must correspond with the arguments passed into the 46 /// 'TypeBase::get' call. 47 /// 48 /// 49 /// Type storage objects inherit from TypeStorage and contain the following: 50 /// - The dialect that defined the type. 51 /// - Any parameters of the type. 52 /// - An optional mutable component. 53 /// For non-parametric types, a convenience DefaultTypeStorage is provided. 54 /// Parametric storage types must derive TypeStorage and respect the following: 55 /// - Define a type alias, KeyTy, to a type that uniquely identifies the 56 /// instance of the type. 57 /// * The key type must be constructible from the values passed into the 58 /// detail::TypeUniquer::get call. 59 /// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the 60 /// storage class must define a hashing method: 61 /// 'static unsigned hashKey(const KeyTy &)' 62 /// 63 /// - Provide a method, 'bool operator==(const KeyTy &) const', to 64 /// compare the storage instance against an instance of the key type. 65 /// 66 /// - Provide a static construction method: 67 /// 'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)' 68 /// that builds a unique instance of the derived storage. The arguments to 69 /// this function are an allocator to store any uniqued data within the 70 /// context and the key type for this storage. 71 /// 72 /// - If they have a mutable component, this component must not be a part of 73 /// the key. 74 class Type { 75 public: 76 /// Utility class for implementing types. 77 template <typename ConcreteType, typename BaseType, typename StorageType, 78 template <typename T> class... Traits> 79 using TypeBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType, 80 detail::TypeUniquer, Traits...>; 81 82 using ImplType = TypeStorage; 83 84 using AbstractTy = AbstractType; 85 86 constexpr Type() = default; 87 /* implicit */ Type(const ImplType *impl) 88 : impl(const_cast<ImplType *>(impl)) {} 89 90 Type(const Type &other) = default; 91 Type &operator=(const Type &other) = default; 92 93 bool operator==(Type other) const { return impl == other.impl; } 94 bool operator!=(Type other) const { return !(*this == other); } 95 explicit operator bool() const { return impl; } 96 97 bool operator!() const { return impl == nullptr; } 98 99 template <typename... Tys> 100 [[deprecated("Use mlir::isa<U>() instead")]] 101 bool isa() const; 102 template <typename... Tys> 103 [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]] 104 bool isa_and_nonnull() const; 105 template <typename U> 106 [[deprecated("Use mlir::dyn_cast<U>() instead")]] 107 U dyn_cast() const; 108 template <typename U> 109 [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]] 110 U dyn_cast_or_null() const; 111 template <typename U> 112 [[deprecated("Use mlir::cast<U>() instead")]] 113 U cast() const; 114 115 /// Return a unique identifier for the concrete type. This is used to support 116 /// dynamic type casting. 117 TypeID getTypeID() { return impl->getAbstractType().getTypeID(); } 118 119 /// Return the MLIRContext in which this type was uniqued. 120 MLIRContext *getContext() const; 121 122 /// Get the dialect this type is registered to. 123 Dialect &getDialect() const { return impl->getAbstractType().getDialect(); } 124 125 // Convenience predicates. This is only for floating point types, 126 // derived types should use isa/dyn_cast. 127 bool isIndex() const; 128 bool isBF16() const; 129 bool isF16() const; 130 bool isTF32() const; 131 bool isF32() const; 132 bool isF64() const; 133 bool isF80() const; 134 bool isF128() const; 135 136 /// Return true if this is an integer type (with the specified width). 137 bool isInteger() const; 138 bool isInteger(unsigned width) const; 139 /// Return true if this is a signless integer type (with the specified width). 140 bool isSignlessInteger() const; 141 bool isSignlessInteger(unsigned width) const; 142 /// Return true if this is a signed integer type (with the specified width). 143 bool isSignedInteger() const; 144 bool isSignedInteger(unsigned width) const; 145 /// Return true if this is an unsigned integer type (with the specified 146 /// width). 147 bool isUnsignedInteger() const; 148 bool isUnsignedInteger(unsigned width) const; 149 150 /// Return the bit width of an integer or a float type, assert failure on 151 /// other types. 152 unsigned getIntOrFloatBitWidth() const; 153 154 /// Return true if this is a signless integer or index type. 155 bool isSignlessIntOrIndex() const; 156 /// Return true if this is a signless integer, index, or float type. 157 bool isSignlessIntOrIndexOrFloat() const; 158 /// Return true of this is a signless integer or a float type. 159 bool isSignlessIntOrFloat() const; 160 161 /// Return true if this is an integer (of any signedness) or an index type. 162 bool isIntOrIndex() const; 163 /// Return true if this is an integer (of any signedness) or a float type. 164 bool isIntOrFloat() const; 165 /// Return true if this is an integer (of any signedness), index, or float 166 /// type. 167 bool isIntOrIndexOrFloat() const; 168 169 /// Print the current type. 170 void print(raw_ostream &os) const; 171 void print(raw_ostream &os, AsmState &state) const; 172 void dump() const; 173 174 friend ::llvm::hash_code hash_value(Type arg); 175 176 /// Methods for supporting PointerLikeTypeTraits. 177 const void *getAsOpaquePointer() const { 178 return static_cast<const void *>(impl); 179 } 180 static Type getFromOpaquePointer(const void *pointer) { 181 return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer))); 182 } 183 184 /// Returns true if `InterfaceT` has been promised by the dialect or 185 /// implemented. 186 template <typename InterfaceT> 187 bool hasPromiseOrImplementsInterface() { 188 return dialect_extension_detail::hasPromisedInterface( 189 getDialect(), getTypeID(), InterfaceT::getInterfaceID()) || 190 mlir::isa<InterfaceT>(*this); 191 } 192 193 /// Returns true if the type was registered with a particular trait. 194 template <template <typename T> class Trait> 195 bool hasTrait() { 196 return getAbstractType().hasTrait<Trait>(); 197 } 198 199 /// Return the abstract type descriptor for this type. 200 const AbstractTy &getAbstractType() const { return impl->getAbstractType(); } 201 202 /// Return the Type implementation. 203 ImplType *getImpl() const { return impl; } 204 205 /// Walk all of the immediately nested sub-attributes and sub-types. This 206 /// method does not recurse into sub elements. 207 void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn, 208 function_ref<void(Type)> walkTypesFn) const { 209 getAbstractType().walkImmediateSubElements(*this, walkAttrsFn, walkTypesFn); 210 } 211 212 /// Replace the immediately nested sub-attributes and sub-types with those 213 /// provided. The order of the provided elements is derived from the order of 214 /// the elements returned by the callbacks of `walkImmediateSubElements`. The 215 /// element at index 0 would replace the very first attribute given by 216 /// `walkImmediateSubElements`. On success, the new instance with the values 217 /// replaced is returned. If replacement fails, nullptr is returned. 218 auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, 219 ArrayRef<Type> replTypes) const { 220 return getAbstractType().replaceImmediateSubElements(*this, replAttrs, 221 replTypes); 222 } 223 224 /// Walk this type and all attibutes/types nested within using the 225 /// provided walk functions. See `AttrTypeWalker` for information on the 226 /// supported walk function types. 227 template <WalkOrder Order = WalkOrder::PostOrder, typename... WalkFns> 228 auto walk(WalkFns &&...walkFns) { 229 AttrTypeWalker walker; 230 (walker.addWalk(std::forward<WalkFns>(walkFns)), ...); 231 return walker.walk<Order>(*this); 232 } 233 234 /// Recursively replace all of the nested sub-attributes and sub-types using 235 /// the provided map functions. Returns nullptr in the case of failure. See 236 /// `AttrTypeReplacer` for information on the support replacement function 237 /// types. 238 template <typename... ReplacementFns> 239 auto replace(ReplacementFns &&...replacementFns) { 240 AttrTypeReplacer replacer; 241 (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)), 242 ...); 243 return replacer.replace(*this); 244 } 245 246 protected: 247 ImplType *impl{nullptr}; 248 }; 249 250 inline raw_ostream &operator<<(raw_ostream &os, Type type) { 251 type.print(os); 252 return os; 253 } 254 255 //===----------------------------------------------------------------------===// 256 // TypeTraitBase 257 //===----------------------------------------------------------------------===// 258 259 namespace TypeTrait { 260 /// This class represents the base of a type trait. 261 template <typename ConcreteType, template <typename> class TraitType> 262 using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>; 263 } // namespace TypeTrait 264 265 //===----------------------------------------------------------------------===// 266 // TypeInterface 267 //===----------------------------------------------------------------------===// 268 269 /// This class represents the base of a type interface. See the definition of 270 /// `detail::Interface` for requirements on the `Traits` type. 271 template <typename ConcreteType, typename Traits> 272 class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type, 273 TypeTrait::TraitBase> { 274 public: 275 using Base = TypeInterface<ConcreteType, Traits>; 276 using InterfaceBase = 277 detail::Interface<ConcreteType, Type, Traits, Type, TypeTrait::TraitBase>; 278 using InterfaceBase::InterfaceBase; 279 280 protected: 281 /// Returns the impl interface instance for the given type. 282 static typename InterfaceBase::Concept *getInterfaceFor(Type type) { 283 #ifndef NDEBUG 284 // Check that the current interface isn't an unresolved promise for the 285 // given type. 286 dialect_extension_detail::handleUseOfUndefinedPromisedInterface( 287 type.getDialect(), type.getTypeID(), ConcreteType::getInterfaceID(), 288 llvm::getTypeName<ConcreteType>()); 289 #endif 290 291 return type.getAbstractType().getInterface<ConcreteType>(); 292 } 293 294 /// Allow access to 'getInterfaceFor'. 295 friend InterfaceBase; 296 }; 297 298 //===----------------------------------------------------------------------===// 299 // Core TypeTrait 300 //===----------------------------------------------------------------------===// 301 302 /// This trait is used to determine if a type is mutable or not. It is attached 303 /// on a type if the corresponding ImplType defines a `mutate` function with 304 /// a proper signature. 305 namespace TypeTrait { 306 template <typename ConcreteType> 307 using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>; 308 } // namespace TypeTrait 309 310 //===----------------------------------------------------------------------===// 311 // Type Utils 312 //===----------------------------------------------------------------------===// 313 314 // Make Type hashable. 315 inline ::llvm::hash_code hash_value(Type arg) { 316 return DenseMapInfo<const Type::ImplType *>::getHashValue(arg.impl); 317 } 318 319 template <typename... Tys> 320 bool Type::isa() const { 321 return llvm::isa<Tys...>(*this); 322 } 323 324 template <typename... Tys> 325 bool Type::isa_and_nonnull() const { 326 return llvm::isa_and_present<Tys...>(*this); 327 } 328 329 template <typename U> 330 U Type::dyn_cast() const { 331 return llvm::dyn_cast<U>(*this); 332 } 333 334 template <typename U> 335 U Type::dyn_cast_or_null() const { 336 return llvm::dyn_cast_or_null<U>(*this); 337 } 338 339 template <typename U> 340 U Type::cast() const { 341 return llvm::cast<U>(*this); 342 } 343 344 } // namespace mlir 345 346 namespace llvm { 347 348 // Type hash just like pointers. 349 template <> 350 struct DenseMapInfo<mlir::Type> { 351 static mlir::Type getEmptyKey() { 352 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); 353 return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer)); 354 } 355 static mlir::Type getTombstoneKey() { 356 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); 357 return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer)); 358 } 359 static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(val); } 360 static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; } 361 }; 362 template <typename T> 363 struct DenseMapInfo<T, std::enable_if_t<std::is_base_of<mlir::Type, T>::value && 364 !mlir::detail::IsInterface<T>::value>> 365 : public DenseMapInfo<mlir::Type> { 366 static T getEmptyKey() { 367 const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey(); 368 return T::getFromOpaquePointer(pointer); 369 } 370 static T getTombstoneKey() { 371 const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey(); 372 return T::getFromOpaquePointer(pointer); 373 } 374 }; 375 376 /// We align TypeStorage by 8, so allow LLVM to steal the low bits. 377 template <> 378 struct PointerLikeTypeTraits<mlir::Type> { 379 public: 380 static inline void *getAsVoidPointer(mlir::Type I) { 381 return const_cast<void *>(I.getAsOpaquePointer()); 382 } 383 static inline mlir::Type getFromVoidPointer(void *P) { 384 return mlir::Type::getFromOpaquePointer(P); 385 } 386 static constexpr int NumLowBitsAvailable = 3; 387 }; 388 389 /// Add support for llvm style casts. 390 /// We provide a cast between To and From if From is mlir::Type or derives from 391 /// it 392 template <typename To, typename From> 393 struct CastInfo< 394 To, From, 395 std::enable_if_t<std::is_same_v<mlir::Type, std::remove_const_t<From>> || 396 std::is_base_of_v<mlir::Type, From>>> 397 : NullableValueCastFailed<To>, 398 DefaultDoCastIfPossible<To, From, CastInfo<To, From>> { 399 /// Arguments are taken as mlir::Type here and not as `From`, because when 400 /// casting from an intermediate type of the hierarchy to one of its children, 401 /// the val.getTypeID() inside T::classof will use the static getTypeID of the 402 /// parent instead of the non-static Type::getTypeID that returns the dynamic 403 /// ID. This means that T::classof would end up comparing the static TypeID of 404 /// the children to the static TypeID of its parent, making it impossible to 405 /// downcast from the parent to the child. 406 static inline bool isPossible(mlir::Type ty) { 407 /// Return a constant true instead of a dynamic true when casting to self or 408 /// up the hierarchy. 409 if constexpr (std::is_base_of_v<To, From>) { 410 return true; 411 } else { 412 return To::classof(ty); 413 }; 414 } 415 static inline To doCast(mlir::Type ty) { return To(ty.getImpl()); } 416 }; 417 418 } // namespace llvm 419 420 #endif // MLIR_IR_TYPES_H 421