xref: /llvm-project/mlir/include/mlir/IR/Types.h (revision 7a77f14c0abfbecbfb800ea8d974e66d81ee516a)
1 //===- Types.h - MLIR Type Classes ------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_TYPES_H
10 #define MLIR_IR_TYPES_H
11 
12 #include "mlir/IR/TypeSupport.h"
13 #include "llvm/ADT/ArrayRef.h"
14 #include "llvm/ADT/DenseMapInfo.h"
15 #include "llvm/Support/PointerLikeTypeTraits.h"
16 
17 namespace mlir {
18 class AsmState;
19 
20 /// Instances of the Type class are uniqued, have an immutable identifier and an
21 /// optional mutable component.  They wrap a pointer to the storage object owned
22 /// by MLIRContext.  Therefore, instances of Type are passed around by value.
23 ///
24 /// Some types are "primitives" meaning they do not have any parameters, for
25 /// example the Index type.  Parametric types have additional information that
26 /// differentiates the types of the same class, for example the Integer type has
27 /// bitwidth, making i8 and i16 belong to the same kind by be different
28 /// instances of the IntegerType. Type parameters are part of the unique
29 /// immutable key.  The mutable component of the type can be modified after the
30 /// type is created, but cannot affect the identity of the type.
31 ///
32 /// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
33 ///
34 /// Derived type classes are expected to implement several required
35 /// implementation hooks:
36 ///  * Optional:
37 ///    - static LogicalResult verifyInvariants(
38 ///                                function_ref<InFlightDiagnostic()> emitError,
39 ///                                Args... args)
40 ///      * This method is invoked when calling the 'TypeBase::get/getChecked'
41 ///        methods to ensure that the arguments passed in are valid to construct
42 ///        a type instance with.
43 ///      * This method is expected to return failure if a type cannot be
44 ///        constructed with 'args', success otherwise.
45 ///      * 'args' must correspond with the arguments passed into the
46 ///        'TypeBase::get' call.
47 ///
48 ///
49 /// Type storage objects inherit from TypeStorage and contain the following:
50 ///    - The dialect that defined the type.
51 ///    - Any parameters of the type.
52 ///    - An optional mutable component.
53 /// For non-parametric types, a convenience DefaultTypeStorage is provided.
54 /// Parametric storage types must derive TypeStorage and respect the following:
55 ///    - Define a type alias, KeyTy, to a type that uniquely identifies the
56 ///      instance of the type.
57 ///      * The key type must be constructible from the values passed into the
58 ///        detail::TypeUniquer::get call.
59 ///      * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
60 ///        storage class must define a hashing method:
61 ///         'static unsigned hashKey(const KeyTy &)'
62 ///
63 ///    - Provide a method, 'bool operator==(const KeyTy &) const', to
64 ///      compare the storage instance against an instance of the key type.
65 ///
66 ///    - Provide a static construction method:
67 ///        'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)'
68 ///      that builds a unique instance of the derived storage. The arguments to
69 ///      this function are an allocator to store any uniqued data within the
70 ///      context and the key type for this storage.
71 ///
72 ///    - If they have a mutable component, this component must not be a part of
73 ///      the key.
74 class Type {
75 public:
76   /// Utility class for implementing types.
77   template <typename ConcreteType, typename BaseType, typename StorageType,
78             template <typename T> class... Traits>
79   using TypeBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
80                                            detail::TypeUniquer, Traits...>;
81 
82   using ImplType = TypeStorage;
83 
84   using AbstractTy = AbstractType;
85 
86   constexpr Type() = default;
87   /* implicit */ Type(const ImplType *impl)
88       : impl(const_cast<ImplType *>(impl)) {}
89 
90   Type(const Type &other) = default;
91   Type &operator=(const Type &other) = default;
92 
93   bool operator==(Type other) const { return impl == other.impl; }
94   bool operator!=(Type other) const { return !(*this == other); }
95   explicit operator bool() const { return impl; }
96 
97   bool operator!() const { return impl == nullptr; }
98 
99   template <typename... Tys>
100   [[deprecated("Use mlir::isa<U>() instead")]]
101   bool isa() const;
102   template <typename... Tys>
103   [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]]
104   bool isa_and_nonnull() const;
105   template <typename U>
106   [[deprecated("Use mlir::dyn_cast<U>() instead")]]
107   U dyn_cast() const;
108   template <typename U>
109   [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
110   U dyn_cast_or_null() const;
111   template <typename U>
112   [[deprecated("Use mlir::cast<U>() instead")]]
113   U cast() const;
114 
115   /// Return a unique identifier for the concrete type. This is used to support
116   /// dynamic type casting.
117   TypeID getTypeID() { return impl->getAbstractType().getTypeID(); }
118 
119   /// Return the MLIRContext in which this type was uniqued.
120   MLIRContext *getContext() const;
121 
122   /// Get the dialect this type is registered to.
123   Dialect &getDialect() const { return impl->getAbstractType().getDialect(); }
124 
125   // Convenience predicates.  This is only for floating point types,
126   // derived types should use isa/dyn_cast.
127   bool isIndex() const;
128   bool isBF16() const;
129   bool isF16() const;
130   bool isTF32() const;
131   bool isF32() const;
132   bool isF64() const;
133   bool isF80() const;
134   bool isF128() const;
135 
136   /// Return true if this is an integer type (with the specified width).
137   bool isInteger() const;
138   bool isInteger(unsigned width) const;
139   /// Return true if this is a signless integer type (with the specified width).
140   bool isSignlessInteger() const;
141   bool isSignlessInteger(unsigned width) const;
142   /// Return true if this is a signed integer type (with the specified width).
143   bool isSignedInteger() const;
144   bool isSignedInteger(unsigned width) const;
145   /// Return true if this is an unsigned integer type (with the specified
146   /// width).
147   bool isUnsignedInteger() const;
148   bool isUnsignedInteger(unsigned width) const;
149 
150   /// Return the bit width of an integer or a float type, assert failure on
151   /// other types.
152   unsigned getIntOrFloatBitWidth() const;
153 
154   /// Return true if this is a signless integer or index type.
155   bool isSignlessIntOrIndex() const;
156   /// Return true if this is a signless integer, index, or float type.
157   bool isSignlessIntOrIndexOrFloat() const;
158   /// Return true of this is a signless integer or a float type.
159   bool isSignlessIntOrFloat() const;
160 
161   /// Return true if this is an integer (of any signedness) or an index type.
162   bool isIntOrIndex() const;
163   /// Return true if this is an integer (of any signedness) or a float type.
164   bool isIntOrFloat() const;
165   /// Return true if this is an integer (of any signedness), index, or float
166   /// type.
167   bool isIntOrIndexOrFloat() const;
168 
169   /// Print the current type.
170   void print(raw_ostream &os) const;
171   void print(raw_ostream &os, AsmState &state) const;
172   void dump() const;
173 
174   friend ::llvm::hash_code hash_value(Type arg);
175 
176   /// Methods for supporting PointerLikeTypeTraits.
177   const void *getAsOpaquePointer() const {
178     return static_cast<const void *>(impl);
179   }
180   static Type getFromOpaquePointer(const void *pointer) {
181     return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
182   }
183 
184   /// Returns true if `InterfaceT` has been promised by the dialect or
185   /// implemented.
186   template <typename InterfaceT>
187   bool hasPromiseOrImplementsInterface() {
188     return dialect_extension_detail::hasPromisedInterface(
189                getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
190            mlir::isa<InterfaceT>(*this);
191   }
192 
193   /// Returns true if the type was registered with a particular trait.
194   template <template <typename T> class Trait>
195   bool hasTrait() {
196     return getAbstractType().hasTrait<Trait>();
197   }
198 
199   /// Return the abstract type descriptor for this type.
200   const AbstractTy &getAbstractType() const { return impl->getAbstractType(); }
201 
202   /// Return the Type implementation.
203   ImplType *getImpl() const { return impl; }
204 
205   /// Walk all of the immediately nested sub-attributes and sub-types. This
206   /// method does not recurse into sub elements.
207   void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
208                                 function_ref<void(Type)> walkTypesFn) const {
209     getAbstractType().walkImmediateSubElements(*this, walkAttrsFn, walkTypesFn);
210   }
211 
212   /// Replace the immediately nested sub-attributes and sub-types with those
213   /// provided. The order of the provided elements is derived from the order of
214   /// the elements returned by the callbacks of `walkImmediateSubElements`. The
215   /// element at index 0 would replace the very first attribute given by
216   /// `walkImmediateSubElements`. On success, the new instance with the values
217   /// replaced is returned. If replacement fails, nullptr is returned.
218   auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
219                                    ArrayRef<Type> replTypes) const {
220     return getAbstractType().replaceImmediateSubElements(*this, replAttrs,
221                                                          replTypes);
222   }
223 
224   /// Walk this type and all attibutes/types nested within using the
225   /// provided walk functions. See `AttrTypeWalker` for information on the
226   /// supported walk function types.
227   template <WalkOrder Order = WalkOrder::PostOrder, typename... WalkFns>
228   auto walk(WalkFns &&...walkFns) {
229     AttrTypeWalker walker;
230     (walker.addWalk(std::forward<WalkFns>(walkFns)), ...);
231     return walker.walk<Order>(*this);
232   }
233 
234   /// Recursively replace all of the nested sub-attributes and sub-types using
235   /// the provided map functions. Returns nullptr in the case of failure. See
236   /// `AttrTypeReplacer` for information on the support replacement function
237   /// types.
238   template <typename... ReplacementFns>
239   auto replace(ReplacementFns &&...replacementFns) {
240     AttrTypeReplacer replacer;
241     (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)),
242      ...);
243     return replacer.replace(*this);
244   }
245 
246 protected:
247   ImplType *impl{nullptr};
248 };
249 
250 inline raw_ostream &operator<<(raw_ostream &os, Type type) {
251   type.print(os);
252   return os;
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // TypeTraitBase
257 //===----------------------------------------------------------------------===//
258 
259 namespace TypeTrait {
260 /// This class represents the base of a type trait.
261 template <typename ConcreteType, template <typename> class TraitType>
262 using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>;
263 } // namespace TypeTrait
264 
265 //===----------------------------------------------------------------------===//
266 // TypeInterface
267 //===----------------------------------------------------------------------===//
268 
269 /// This class represents the base of a type interface. See the definition  of
270 /// `detail::Interface` for requirements on the `Traits` type.
271 template <typename ConcreteType, typename Traits>
272 class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,
273                                                TypeTrait::TraitBase> {
274 public:
275   using Base = TypeInterface<ConcreteType, Traits>;
276   using InterfaceBase =
277       detail::Interface<ConcreteType, Type, Traits, Type, TypeTrait::TraitBase>;
278   using InterfaceBase::InterfaceBase;
279 
280 protected:
281   /// Returns the impl interface instance for the given type.
282   static typename InterfaceBase::Concept *getInterfaceFor(Type type) {
283 #ifndef NDEBUG
284     // Check that the current interface isn't an unresolved promise for the
285     // given type.
286     dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
287         type.getDialect(), type.getTypeID(), ConcreteType::getInterfaceID(),
288         llvm::getTypeName<ConcreteType>());
289 #endif
290 
291     return type.getAbstractType().getInterface<ConcreteType>();
292   }
293 
294   /// Allow access to 'getInterfaceFor'.
295   friend InterfaceBase;
296 };
297 
298 //===----------------------------------------------------------------------===//
299 // Core TypeTrait
300 //===----------------------------------------------------------------------===//
301 
302 /// This trait is used to determine if a type is mutable or not. It is attached
303 /// on a type if the corresponding ImplType defines a `mutate` function with
304 /// a proper signature.
305 namespace TypeTrait {
306 template <typename ConcreteType>
307 using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
308 } // namespace TypeTrait
309 
310 //===----------------------------------------------------------------------===//
311 // Type Utils
312 //===----------------------------------------------------------------------===//
313 
314 // Make Type hashable.
315 inline ::llvm::hash_code hash_value(Type arg) {
316   return DenseMapInfo<const Type::ImplType *>::getHashValue(arg.impl);
317 }
318 
319 template <typename... Tys>
320 bool Type::isa() const {
321   return llvm::isa<Tys...>(*this);
322 }
323 
324 template <typename... Tys>
325 bool Type::isa_and_nonnull() const {
326   return llvm::isa_and_present<Tys...>(*this);
327 }
328 
329 template <typename U>
330 U Type::dyn_cast() const {
331   return llvm::dyn_cast<U>(*this);
332 }
333 
334 template <typename U>
335 U Type::dyn_cast_or_null() const {
336   return llvm::dyn_cast_or_null<U>(*this);
337 }
338 
339 template <typename U>
340 U Type::cast() const {
341   return llvm::cast<U>(*this);
342 }
343 
344 } // namespace mlir
345 
346 namespace llvm {
347 
348 // Type hash just like pointers.
349 template <>
350 struct DenseMapInfo<mlir::Type> {
351   static mlir::Type getEmptyKey() {
352     auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
353     return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
354   }
355   static mlir::Type getTombstoneKey() {
356     auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
357     return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
358   }
359   static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(val); }
360   static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; }
361 };
362 template <typename T>
363 struct DenseMapInfo<T, std::enable_if_t<std::is_base_of<mlir::Type, T>::value &&
364                                         !mlir::detail::IsInterface<T>::value>>
365     : public DenseMapInfo<mlir::Type> {
366   static T getEmptyKey() {
367     const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();
368     return T::getFromOpaquePointer(pointer);
369   }
370   static T getTombstoneKey() {
371     const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey();
372     return T::getFromOpaquePointer(pointer);
373   }
374 };
375 
376 /// We align TypeStorage by 8, so allow LLVM to steal the low bits.
377 template <>
378 struct PointerLikeTypeTraits<mlir::Type> {
379 public:
380   static inline void *getAsVoidPointer(mlir::Type I) {
381     return const_cast<void *>(I.getAsOpaquePointer());
382   }
383   static inline mlir::Type getFromVoidPointer(void *P) {
384     return mlir::Type::getFromOpaquePointer(P);
385   }
386   static constexpr int NumLowBitsAvailable = 3;
387 };
388 
389 /// Add support for llvm style casts.
390 /// We provide a cast between To and From if From is mlir::Type or derives from
391 /// it
392 template <typename To, typename From>
393 struct CastInfo<
394     To, From,
395     std::enable_if_t<std::is_same_v<mlir::Type, std::remove_const_t<From>> ||
396                      std::is_base_of_v<mlir::Type, From>>>
397     : NullableValueCastFailed<To>,
398       DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
399   /// Arguments are taken as mlir::Type here and not as `From`, because when
400   /// casting from an intermediate type of the hierarchy to one of its children,
401   /// the val.getTypeID() inside T::classof will use the static getTypeID of the
402   /// parent instead of the non-static Type::getTypeID that returns the dynamic
403   /// ID. This means that T::classof would end up comparing the static TypeID of
404   /// the children to the static TypeID of its parent, making it impossible to
405   /// downcast from the parent to the child.
406   static inline bool isPossible(mlir::Type ty) {
407     /// Return a constant true instead of a dynamic true when casting to self or
408     /// up the hierarchy.
409     if constexpr (std::is_base_of_v<To, From>) {
410       return true;
411     } else {
412       return To::classof(ty);
413     };
414   }
415   static inline To doCast(mlir::Type ty) { return To(ty.getImpl()); }
416 };
417 
418 } // namespace llvm
419 
420 #endif // MLIR_IR_TYPES_H
421