xref: /llvm-project/mlir/include/mlir/Support/InterfaceSupport.h (revision 6089d612a580738df00c22a43e6f2c29bd216af9)
1 //===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines several support classes for defining interfaces.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_SUPPORT_INTERFACESUPPORT_H
14 #define MLIR_SUPPORT_INTERFACESUPPORT_H
15 
16 #include "mlir/Support/TypeID.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/Support/TypeName.h"
20 
21 namespace mlir {
22 namespace detail {
23 //===----------------------------------------------------------------------===//
24 // Interface
25 //===----------------------------------------------------------------------===//
26 
27 /// This class represents an abstract interface. An interface is a simplified
28 /// mechanism for attaching concept based polymorphism to a class hierarchy. An
29 /// interface is comprised of two components:
30 /// * The derived interface class: This is what users interact with, and invoke
31 ///   methods on.
32 /// * An interface `Trait` class: This is the class that is attached to the
33 ///   object implementing the interface. It is the mechanism with which models
34 ///   are specialized.
35 ///
36 /// Derived interfaces types must provide the following template types:
37 /// * ConcreteType: The CRTP derived type.
38 /// * ValueT: The opaque type the derived interface operates on. For example
39 ///           `Operation*` for operation interfaces, or `Attribute` for
40 ///           attribute interfaces.
41 /// * Traits: A class that contains definitions for a 'Concept' and a 'Model'
42 ///           class. The 'Concept' class defines an abstract virtual interface,
43 ///           where as the 'Model' class implements this interface for a
44 ///           specific derived T type. Both of these classes *must* not contain
45 ///           non-static data. A simple example is shown below:
46 ///
47 /// ```c++
48 ///    struct ExampleInterfaceTraits {
49 ///      struct Concept {
50 ///        virtual unsigned getNumInputs(T t) const = 0;
51 ///      };
52 ///      template <typename DerivedT> class Model {
53 ///        unsigned getNumInputs(T t) const final {
54 ///          return cast<DerivedT>(t).getNumInputs();
55 ///        }
56 ///      };
57 ///    };
58 /// ```
59 ///
60 /// * BaseType: A desired base type for the interface. This is a class
61 ///             that provides specific functionality for the `ValueT`
62 ///             value. For instance the specific `Op` that will wrap the
63 ///             `Operation*` for an `OpInterface`.
64 /// * BaseTrait: The base type for the interface trait. This is the base class
65 ///              to use for the interface trait that will be attached to each
66 ///              instance of `ValueT` that implements this interface.
67 ///
68 template <typename ConcreteType, typename ValueT, typename Traits,
69           typename BaseType,
70           template <typename, template <typename> class> class BaseTrait>
71 class Interface : public BaseType {
72 public:
73   using Concept = typename Traits::Concept;
74   template <typename T>
75   using Model = typename Traits::template Model<T>;
76   template <typename T>
77   using FallbackModel = typename Traits::template FallbackModel<T>;
78   using InterfaceBase =
79       Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>;
80   template <typename T, typename U>
81   using ExternalModel = typename Traits::template ExternalModel<T, U>;
82   using ValueType = ValueT;
83 
84   /// This is a special trait that registers a given interface with an object.
85   template <typename ConcreteT>
86   struct Trait : public BaseTrait<ConcreteT, Trait> {
87     using ModelT = Model<ConcreteT>;
88 
89     /// Define an accessor for the ID of this interface.
getInterfaceIDTrait90     static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
91   };
92 
93   /// Construct an interface from an instance of the value type.
94   explicit Interface(ValueT t = ValueT())
BaseType(t)95       : BaseType(t),
96         conceptImpl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
97     assert((!t || conceptImpl) &&
98            "expected value to provide interface instance");
99   }
Interface(std::nullptr_t)100   Interface(std::nullptr_t) : BaseType(ValueT()), conceptImpl(nullptr) {}
101 
102   /// Construct an interface instance from a type that implements this
103   /// interface's trait.
104   template <typename T,
105             std::enable_if_t<std::is_base_of<Trait<T>, T>::value> * = nullptr>
Interface(T t)106   Interface(T t)
107       : BaseType(t),
108         conceptImpl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
109     assert((!t || conceptImpl) &&
110            "expected value to provide interface instance");
111   }
112 
113   /// Constructor for a known concept.
Interface(ValueT t,const Concept * conceptImpl)114   Interface(ValueT t, const Concept *conceptImpl)
115       : BaseType(t), conceptImpl(const_cast<Concept *>(conceptImpl)) {
116     assert(!t || ConcreteType::getInterfaceFor(t) == conceptImpl);
117   }
118 
119   /// Constructor for DenseMapInfo's empty key and tombstone key.
Interface(ValueT t,std::nullptr_t)120   Interface(ValueT t, std::nullptr_t) : BaseType(t), conceptImpl(nullptr) {}
121 
122   /// Support 'classof' by checking if the given object defines the concrete
123   /// interface.
classof(ValueT t)124   static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); }
125 
126   /// Define an accessor for the ID of this interface.
getInterfaceID()127   static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
128 
129 protected:
130   /// Get the raw concept in the correct derived concept type.
getImpl()131   const Concept *getImpl() const { return conceptImpl; }
getImpl()132   Concept *getImpl() { return conceptImpl; }
133 
134 private:
135   /// A pointer to the impl concept object.
136   Concept *conceptImpl;
137 };
138 
139 //===----------------------------------------------------------------------===//
140 // InterfaceMap
141 //===----------------------------------------------------------------------===//
142 
143 /// Template utility that computes the number of elements within `T` that
144 /// satisfy the given predicate.
145 template <template <class> class Pred, size_t N, typename... Ts>
146 struct count_if_t_impl : public std::integral_constant<size_t, N> {};
147 template <template <class> class Pred, size_t N, typename T, typename... Us>
148 struct count_if_t_impl<Pred, N, T, Us...>
149     : public std::integral_constant<
150           size_t,
151           count_if_t_impl<Pred, N + (Pred<T>::value ? 1 : 0), Us...>::value> {};
152 template <template <class> class Pred, typename... Ts>
153 using count_if_t = count_if_t_impl<Pred, 0, Ts...>;
154 
155 /// This class provides an efficient mapping between a given `Interface` type,
156 /// and a particular implementation of its concept.
157 class InterfaceMap {
158   /// Trait to check if T provides a static 'getInterfaceID' method.
159   template <typename T, typename... Args>
160   using has_get_interface_id = decltype(T::getInterfaceID());
161   template <typename T>
162   using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>;
163   template <typename... Types>
164   using num_interface_types_t = count_if_t<detect_get_interface_id, Types...>;
165 
166   /// Trait to check if T provides a 'initializeInterfaceConcept' method.
167   template <typename T, typename... Args>
168   using has_initialize_method =
169       decltype(std::declval<T>().initializeInterfaceConcept(
170           std::declval<InterfaceMap &>()));
171   template <typename T>
172   using detect_initialize_method = llvm::is_detected<has_initialize_method, T>;
173 
174 public:
175   InterfaceMap() = default;
176   InterfaceMap(InterfaceMap &&) = default;
177   InterfaceMap &operator=(InterfaceMap &&rhs) {
178     for (auto &it : interfaces)
179       free(it.second);
180     interfaces = std::move(rhs.interfaces);
181     return *this;
182   }
183   ~InterfaceMap() {
184     for (auto &it : interfaces)
185       free(it.second);
186   }
187 
188   /// Construct an InterfaceMap with the given set of template types. For
189   /// convenience given that object trait lists may contain other non-interface
190   /// types, not all of the types need to be interfaces. The provided types that
191   /// do not represent interfaces are not added to the interface map.
192   template <typename... Types>
193   static InterfaceMap get() {
194     constexpr size_t numInterfaces = num_interface_types_t<Types...>::value;
195     if constexpr (numInterfaces == 0)
196       return InterfaceMap();
197 
198     InterfaceMap map;
199     (map.insertPotentialInterface<Types>(), ...);
200     return map;
201   }
202 
203   /// Returns an instance of the concept object for the given interface if it
204   /// was registered to this map, null otherwise.
205   template <typename T>
206   typename T::Concept *lookup() const {
207     return reinterpret_cast<typename T::Concept *>(lookup(T::getInterfaceID()));
208   }
209 
210   /// Returns true if the interface map contains an interface for the given id.
211   bool contains(TypeID interfaceID) const { return lookup(interfaceID); }
212 
213   /// Insert the given interface models.
214   template <typename... IfaceModels>
215   void insertModels() {
216     (insertModel<IfaceModels>(), ...);
217   }
218 
219 private:
220   /// Insert the given interface type into the map, ignoring it if it doesn't
221   /// actually represent an interface.
222   template <typename T>
223   inline void insertPotentialInterface() {
224     if constexpr (detect_get_interface_id<T>::value)
225       insertModel<typename T::ModelT>();
226   }
227 
228   /// Insert the given interface model into the map.
229   template <typename InterfaceModel>
230   void insertModel() {
231     // FIXME(#59975): Uncomment this when SPIRV no longer awkwardly reimplements
232     // interfaces in a way that isn't clean/compatible.
233     // static_assert(std::is_trivially_destructible_v<InterfaceModel>,
234     //               "interface models must be trivially destructible");
235 
236     // Build the interface model, optionally initializing if necessary.
237     InterfaceModel *model =
238         new (malloc(sizeof(InterfaceModel))) InterfaceModel();
239     if constexpr (detect_initialize_method<InterfaceModel>::value)
240       model->initializeInterfaceConcept(*this);
241 
242     insert(InterfaceModel::Interface::getInterfaceID(), model);
243   }
244   /// Insert the given set of interface id and concept implementation into the
245   /// interface map.
246   void insert(TypeID interfaceId, void *conceptImpl);
247 
248   /// Compare two TypeID instances by comparing the underlying pointer.
249   static bool compare(TypeID lhs, TypeID rhs) {
250     return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
251   }
252 
253   /// Returns an instance of the concept object for the given interface id if it
254   /// was registered to this map, null otherwise.
255   void *lookup(TypeID id) const {
256     const auto *it =
257         llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) {
258           return compare(it.first, id);
259         });
260     return (it != interfaces.end() && it->first == id) ? it->second : nullptr;
261   }
262 
263   /// A list of interface instances, sorted by TypeID.
264   SmallVector<std::pair<TypeID, void *>> interfaces;
265 };
266 
267 template <typename ConcreteType, typename ValueT, typename Traits,
268           typename BaseType,
269           template <typename, template <typename> class> class BaseTrait>
270 void isInterfaceImpl(
271     Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait> &);
272 
273 template <typename T>
274 using is_interface_t = decltype(isInterfaceImpl(std::declval<T &>()));
275 
276 template <typename T>
277 using IsInterface = llvm::is_detected<is_interface_t, T>;
278 
279 } // namespace detail
280 } // namespace mlir
281 
282 namespace llvm {
283 
284 template <typename T>
285 struct DenseMapInfo<T, std::enable_if_t<mlir::detail::IsInterface<T>::value>> {
286   using ValueTypeInfo = llvm::DenseMapInfo<typename T::ValueType>;
287 
288   static T getEmptyKey() { return T(ValueTypeInfo::getEmptyKey(), nullptr); }
289 
290   static T getTombstoneKey() {
291     return T(ValueTypeInfo::getTombstoneKey(), nullptr);
292   }
293 
294   static unsigned getHashValue(T val) {
295     return ValueTypeInfo::getHashValue(val);
296   }
297 
298   static bool isEqual(T lhs, T rhs) { return ValueTypeInfo::isEqual(lhs, rhs); }
299 };
300 
301 } // namespace llvm
302 
303 #endif
304