1 //===- TypeSupport.h --------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines support types for registering dialect extended types. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_IR_TYPESUPPORT_H 14 #define MLIR_IR_TYPESUPPORT_H 15 16 #include "mlir/IR/MLIRContext.h" 17 #include "mlir/IR/StorageUniquerSupport.h" 18 #include "llvm/ADT/Twine.h" 19 20 namespace mlir { 21 class Dialect; 22 class MLIRContext; 23 24 //===----------------------------------------------------------------------===// 25 // AbstractType 26 //===----------------------------------------------------------------------===// 27 28 /// This class contains all of the static information common to all instances of 29 /// a registered Type. 30 class AbstractType { 31 public: 32 using HasTraitFn = llvm::unique_function<bool(TypeID) const>; 33 using WalkImmediateSubElementsFn = function_ref<void( 34 Type, function_ref<void(Attribute)>, function_ref<void(Type)>)>; 35 using ReplaceImmediateSubElementsFn = 36 function_ref<Type(Type, ArrayRef<Attribute>, ArrayRef<Type>)>; 37 38 /// Look up the specified abstract type in the MLIRContext and return a 39 /// reference to it. 40 static const AbstractType &lookup(TypeID typeID, MLIRContext *context); 41 42 /// Look up the specified abstract type in the MLIRContext and return a 43 /// reference to it if it exists. 44 static std::optional<std::reference_wrapper<const AbstractType>> 45 lookup(StringRef name, MLIRContext *context); 46 47 /// This method is used by Dialect objects when they register the list of 48 /// types they contain. 49 template <typename T> get(Dialect & dialect)50 static AbstractType get(Dialect &dialect) { 51 return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(), 52 T::getWalkImmediateSubElementsFn(), 53 T::getReplaceImmediateSubElementsFn(), T::getTypeID(), 54 T::name); 55 } 56 57 /// This method is used by Dialect objects to register types with 58 /// custom TypeIDs. 59 /// The use of this method is in general discouraged in favor of 60 /// 'get<CustomType>(dialect)'; 61 static AbstractType get(Dialect & dialect,detail::InterfaceMap && interfaceMap,HasTraitFn && hasTrait,WalkImmediateSubElementsFn walkImmediateSubElementsFn,ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,TypeID typeID,StringRef name)62 get(Dialect &dialect, detail::InterfaceMap &&interfaceMap, 63 HasTraitFn &&hasTrait, 64 WalkImmediateSubElementsFn walkImmediateSubElementsFn, 65 ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, 66 TypeID typeID, StringRef name) { 67 return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait), 68 walkImmediateSubElementsFn, 69 replaceImmediateSubElementsFn, typeID, name); 70 } 71 72 /// Return the dialect this type was registered to. getDialect()73 Dialect &getDialect() const { return const_cast<Dialect &>(dialect); } 74 75 /// Returns an instance of the concept object for the given interface if it 76 /// was registered to this type, null otherwise. This should not be used 77 /// directly. 78 template <typename T> getInterface()79 typename T::Concept *getInterface() const { 80 return interfaceMap.lookup<T>(); 81 } 82 83 /// Returns true if the type has the interface with the given ID. hasInterface(TypeID interfaceID)84 bool hasInterface(TypeID interfaceID) const { 85 return interfaceMap.contains(interfaceID); 86 } 87 88 /// Returns true if the type has a particular trait. 89 template <template <typename T> class Trait> hasTrait()90 bool hasTrait() const { 91 return hasTraitFn(TypeID::get<Trait>()); 92 } 93 94 /// Returns true if the type has a particular trait. hasTrait(TypeID traitID)95 bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); } 96 97 /// Walk the immediate sub-elements of the given type. 98 void walkImmediateSubElements(Type type, 99 function_ref<void(Attribute)> walkAttrsFn, 100 function_ref<void(Type)> walkTypesFn) const; 101 102 /// Replace the immediate sub-elements of the given type. 103 Type replaceImmediateSubElements(Type type, ArrayRef<Attribute> replAttrs, 104 ArrayRef<Type> replTypes) const; 105 106 /// Return the unique identifier representing the concrete type class. getTypeID()107 TypeID getTypeID() const { return typeID; } 108 109 /// Return the unique name representing the type. getName()110 StringRef getName() const { return name; } 111 112 private: AbstractType(Dialect & dialect,detail::InterfaceMap && interfaceMap,HasTraitFn && hasTrait,WalkImmediateSubElementsFn walkImmediateSubElementsFn,ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,TypeID typeID,StringRef name)113 AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap, 114 HasTraitFn &&hasTrait, 115 WalkImmediateSubElementsFn walkImmediateSubElementsFn, 116 ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, 117 TypeID typeID, StringRef name) 118 : dialect(dialect), interfaceMap(std::move(interfaceMap)), 119 hasTraitFn(std::move(hasTrait)), 120 walkImmediateSubElementsFn(walkImmediateSubElementsFn), 121 replaceImmediateSubElementsFn(replaceImmediateSubElementsFn), 122 typeID(typeID), name(name) {} 123 124 /// Give StorageUserBase access to the mutable lookup. 125 template <typename ConcreteT, typename BaseT, typename StorageT, 126 typename UniquerT, template <typename T> class... Traits> 127 friend class detail::StorageUserBase; 128 129 /// Look up the specified abstract type in the MLIRContext and return a 130 /// (mutable) pointer to it. Return a null pointer if the type could not 131 /// be found in the context. 132 static AbstractType *lookupMutable(TypeID typeID, MLIRContext *context); 133 134 /// This is the dialect that this type was registered to. 135 const Dialect &dialect; 136 137 /// This is a collection of the interfaces registered to this type. 138 detail::InterfaceMap interfaceMap; 139 140 /// Function to check if the type has a particular trait. 141 HasTraitFn hasTraitFn; 142 143 /// Function to walk the immediate sub-elements of this type. 144 WalkImmediateSubElementsFn walkImmediateSubElementsFn; 145 146 /// Function to replace the immediate sub-elements of this type. 147 ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn; 148 149 /// The unique identifier of the derived Type class. 150 const TypeID typeID; 151 152 /// The unique name of this type. The string is not owned by the context, so 153 /// The lifetime of this string should outlive the MLIR context. 154 const StringRef name; 155 }; 156 157 //===----------------------------------------------------------------------===// 158 // TypeStorage 159 //===----------------------------------------------------------------------===// 160 161 namespace detail { 162 struct TypeUniquer; 163 } // namespace detail 164 165 /// Base storage class appearing in a Type. 166 class TypeStorage : public StorageUniquer::BaseStorage { 167 friend detail::TypeUniquer; 168 friend StorageUniquer; 169 170 public: 171 /// Return the abstract type descriptor for this type. getAbstractType()172 const AbstractType &getAbstractType() { 173 assert(abstractType && "Malformed type storage object."); 174 return *abstractType; 175 } 176 177 protected: 178 /// This constructor is used by derived classes as part of the TypeUniquer. TypeStorage()179 TypeStorage() {} 180 181 private: 182 /// Set the abstract type for this storage instance. This is used by the 183 /// TypeUniquer when initializing a newly constructed type storage object. initialize(const AbstractType & abstractTy)184 void initialize(const AbstractType &abstractTy) { 185 abstractType = const_cast<AbstractType *>(&abstractTy); 186 } 187 188 /// The abstract description for this type. 189 AbstractType *abstractType{nullptr}; 190 }; 191 192 /// Default storage type for types that require no additional initialization or 193 /// storage. 194 using DefaultTypeStorage = TypeStorage; 195 196 //===----------------------------------------------------------------------===// 197 // TypeStorageAllocator 198 //===----------------------------------------------------------------------===// 199 200 /// This is a utility allocator used to allocate memory for instances of derived 201 /// Types. 202 using TypeStorageAllocator = StorageUniquer::StorageAllocator; 203 204 //===----------------------------------------------------------------------===// 205 // TypeUniquer 206 //===----------------------------------------------------------------------===// 207 namespace detail { 208 /// A utility class to get, or create, unique instances of types within an 209 /// MLIRContext. This class manages all creation and uniquing of types. 210 struct TypeUniquer { 211 /// Get an uniqued instance of a type T. 212 template <typename T, typename... Args> getTypeUniquer213 static T get(MLIRContext *ctx, Args &&...args) { 214 return getWithTypeID<T, Args...>(ctx, T::getTypeID(), 215 std::forward<Args>(args)...); 216 } 217 218 /// Get an uniqued instance of a parametric type T. 219 /// The use of this method is in general discouraged in favor of 220 /// 'get<T, Args>(ctx, args)'. 221 template <typename T, typename... Args> 222 static std::enable_if_t< 223 !std::is_same<typename T::ImplType, TypeStorage>::value, T> getWithTypeIDTypeUniquer224 getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) { 225 #ifndef NDEBUG 226 if (!ctx->getTypeUniquer().isParametricStorageInitialized(typeID)) 227 llvm::report_fatal_error( 228 llvm::Twine("can't create type '") + llvm::getTypeName<T>() + 229 "' because storage uniquer isn't initialized: the dialect was likely " 230 "not loaded, or the type wasn't added with addTypes<...>() " 231 "in the Dialect::initialize() method."); 232 #endif 233 return ctx->getTypeUniquer().get<typename T::ImplType>( 234 [&, typeID](TypeStorage *storage) { 235 storage->initialize(AbstractType::lookup(typeID, ctx)); 236 }, 237 typeID, std::forward<Args>(args)...); 238 } 239 /// Get an uniqued instance of a singleton type T. 240 /// The use of this method is in general discouraged in favor of 241 /// 'get<T, Args>(ctx, args)'. 242 template <typename T> 243 static std::enable_if_t< 244 std::is_same<typename T::ImplType, TypeStorage>::value, T> getWithTypeIDTypeUniquer245 getWithTypeID(MLIRContext *ctx, TypeID typeID) { 246 #ifndef NDEBUG 247 if (!ctx->getTypeUniquer().isSingletonStorageInitialized(typeID)) 248 llvm::report_fatal_error( 249 llvm::Twine("can't create type '") + llvm::getTypeName<T>() + 250 "' because storage uniquer isn't initialized: the dialect was likely " 251 "not loaded, or the type wasn't added with addTypes<...>() " 252 "in the Dialect::initialize() method."); 253 #endif 254 return ctx->getTypeUniquer().get<typename T::ImplType>(typeID); 255 } 256 257 /// Change the mutable component of the given type instance in the provided 258 /// context. 259 template <typename T, typename... Args> mutateTypeUniquer260 static LogicalResult mutate(MLIRContext *ctx, typename T::ImplType *impl, 261 Args &&...args) { 262 assert(impl && "cannot mutate null type"); 263 return ctx->getTypeUniquer().mutate(T::getTypeID(), impl, 264 std::forward<Args>(args)...); 265 } 266 267 /// Register a type instance T with the uniquer. 268 template <typename T> registerTypeTypeUniquer269 static void registerType(MLIRContext *ctx) { 270 registerType<T>(ctx, T::getTypeID()); 271 } 272 273 /// Register a parametric type instance T with the uniquer. 274 /// The use of this method is in general discouraged in favor of 275 /// 'registerType<T>(ctx)'. 276 template <typename T> 277 static std::enable_if_t< 278 !std::is_same<typename T::ImplType, TypeStorage>::value> registerTypeTypeUniquer279 registerType(MLIRContext *ctx, TypeID typeID) { 280 ctx->getTypeUniquer().registerParametricStorageType<typename T::ImplType>( 281 typeID); 282 } 283 /// Register a singleton type instance T with the uniquer. 284 /// The use of this method is in general discouraged in favor of 285 /// 'registerType<T>(ctx)'. 286 template <typename T> 287 static std::enable_if_t< 288 std::is_same<typename T::ImplType, TypeStorage>::value> registerTypeTypeUniquer289 registerType(MLIRContext *ctx, TypeID typeID) { 290 ctx->getTypeUniquer().registerSingletonStorageType<TypeStorage>( 291 typeID, [&ctx, typeID](TypeStorage *storage) { 292 storage->initialize(AbstractType::lookup(typeID, ctx)); 293 }); 294 } 295 }; 296 } // namespace detail 297 298 } // namespace mlir 299 300 #endif 301