xref: /llvm-project/mlir/include/mlir/IR/StorageUniquerSupport.h (revision a9636b7f60f283926c66e96c036f5b5d9e57c026)
1 //===- StorageUniquerSupport.h - MLIR Storage Uniquer Utilities -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utility classes for interfacing with StorageUniquer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H
14 #define MLIR_IR_STORAGEUNIQUERSUPPORT_H
15 
16 #include "mlir/IR/AttrTypeSubElements.h"
17 #include "mlir/IR/DialectRegistry.h"
18 #include "mlir/Support/InterfaceSupport.h"
19 #include "mlir/Support/StorageUniquer.h"
20 #include "mlir/Support/TypeID.h"
21 #include "llvm/ADT/FunctionExtras.h"
22 
23 namespace mlir {
24 class InFlightDiagnostic;
25 class Location;
26 class MLIRContext;
27 
28 namespace detail {
29 /// Utility method to generate a callback that can be used to generate a
30 /// diagnostic when checking the construction invariants of a storage object.
31 /// This is defined out-of-line to avoid the need to include Location.h.
32 llvm::unique_function<InFlightDiagnostic()>
33 getDefaultDiagnosticEmitFn(MLIRContext *ctx);
34 llvm::unique_function<InFlightDiagnostic()>
35 getDefaultDiagnosticEmitFn(const Location &loc);
36 
37 //===----------------------------------------------------------------------===//
38 // StorageUserTraitBase
39 //===----------------------------------------------------------------------===//
40 
41 /// Helper class for implementing traits for storage classes. Clients are not
42 /// expected to interact with this directly, so its members are all protected.
43 template <typename ConcreteType, template <typename> class TraitType>
44 class StorageUserTraitBase {
45 protected:
46   /// Return the derived instance.
47   ConcreteType getInstance() const {
48     // We have to cast up to the trait type, then to the concrete type because
49     // the concrete type will multiply derive from the (content free) TraitBase
50     // class, and we need to be able to disambiguate the path for the C++
51     // compiler.
52     auto *trait = static_cast<const TraitType<ConcreteType> *>(this);
53     return *static_cast<const ConcreteType *>(trait);
54   }
55 };
56 
57 namespace StorageUserTrait {
58 /// This trait is used to determine if a storage user, like Type, is mutable
59 /// or not. A storage user is mutable if ImplType of the derived class defines
60 /// a `mutate` function with a proper signature. Note that this trait is not
61 /// supposed to be used publicly. Users should use alias names like
62 /// `TypeTrait::IsMutable` instead.
63 template <typename ConcreteType>
64 struct IsMutable : public StorageUserTraitBase<ConcreteType, IsMutable> {};
65 } // namespace StorageUserTrait
66 
67 //===----------------------------------------------------------------------===//
68 // StorageUserBase
69 //===----------------------------------------------------------------------===//
70 
71 namespace storage_user_base_impl {
72 /// Returns true if this given Trait ID matches the IDs of any of the provided
73 /// trait types `Traits`.
74 template <template <typename T> class... Traits>
75 bool hasTrait(TypeID traitID) {
76   TypeID traitIDs[] = {TypeID::get<Traits>()...};
77   for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
78     if (traitIDs[i] == traitID)
79       return true;
80   return false;
81 }
82 
83 // We specialize for the empty case to not define an empty array.
84 template <>
85 inline bool hasTrait(TypeID traitID) {
86   return false;
87 }
88 } // namespace storage_user_base_impl
89 
90 /// Utility class for implementing users of storage classes uniqued by a
91 /// StorageUniquer. Clients are not expected to interact with this class
92 /// directly.
93 template <typename ConcreteT, typename BaseT, typename StorageT,
94           typename UniquerT, template <typename T> class... Traits>
95 class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
96 public:
97   using BaseT::BaseT;
98 
99   /// Utility declarations for the concrete attribute class.
100   using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT, Traits...>;
101   using ImplType = StorageT;
102   using HasTraitFn = bool (*)(TypeID);
103 
104   /// Return a unique identifier for the concrete type.
105   static TypeID getTypeID() { return TypeID::get<ConcreteT>(); }
106 
107   /// Provide an implementation of 'classof' that compares the type id of the
108   /// provided value with that of the concrete type.
109   template <typename T>
110   static bool classof(T val) {
111     static_assert(std::is_convertible<ConcreteT, T>::value,
112                   "casting from a non-convertible type");
113     return val.getTypeID() == getTypeID();
114   }
115 
116   /// Returns an interface map for the interfaces registered to this storage
117   /// user. This should not be used directly.
118   static detail::InterfaceMap getInterfaceMap() {
119     return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
120   }
121 
122   /// Returns the function that returns true if the given Trait ID matches the
123   /// IDs of any of the traits defined by the storage user.
124   static HasTraitFn getHasTraitFn() {
125     return [](TypeID id) {
126       return storage_user_base_impl::hasTrait<Traits...>(id);
127     };
128   }
129 
130   /// Returns a function that walks immediate sub elements of a given instance
131   /// of the storage user.
132   static auto getWalkImmediateSubElementsFn() {
133     return [](auto instance, function_ref<void(Attribute)> walkAttrsFn,
134               function_ref<void(Type)> walkTypesFn) {
135       ::mlir::detail::walkImmediateSubElementsImpl(
136           llvm::cast<ConcreteT>(instance), walkAttrsFn, walkTypesFn);
137     };
138   }
139 
140   /// Returns a function that replaces immediate sub elements of a given
141   /// instance of the storage user.
142   static auto getReplaceImmediateSubElementsFn() {
143     return [](auto instance, ArrayRef<Attribute> replAttrs,
144               ArrayRef<Type> replTypes) {
145       return ::mlir::detail::replaceImmediateSubElementsImpl(
146           llvm::cast<ConcreteT>(instance), replAttrs, replTypes);
147     };
148   }
149 
150   /// Attach the given models as implementations of the corresponding interfaces
151   /// for the concrete storage user class. The type must be registered with the
152   /// context, i.e. the dialect to which the type belongs must be loaded. The
153   /// call will abort otherwise.
154   template <typename... IfaceModels>
155   static void attachInterface(MLIRContext &context) {
156     typename ConcreteT::AbstractTy *abstract =
157         ConcreteT::AbstractTy::lookupMutable(TypeID::get<ConcreteT>(),
158                                              &context);
159     if (!abstract)
160       llvm::report_fatal_error("Registering an interface for an attribute/type "
161                                "that is not itself registered.");
162 
163     // Handle the case where the models resolve a promised interface.
164     (dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
165          abstract->getDialect(), abstract->getTypeID(),
166          IfaceModels::Interface::getInterfaceID()),
167      ...);
168 
169     (checkInterfaceTarget<IfaceModels>(), ...);
170     abstract->interfaceMap.template insertModels<IfaceModels...>();
171   }
172 
173   /// Get or create a new ConcreteT instance within the ctx. This
174   /// function is guaranteed to return a non null object and will assert if
175   /// the arguments provided are invalid.
176   template <typename... Args>
177   static ConcreteT get(MLIRContext *ctx, Args &&...args) {
178     // Ensure that the invariants are correct for construction.
179     assert(succeeded(
180         ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...)));
181     return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
182   }
183 
184   /// Get or create a new ConcreteT instance within the ctx, defined at
185   /// the given, potentially unknown, location. If the arguments provided are
186   /// invalid, errors are emitted using the provided location and a null object
187   /// is returned.
188   template <typename... Args>
189   static ConcreteT getChecked(const Location &loc, Args &&...args) {
190     return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc),
191                                  std::forward<Args>(args)...);
192   }
193 
194   /// Get or create a new ConcreteT instance within the ctx. If the arguments
195   /// provided are invalid, errors are emitted using the provided `emitError`
196   /// and a null object is returned.
197   template <typename... Args>
198   static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
199                               MLIRContext *ctx, Args... args) {
200     // If the construction invariants fail then we return a null attribute.
201     if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
202       return ConcreteT();
203     return UniquerT::template get<ConcreteT>(ctx, args...);
204   }
205 
206   /// Get an instance of the concrete type from a void pointer.
207   static ConcreteT getFromOpaquePointer(const void *ptr) {
208     return ConcreteT((const typename BaseT::ImplType *)ptr);
209   }
210 
211   /// Utility for easy access to the storage instance.
212   ImplType *getImpl() const { return static_cast<ImplType *>(this->impl); }
213 
214 protected:
215   /// Mutate the current storage instance. This will not change the unique key.
216   /// The arguments are forwarded to 'ConcreteT::mutate'.
217   template <typename... Args>
218   LogicalResult mutate(Args &&...args) {
219     static_assert(std::is_base_of<StorageUserTrait::IsMutable<ConcreteT>,
220                                   ConcreteT>::value,
221                   "The `mutate` function expects mutable trait "
222                   "(e.g. TypeTrait::IsMutable) to be attached on parent.");
223     return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(),
224                                                 std::forward<Args>(args)...);
225   }
226 
227   /// Default implementation that just returns success.
228   template <typename... Args>
229   static LogicalResult verifyInvariants(Args... args) {
230     return success();
231   }
232 
233 private:
234   /// Trait to check if T provides a 'ConcreteEntity' type alias.
235   template <typename T>
236   using has_concrete_entity_t = typename T::ConcreteEntity;
237 
238   /// A struct-wrapped type alias to T::ConcreteEntity if provided and to
239   /// ConcreteT otherwise. This is akin to std::conditional but doesn't fail on
240   /// the missing typedef. Useful for checking if the interface is targeting the
241   /// right class.
242   template <typename T,
243             bool = llvm::is_detected<has_concrete_entity_t, T>::value>
244   struct IfaceTargetOrConcreteT {
245     using type = typename T::ConcreteEntity;
246   };
247   template <typename T>
248   struct IfaceTargetOrConcreteT<T, false> {
249     using type = ConcreteT;
250   };
251 
252   /// A hook for static assertion that the external interface model T is
253   /// targeting a base class of the concrete attribute/type. The model can also
254   /// be a fallback model that works for every attribute/type.
255   template <typename T>
256   static void checkInterfaceTarget() {
257     static_assert(std::is_base_of<typename IfaceTargetOrConcreteT<T>::type,
258                                   ConcreteT>::value,
259                   "attaching an interface to the wrong attribute/type kind");
260   }
261 };
262 } // namespace detail
263 } // namespace mlir
264 
265 #endif
266