xref: /llvm-project/mlir/include/mlir/IR/TypeSupport.h (revision 3dbac2c007c114a720300d2a4d79abe9ca1351e7)
1 //===- TypeSupport.h --------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines support types for registering dialect extended types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_TYPESUPPORT_H
14 #define MLIR_IR_TYPESUPPORT_H
15 
16 #include "mlir/IR/MLIRContext.h"
17 #include "mlir/IR/StorageUniquerSupport.h"
18 #include "llvm/ADT/Twine.h"
19 
20 namespace mlir {
21 class Dialect;
22 class MLIRContext;
23 
24 //===----------------------------------------------------------------------===//
25 // AbstractType
26 //===----------------------------------------------------------------------===//
27 
28 /// This class contains all of the static information common to all instances of
29 /// a registered Type.
30 class AbstractType {
31 public:
32   using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
33   using WalkImmediateSubElementsFn = function_ref<void(
34       Type, function_ref<void(Attribute)>, function_ref<void(Type)>)>;
35   using ReplaceImmediateSubElementsFn =
36       function_ref<Type(Type, ArrayRef<Attribute>, ArrayRef<Type>)>;
37 
38   /// Look up the specified abstract type in the MLIRContext and return a
39   /// reference to it.
40   static const AbstractType &lookup(TypeID typeID, MLIRContext *context);
41 
42   /// Look up the specified abstract type in the MLIRContext and return a
43   /// reference to it if it exists.
44   static std::optional<std::reference_wrapper<const AbstractType>>
45   lookup(StringRef name, MLIRContext *context);
46 
47   /// This method is used by Dialect objects when they register the list of
48   /// types they contain.
49   template <typename T>
get(Dialect & dialect)50   static AbstractType get(Dialect &dialect) {
51     return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
52                         T::getWalkImmediateSubElementsFn(),
53                         T::getReplaceImmediateSubElementsFn(), T::getTypeID(),
54                         T::name);
55   }
56 
57   /// This method is used by Dialect objects to register types with
58   /// custom TypeIDs.
59   /// The use of this method is in general discouraged in favor of
60   /// 'get<CustomType>(dialect)';
61   static AbstractType
get(Dialect & dialect,detail::InterfaceMap && interfaceMap,HasTraitFn && hasTrait,WalkImmediateSubElementsFn walkImmediateSubElementsFn,ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,TypeID typeID,StringRef name)62   get(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
63       HasTraitFn &&hasTrait,
64       WalkImmediateSubElementsFn walkImmediateSubElementsFn,
65       ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
66       TypeID typeID, StringRef name) {
67     return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait),
68                         walkImmediateSubElementsFn,
69                         replaceImmediateSubElementsFn, typeID, name);
70   }
71 
72   /// Return the dialect this type was registered to.
getDialect()73   Dialect &getDialect() const { return const_cast<Dialect &>(dialect); }
74 
75   /// Returns an instance of the concept object for the given interface if it
76   /// was registered to this type, null otherwise. This should not be used
77   /// directly.
78   template <typename T>
getInterface()79   typename T::Concept *getInterface() const {
80     return interfaceMap.lookup<T>();
81   }
82 
83   /// Returns true if the type has the interface with the given ID.
hasInterface(TypeID interfaceID)84   bool hasInterface(TypeID interfaceID) const {
85     return interfaceMap.contains(interfaceID);
86   }
87 
88   /// Returns true if the type has a particular trait.
89   template <template <typename T> class Trait>
hasTrait()90   bool hasTrait() const {
91     return hasTraitFn(TypeID::get<Trait>());
92   }
93 
94   /// Returns true if the type has a particular trait.
hasTrait(TypeID traitID)95   bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
96 
97   /// Walk the immediate sub-elements of the given type.
98   void walkImmediateSubElements(Type type,
99                                 function_ref<void(Attribute)> walkAttrsFn,
100                                 function_ref<void(Type)> walkTypesFn) const;
101 
102   /// Replace the immediate sub-elements of the given type.
103   Type replaceImmediateSubElements(Type type, ArrayRef<Attribute> replAttrs,
104                                    ArrayRef<Type> replTypes) const;
105 
106   /// Return the unique identifier representing the concrete type class.
getTypeID()107   TypeID getTypeID() const { return typeID; }
108 
109   /// Return the unique name representing the type.
getName()110   StringRef getName() const { return name; }
111 
112 private:
AbstractType(Dialect & dialect,detail::InterfaceMap && interfaceMap,HasTraitFn && hasTrait,WalkImmediateSubElementsFn walkImmediateSubElementsFn,ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,TypeID typeID,StringRef name)113   AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
114                HasTraitFn &&hasTrait,
115                WalkImmediateSubElementsFn walkImmediateSubElementsFn,
116                ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
117                TypeID typeID, StringRef name)
118       : dialect(dialect), interfaceMap(std::move(interfaceMap)),
119         hasTraitFn(std::move(hasTrait)),
120         walkImmediateSubElementsFn(walkImmediateSubElementsFn),
121         replaceImmediateSubElementsFn(replaceImmediateSubElementsFn),
122         typeID(typeID), name(name) {}
123 
124   /// Give StorageUserBase access to the mutable lookup.
125   template <typename ConcreteT, typename BaseT, typename StorageT,
126             typename UniquerT, template <typename T> class... Traits>
127   friend class detail::StorageUserBase;
128 
129   /// Look up the specified abstract type in the MLIRContext and return a
130   /// (mutable) pointer to it. Return a null pointer if the type could not
131   /// be found in the context.
132   static AbstractType *lookupMutable(TypeID typeID, MLIRContext *context);
133 
134   /// This is the dialect that this type was registered to.
135   const Dialect &dialect;
136 
137   /// This is a collection of the interfaces registered to this type.
138   detail::InterfaceMap interfaceMap;
139 
140   /// Function to check if the type has a particular trait.
141   HasTraitFn hasTraitFn;
142 
143   /// Function to walk the immediate sub-elements of this type.
144   WalkImmediateSubElementsFn walkImmediateSubElementsFn;
145 
146   /// Function to replace the immediate sub-elements of this type.
147   ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn;
148 
149   /// The unique identifier of the derived Type class.
150   const TypeID typeID;
151 
152   /// The unique name of this type. The string is not owned by the context, so
153   /// The lifetime of this string should outlive the MLIR context.
154   const StringRef name;
155 };
156 
157 //===----------------------------------------------------------------------===//
158 // TypeStorage
159 //===----------------------------------------------------------------------===//
160 
161 namespace detail {
162 struct TypeUniquer;
163 } // namespace detail
164 
165 /// Base storage class appearing in a Type.
166 class TypeStorage : public StorageUniquer::BaseStorage {
167   friend detail::TypeUniquer;
168   friend StorageUniquer;
169 
170 public:
171   /// Return the abstract type descriptor for this type.
getAbstractType()172   const AbstractType &getAbstractType() {
173     assert(abstractType && "Malformed type storage object.");
174     return *abstractType;
175   }
176 
177 protected:
178   /// This constructor is used by derived classes as part of the TypeUniquer.
TypeStorage()179   TypeStorage() {}
180 
181 private:
182   /// Set the abstract type for this storage instance. This is used by the
183   /// TypeUniquer when initializing a newly constructed type storage object.
initialize(const AbstractType & abstractTy)184   void initialize(const AbstractType &abstractTy) {
185     abstractType = const_cast<AbstractType *>(&abstractTy);
186   }
187 
188   /// The abstract description for this type.
189   AbstractType *abstractType{nullptr};
190 };
191 
192 /// Default storage type for types that require no additional initialization or
193 /// storage.
194 using DefaultTypeStorage = TypeStorage;
195 
196 //===----------------------------------------------------------------------===//
197 // TypeStorageAllocator
198 //===----------------------------------------------------------------------===//
199 
200 /// This is a utility allocator used to allocate memory for instances of derived
201 /// Types.
202 using TypeStorageAllocator = StorageUniquer::StorageAllocator;
203 
204 //===----------------------------------------------------------------------===//
205 // TypeUniquer
206 //===----------------------------------------------------------------------===//
207 namespace detail {
208 /// A utility class to get, or create, unique instances of types within an
209 /// MLIRContext. This class manages all creation and uniquing of types.
210 struct TypeUniquer {
211   /// Get an uniqued instance of a type T.
212   template <typename T, typename... Args>
getTypeUniquer213   static T get(MLIRContext *ctx, Args &&...args) {
214     return getWithTypeID<T, Args...>(ctx, T::getTypeID(),
215                                      std::forward<Args>(args)...);
216   }
217 
218   /// Get an uniqued instance of a parametric type T.
219   /// The use of this method is in general discouraged in favor of
220   /// 'get<T, Args>(ctx, args)'.
221   template <typename T, typename... Args>
222   static std::enable_if_t<
223       !std::is_same<typename T::ImplType, TypeStorage>::value, T>
getWithTypeIDTypeUniquer224   getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) {
225 #ifndef NDEBUG
226     if (!ctx->getTypeUniquer().isParametricStorageInitialized(typeID))
227       llvm::report_fatal_error(
228           llvm::Twine("can't create type '") + llvm::getTypeName<T>() +
229           "' because storage uniquer isn't initialized: the dialect was likely "
230           "not loaded, or the type wasn't added with addTypes<...>() "
231           "in the Dialect::initialize() method.");
232 #endif
233     return ctx->getTypeUniquer().get<typename T::ImplType>(
234         [&, typeID](TypeStorage *storage) {
235           storage->initialize(AbstractType::lookup(typeID, ctx));
236         },
237         typeID, std::forward<Args>(args)...);
238   }
239   /// Get an uniqued instance of a singleton type T.
240   /// The use of this method is in general discouraged in favor of
241   /// 'get<T, Args>(ctx, args)'.
242   template <typename T>
243   static std::enable_if_t<
244       std::is_same<typename T::ImplType, TypeStorage>::value, T>
getWithTypeIDTypeUniquer245   getWithTypeID(MLIRContext *ctx, TypeID typeID) {
246 #ifndef NDEBUG
247     if (!ctx->getTypeUniquer().isSingletonStorageInitialized(typeID))
248       llvm::report_fatal_error(
249           llvm::Twine("can't create type '") + llvm::getTypeName<T>() +
250           "' because storage uniquer isn't initialized: the dialect was likely "
251           "not loaded, or the type wasn't added with addTypes<...>() "
252           "in the Dialect::initialize() method.");
253 #endif
254     return ctx->getTypeUniquer().get<typename T::ImplType>(typeID);
255   }
256 
257   /// Change the mutable component of the given type instance in the provided
258   /// context.
259   template <typename T, typename... Args>
mutateTypeUniquer260   static LogicalResult mutate(MLIRContext *ctx, typename T::ImplType *impl,
261                               Args &&...args) {
262     assert(impl && "cannot mutate null type");
263     return ctx->getTypeUniquer().mutate(T::getTypeID(), impl,
264                                         std::forward<Args>(args)...);
265   }
266 
267   /// Register a type instance T with the uniquer.
268   template <typename T>
registerTypeTypeUniquer269   static void registerType(MLIRContext *ctx) {
270     registerType<T>(ctx, T::getTypeID());
271   }
272 
273   /// Register a parametric type instance T with the uniquer.
274   /// The use of this method is in general discouraged in favor of
275   /// 'registerType<T>(ctx)'.
276   template <typename T>
277   static std::enable_if_t<
278       !std::is_same<typename T::ImplType, TypeStorage>::value>
registerTypeTypeUniquer279   registerType(MLIRContext *ctx, TypeID typeID) {
280     ctx->getTypeUniquer().registerParametricStorageType<typename T::ImplType>(
281         typeID);
282   }
283   /// Register a singleton type instance T with the uniquer.
284   /// The use of this method is in general discouraged in favor of
285   /// 'registerType<T>(ctx)'.
286   template <typename T>
287   static std::enable_if_t<
288       std::is_same<typename T::ImplType, TypeStorage>::value>
registerTypeTypeUniquer289   registerType(MLIRContext *ctx, TypeID typeID) {
290     ctx->getTypeUniquer().registerSingletonStorageType<TypeStorage>(
291         typeID, [&ctx, typeID](TypeStorage *storage) {
292           storage->initialize(AbstractType::lookup(typeID, ctx));
293         });
294   }
295 };
296 } // namespace detail
297 
298 } // namespace mlir
299 
300 #endif
301