1 //===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines several support classes for defining interfaces. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_SUPPORT_INTERFACESUPPORT_H 14 #define MLIR_SUPPORT_INTERFACESUPPORT_H 15 16 #include "mlir/Support/TypeID.h" 17 #include "llvm/ADT/ArrayRef.h" 18 #include "llvm/ADT/DenseMap.h" 19 #include "llvm/Support/TypeName.h" 20 21 namespace mlir { 22 namespace detail { 23 //===----------------------------------------------------------------------===// 24 // Interface 25 //===----------------------------------------------------------------------===// 26 27 /// This class represents an abstract interface. An interface is a simplified 28 /// mechanism for attaching concept based polymorphism to a class hierarchy. An 29 /// interface is comprised of two components: 30 /// * The derived interface class: This is what users interact with, and invoke 31 /// methods on. 32 /// * An interface `Trait` class: This is the class that is attached to the 33 /// object implementing the interface. It is the mechanism with which models 34 /// are specialized. 35 /// 36 /// Derived interfaces types must provide the following template types: 37 /// * ConcreteType: The CRTP derived type. 38 /// * ValueT: The opaque type the derived interface operates on. For example 39 /// `Operation*` for operation interfaces, or `Attribute` for 40 /// attribute interfaces. 41 /// * Traits: A class that contains definitions for a 'Concept' and a 'Model' 42 /// class. The 'Concept' class defines an abstract virtual interface, 43 /// where as the 'Model' class implements this interface for a 44 /// specific derived T type. Both of these classes *must* not contain 45 /// non-static data. A simple example is shown below: 46 /// 47 /// ```c++ 48 /// struct ExampleInterfaceTraits { 49 /// struct Concept { 50 /// virtual unsigned getNumInputs(T t) const = 0; 51 /// }; 52 /// template <typename DerivedT> class Model { 53 /// unsigned getNumInputs(T t) const final { 54 /// return cast<DerivedT>(t).getNumInputs(); 55 /// } 56 /// }; 57 /// }; 58 /// ``` 59 /// 60 /// * BaseType: A desired base type for the interface. This is a class 61 /// that provides specific functionality for the `ValueT` 62 /// value. For instance the specific `Op` that will wrap the 63 /// `Operation*` for an `OpInterface`. 64 /// * BaseTrait: The base type for the interface trait. This is the base class 65 /// to use for the interface trait that will be attached to each 66 /// instance of `ValueT` that implements this interface. 67 /// 68 template <typename ConcreteType, typename ValueT, typename Traits, 69 typename BaseType, 70 template <typename, template <typename> class> class BaseTrait> 71 class Interface : public BaseType { 72 public: 73 using Concept = typename Traits::Concept; 74 template <typename T> 75 using Model = typename Traits::template Model<T>; 76 template <typename T> 77 using FallbackModel = typename Traits::template FallbackModel<T>; 78 using InterfaceBase = 79 Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>; 80 template <typename T, typename U> 81 using ExternalModel = typename Traits::template ExternalModel<T, U>; 82 using ValueType = ValueT; 83 84 /// This is a special trait that registers a given interface with an object. 85 template <typename ConcreteT> 86 struct Trait : public BaseTrait<ConcreteT, Trait> { 87 using ModelT = Model<ConcreteT>; 88 89 /// Define an accessor for the ID of this interface. getInterfaceIDTrait90 static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } 91 }; 92 93 /// Construct an interface from an instance of the value type. 94 explicit Interface(ValueT t = ValueT()) BaseType(t)95 : BaseType(t), 96 conceptImpl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { 97 assert((!t || conceptImpl) && 98 "expected value to provide interface instance"); 99 } Interface(std::nullptr_t)100 Interface(std::nullptr_t) : BaseType(ValueT()), conceptImpl(nullptr) {} 101 102 /// Construct an interface instance from a type that implements this 103 /// interface's trait. 104 template <typename T, 105 std::enable_if_t<std::is_base_of<Trait<T>, T>::value> * = nullptr> Interface(T t)106 Interface(T t) 107 : BaseType(t), 108 conceptImpl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { 109 assert((!t || conceptImpl) && 110 "expected value to provide interface instance"); 111 } 112 113 /// Constructor for a known concept. Interface(ValueT t,const Concept * conceptImpl)114 Interface(ValueT t, const Concept *conceptImpl) 115 : BaseType(t), conceptImpl(const_cast<Concept *>(conceptImpl)) { 116 assert(!t || ConcreteType::getInterfaceFor(t) == conceptImpl); 117 } 118 119 /// Constructor for DenseMapInfo's empty key and tombstone key. Interface(ValueT t,std::nullptr_t)120 Interface(ValueT t, std::nullptr_t) : BaseType(t), conceptImpl(nullptr) {} 121 122 /// Support 'classof' by checking if the given object defines the concrete 123 /// interface. classof(ValueT t)124 static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); } 125 126 /// Define an accessor for the ID of this interface. getInterfaceID()127 static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } 128 129 protected: 130 /// Get the raw concept in the correct derived concept type. getImpl()131 const Concept *getImpl() const { return conceptImpl; } getImpl()132 Concept *getImpl() { return conceptImpl; } 133 134 private: 135 /// A pointer to the impl concept object. 136 Concept *conceptImpl; 137 }; 138 139 //===----------------------------------------------------------------------===// 140 // InterfaceMap 141 //===----------------------------------------------------------------------===// 142 143 /// Template utility that computes the number of elements within `T` that 144 /// satisfy the given predicate. 145 template <template <class> class Pred, size_t N, typename... Ts> 146 struct count_if_t_impl : public std::integral_constant<size_t, N> {}; 147 template <template <class> class Pred, size_t N, typename T, typename... Us> 148 struct count_if_t_impl<Pred, N, T, Us...> 149 : public std::integral_constant< 150 size_t, 151 count_if_t_impl<Pred, N + (Pred<T>::value ? 1 : 0), Us...>::value> {}; 152 template <template <class> class Pred, typename... Ts> 153 using count_if_t = count_if_t_impl<Pred, 0, Ts...>; 154 155 /// This class provides an efficient mapping between a given `Interface` type, 156 /// and a particular implementation of its concept. 157 class InterfaceMap { 158 /// Trait to check if T provides a static 'getInterfaceID' method. 159 template <typename T, typename... Args> 160 using has_get_interface_id = decltype(T::getInterfaceID()); 161 template <typename T> 162 using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>; 163 template <typename... Types> 164 using num_interface_types_t = count_if_t<detect_get_interface_id, Types...>; 165 166 /// Trait to check if T provides a 'initializeInterfaceConcept' method. 167 template <typename T, typename... Args> 168 using has_initialize_method = 169 decltype(std::declval<T>().initializeInterfaceConcept( 170 std::declval<InterfaceMap &>())); 171 template <typename T> 172 using detect_initialize_method = llvm::is_detected<has_initialize_method, T>; 173 174 public: 175 InterfaceMap() = default; 176 InterfaceMap(InterfaceMap &&) = default; 177 InterfaceMap &operator=(InterfaceMap &&rhs) { 178 for (auto &it : interfaces) 179 free(it.second); 180 interfaces = std::move(rhs.interfaces); 181 return *this; 182 } 183 ~InterfaceMap() { 184 for (auto &it : interfaces) 185 free(it.second); 186 } 187 188 /// Construct an InterfaceMap with the given set of template types. For 189 /// convenience given that object trait lists may contain other non-interface 190 /// types, not all of the types need to be interfaces. The provided types that 191 /// do not represent interfaces are not added to the interface map. 192 template <typename... Types> 193 static InterfaceMap get() { 194 constexpr size_t numInterfaces = num_interface_types_t<Types...>::value; 195 if constexpr (numInterfaces == 0) 196 return InterfaceMap(); 197 198 InterfaceMap map; 199 (map.insertPotentialInterface<Types>(), ...); 200 return map; 201 } 202 203 /// Returns an instance of the concept object for the given interface if it 204 /// was registered to this map, null otherwise. 205 template <typename T> 206 typename T::Concept *lookup() const { 207 return reinterpret_cast<typename T::Concept *>(lookup(T::getInterfaceID())); 208 } 209 210 /// Returns true if the interface map contains an interface for the given id. 211 bool contains(TypeID interfaceID) const { return lookup(interfaceID); } 212 213 /// Insert the given interface models. 214 template <typename... IfaceModels> 215 void insertModels() { 216 (insertModel<IfaceModels>(), ...); 217 } 218 219 private: 220 /// Insert the given interface type into the map, ignoring it if it doesn't 221 /// actually represent an interface. 222 template <typename T> 223 inline void insertPotentialInterface() { 224 if constexpr (detect_get_interface_id<T>::value) 225 insertModel<typename T::ModelT>(); 226 } 227 228 /// Insert the given interface model into the map. 229 template <typename InterfaceModel> 230 void insertModel() { 231 // FIXME(#59975): Uncomment this when SPIRV no longer awkwardly reimplements 232 // interfaces in a way that isn't clean/compatible. 233 // static_assert(std::is_trivially_destructible_v<InterfaceModel>, 234 // "interface models must be trivially destructible"); 235 236 // Build the interface model, optionally initializing if necessary. 237 InterfaceModel *model = 238 new (malloc(sizeof(InterfaceModel))) InterfaceModel(); 239 if constexpr (detect_initialize_method<InterfaceModel>::value) 240 model->initializeInterfaceConcept(*this); 241 242 insert(InterfaceModel::Interface::getInterfaceID(), model); 243 } 244 /// Insert the given set of interface id and concept implementation into the 245 /// interface map. 246 void insert(TypeID interfaceId, void *conceptImpl); 247 248 /// Compare two TypeID instances by comparing the underlying pointer. 249 static bool compare(TypeID lhs, TypeID rhs) { 250 return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer(); 251 } 252 253 /// Returns an instance of the concept object for the given interface id if it 254 /// was registered to this map, null otherwise. 255 void *lookup(TypeID id) const { 256 const auto *it = 257 llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) { 258 return compare(it.first, id); 259 }); 260 return (it != interfaces.end() && it->first == id) ? it->second : nullptr; 261 } 262 263 /// A list of interface instances, sorted by TypeID. 264 SmallVector<std::pair<TypeID, void *>> interfaces; 265 }; 266 267 template <typename ConcreteType, typename ValueT, typename Traits, 268 typename BaseType, 269 template <typename, template <typename> class> class BaseTrait> 270 void isInterfaceImpl( 271 Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait> &); 272 273 template <typename T> 274 using is_interface_t = decltype(isInterfaceImpl(std::declval<T &>())); 275 276 template <typename T> 277 using IsInterface = llvm::is_detected<is_interface_t, T>; 278 279 } // namespace detail 280 } // namespace mlir 281 282 namespace llvm { 283 284 template <typename T> 285 struct DenseMapInfo<T, std::enable_if_t<mlir::detail::IsInterface<T>::value>> { 286 using ValueTypeInfo = llvm::DenseMapInfo<typename T::ValueType>; 287 288 static T getEmptyKey() { return T(ValueTypeInfo::getEmptyKey(), nullptr); } 289 290 static T getTombstoneKey() { 291 return T(ValueTypeInfo::getTombstoneKey(), nullptr); 292 } 293 294 static unsigned getHashValue(T val) { 295 return ValueTypeInfo::getHashValue(val); 296 } 297 298 static bool isEqual(T lhs, T rhs) { return ValueTypeInfo::isEqual(lhs, rhs); } 299 }; 300 301 } // namespace llvm 302 303 #endif 304