1 //===- StorageUniquerSupport.h - MLIR Storage Uniquer Utilities -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines utility classes for interfacing with StorageUniquer. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H 14 #define MLIR_IR_STORAGEUNIQUERSUPPORT_H 15 16 #include "mlir/IR/AttrTypeSubElements.h" 17 #include "mlir/IR/DialectRegistry.h" 18 #include "mlir/Support/InterfaceSupport.h" 19 #include "mlir/Support/StorageUniquer.h" 20 #include "mlir/Support/TypeID.h" 21 #include "llvm/ADT/FunctionExtras.h" 22 23 namespace mlir { 24 class InFlightDiagnostic; 25 class Location; 26 class MLIRContext; 27 28 namespace detail { 29 /// Utility method to generate a callback that can be used to generate a 30 /// diagnostic when checking the construction invariants of a storage object. 31 /// This is defined out-of-line to avoid the need to include Location.h. 32 llvm::unique_function<InFlightDiagnostic()> 33 getDefaultDiagnosticEmitFn(MLIRContext *ctx); 34 llvm::unique_function<InFlightDiagnostic()> 35 getDefaultDiagnosticEmitFn(const Location &loc); 36 37 //===----------------------------------------------------------------------===// 38 // StorageUserTraitBase 39 //===----------------------------------------------------------------------===// 40 41 /// Helper class for implementing traits for storage classes. Clients are not 42 /// expected to interact with this directly, so its members are all protected. 43 template <typename ConcreteType, template <typename> class TraitType> 44 class StorageUserTraitBase { 45 protected: 46 /// Return the derived instance. 47 ConcreteType getInstance() const { 48 // We have to cast up to the trait type, then to the concrete type because 49 // the concrete type will multiply derive from the (content free) TraitBase 50 // class, and we need to be able to disambiguate the path for the C++ 51 // compiler. 52 auto *trait = static_cast<const TraitType<ConcreteType> *>(this); 53 return *static_cast<const ConcreteType *>(trait); 54 } 55 }; 56 57 namespace StorageUserTrait { 58 /// This trait is used to determine if a storage user, like Type, is mutable 59 /// or not. A storage user is mutable if ImplType of the derived class defines 60 /// a `mutate` function with a proper signature. Note that this trait is not 61 /// supposed to be used publicly. Users should use alias names like 62 /// `TypeTrait::IsMutable` instead. 63 template <typename ConcreteType> 64 struct IsMutable : public StorageUserTraitBase<ConcreteType, IsMutable> {}; 65 } // namespace StorageUserTrait 66 67 //===----------------------------------------------------------------------===// 68 // StorageUserBase 69 //===----------------------------------------------------------------------===// 70 71 namespace storage_user_base_impl { 72 /// Returns true if this given Trait ID matches the IDs of any of the provided 73 /// trait types `Traits`. 74 template <template <typename T> class... Traits> 75 bool hasTrait(TypeID traitID) { 76 TypeID traitIDs[] = {TypeID::get<Traits>()...}; 77 for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i) 78 if (traitIDs[i] == traitID) 79 return true; 80 return false; 81 } 82 83 // We specialize for the empty case to not define an empty array. 84 template <> 85 inline bool hasTrait(TypeID traitID) { 86 return false; 87 } 88 } // namespace storage_user_base_impl 89 90 /// Utility class for implementing users of storage classes uniqued by a 91 /// StorageUniquer. Clients are not expected to interact with this class 92 /// directly. 93 template <typename ConcreteT, typename BaseT, typename StorageT, 94 typename UniquerT, template <typename T> class... Traits> 95 class StorageUserBase : public BaseT, public Traits<ConcreteT>... { 96 public: 97 using BaseT::BaseT; 98 99 /// Utility declarations for the concrete attribute class. 100 using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT, Traits...>; 101 using ImplType = StorageT; 102 using HasTraitFn = bool (*)(TypeID); 103 104 /// Return a unique identifier for the concrete type. 105 static TypeID getTypeID() { return TypeID::get<ConcreteT>(); } 106 107 /// Provide an implementation of 'classof' that compares the type id of the 108 /// provided value with that of the concrete type. 109 template <typename T> 110 static bool classof(T val) { 111 static_assert(std::is_convertible<ConcreteT, T>::value, 112 "casting from a non-convertible type"); 113 return val.getTypeID() == getTypeID(); 114 } 115 116 /// Returns an interface map for the interfaces registered to this storage 117 /// user. This should not be used directly. 118 static detail::InterfaceMap getInterfaceMap() { 119 return detail::InterfaceMap::template get<Traits<ConcreteT>...>(); 120 } 121 122 /// Returns the function that returns true if the given Trait ID matches the 123 /// IDs of any of the traits defined by the storage user. 124 static HasTraitFn getHasTraitFn() { 125 return [](TypeID id) { 126 return storage_user_base_impl::hasTrait<Traits...>(id); 127 }; 128 } 129 130 /// Returns a function that walks immediate sub elements of a given instance 131 /// of the storage user. 132 static auto getWalkImmediateSubElementsFn() { 133 return [](auto instance, function_ref<void(Attribute)> walkAttrsFn, 134 function_ref<void(Type)> walkTypesFn) { 135 ::mlir::detail::walkImmediateSubElementsImpl( 136 llvm::cast<ConcreteT>(instance), walkAttrsFn, walkTypesFn); 137 }; 138 } 139 140 /// Returns a function that replaces immediate sub elements of a given 141 /// instance of the storage user. 142 static auto getReplaceImmediateSubElementsFn() { 143 return [](auto instance, ArrayRef<Attribute> replAttrs, 144 ArrayRef<Type> replTypes) { 145 return ::mlir::detail::replaceImmediateSubElementsImpl( 146 llvm::cast<ConcreteT>(instance), replAttrs, replTypes); 147 }; 148 } 149 150 /// Attach the given models as implementations of the corresponding interfaces 151 /// for the concrete storage user class. The type must be registered with the 152 /// context, i.e. the dialect to which the type belongs must be loaded. The 153 /// call will abort otherwise. 154 template <typename... IfaceModels> 155 static void attachInterface(MLIRContext &context) { 156 typename ConcreteT::AbstractTy *abstract = 157 ConcreteT::AbstractTy::lookupMutable(TypeID::get<ConcreteT>(), 158 &context); 159 if (!abstract) 160 llvm::report_fatal_error("Registering an interface for an attribute/type " 161 "that is not itself registered."); 162 163 // Handle the case where the models resolve a promised interface. 164 (dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface( 165 abstract->getDialect(), abstract->getTypeID(), 166 IfaceModels::Interface::getInterfaceID()), 167 ...); 168 169 (checkInterfaceTarget<IfaceModels>(), ...); 170 abstract->interfaceMap.template insertModels<IfaceModels...>(); 171 } 172 173 /// Get or create a new ConcreteT instance within the ctx. This 174 /// function is guaranteed to return a non null object and will assert if 175 /// the arguments provided are invalid. 176 template <typename... Args> 177 static ConcreteT get(MLIRContext *ctx, Args &&...args) { 178 // Ensure that the invariants are correct for construction. 179 assert(succeeded( 180 ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...))); 181 return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...); 182 } 183 184 /// Get or create a new ConcreteT instance within the ctx, defined at 185 /// the given, potentially unknown, location. If the arguments provided are 186 /// invalid, errors are emitted using the provided location and a null object 187 /// is returned. 188 template <typename... Args> 189 static ConcreteT getChecked(const Location &loc, Args &&...args) { 190 return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc), 191 std::forward<Args>(args)...); 192 } 193 194 /// Get or create a new ConcreteT instance within the ctx. If the arguments 195 /// provided are invalid, errors are emitted using the provided `emitError` 196 /// and a null object is returned. 197 template <typename... Args> 198 static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 199 MLIRContext *ctx, Args... args) { 200 // If the construction invariants fail then we return a null attribute. 201 if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...))) 202 return ConcreteT(); 203 return UniquerT::template get<ConcreteT>(ctx, args...); 204 } 205 206 /// Get an instance of the concrete type from a void pointer. 207 static ConcreteT getFromOpaquePointer(const void *ptr) { 208 return ConcreteT((const typename BaseT::ImplType *)ptr); 209 } 210 211 /// Utility for easy access to the storage instance. 212 ImplType *getImpl() const { return static_cast<ImplType *>(this->impl); } 213 214 protected: 215 /// Mutate the current storage instance. This will not change the unique key. 216 /// The arguments are forwarded to 'ConcreteT::mutate'. 217 template <typename... Args> 218 LogicalResult mutate(Args &&...args) { 219 static_assert(std::is_base_of<StorageUserTrait::IsMutable<ConcreteT>, 220 ConcreteT>::value, 221 "The `mutate` function expects mutable trait " 222 "(e.g. TypeTrait::IsMutable) to be attached on parent."); 223 return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(), 224 std::forward<Args>(args)...); 225 } 226 227 /// Default implementation that just returns success. 228 template <typename... Args> 229 static LogicalResult verifyInvariants(Args... args) { 230 return success(); 231 } 232 233 private: 234 /// Trait to check if T provides a 'ConcreteEntity' type alias. 235 template <typename T> 236 using has_concrete_entity_t = typename T::ConcreteEntity; 237 238 /// A struct-wrapped type alias to T::ConcreteEntity if provided and to 239 /// ConcreteT otherwise. This is akin to std::conditional but doesn't fail on 240 /// the missing typedef. Useful for checking if the interface is targeting the 241 /// right class. 242 template <typename T, 243 bool = llvm::is_detected<has_concrete_entity_t, T>::value> 244 struct IfaceTargetOrConcreteT { 245 using type = typename T::ConcreteEntity; 246 }; 247 template <typename T> 248 struct IfaceTargetOrConcreteT<T, false> { 249 using type = ConcreteT; 250 }; 251 252 /// A hook for static assertion that the external interface model T is 253 /// targeting a base class of the concrete attribute/type. The model can also 254 /// be a fallback model that works for every attribute/type. 255 template <typename T> 256 static void checkInterfaceTarget() { 257 static_assert(std::is_base_of<typename IfaceTargetOrConcreteT<T>::type, 258 ConcreteT>::value, 259 "attaching an interface to the wrong attribute/type kind"); 260 } 261 }; 262 } // namespace detail 263 } // namespace mlir 264 265 #endif 266