1 //===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_DIALECTINTERFACE_H 10 #define MLIR_IR_DIALECTINTERFACE_H 11 12 #include "mlir/Support/TypeID.h" 13 #include "llvm/ADT/DenseSet.h" 14 #include "llvm/ADT/STLExtras.h" 15 #include <vector> 16 17 namespace mlir { 18 class Dialect; 19 class MLIRContext; 20 class Operation; 21 22 //===----------------------------------------------------------------------===// 23 // DialectInterface 24 //===----------------------------------------------------------------------===// 25 namespace detail { 26 /// The base class used for all derived interface types. This class provides 27 /// utilities necessary for registration. 28 template <typename ConcreteType, typename BaseT> 29 class DialectInterfaceBase : public BaseT { 30 public: 31 using Base = DialectInterfaceBase<ConcreteType, BaseT>; 32 33 /// Get a unique id for the derived interface type. 34 static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } 35 36 protected: 37 DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {} 38 }; 39 } // namespace detail 40 41 /// This class represents an interface overridden for a single dialect. 42 class DialectInterface { 43 public: 44 virtual ~DialectInterface(); 45 46 /// The base class used for all derived interface types. This class provides 47 /// utilities necessary for registration. 48 template <typename ConcreteType> 49 using Base = detail::DialectInterfaceBase<ConcreteType, DialectInterface>; 50 51 /// Return the dialect that this interface represents. 52 Dialect *getDialect() const { return dialect; } 53 54 /// Return the context that holds the parent dialect of this interface. 55 MLIRContext *getContext() const; 56 57 /// Return the derived interface id. 58 TypeID getID() const { return interfaceID; } 59 60 protected: 61 DialectInterface(Dialect *dialect, TypeID id) 62 : dialect(dialect), interfaceID(id) {} 63 64 private: 65 /// The dialect that represents this interface. 66 Dialect *dialect; 67 68 /// The unique identifier for the derived interface type. 69 TypeID interfaceID; 70 }; 71 72 //===----------------------------------------------------------------------===// 73 // DialectInterfaceCollection 74 //===----------------------------------------------------------------------===// 75 76 namespace detail { 77 /// This class is the base class for a collection of instances for a specific 78 /// interface kind. 79 class DialectInterfaceCollectionBase { 80 /// DenseMap info for dialect interfaces that allows lookup by the dialect. 81 struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> { 82 using DenseMapInfo<const DialectInterface *>::isEqual; 83 84 static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); } 85 static unsigned getHashValue(const DialectInterface *key) { 86 return getHashValue(key->getDialect()); 87 } 88 89 static bool isEqual(Dialect *lhs, const DialectInterface *rhs) { 90 if (rhs == getEmptyKey() || rhs == getTombstoneKey()) 91 return false; 92 return lhs == rhs->getDialect(); 93 } 94 }; 95 96 /// A set of registered dialect interface instances. 97 using InterfaceSetT = DenseSet<const DialectInterface *, InterfaceKeyInfo>; 98 using InterfaceVectorT = std::vector<const DialectInterface *>; 99 100 public: 101 DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind, 102 StringRef interfaceName); 103 virtual ~DialectInterfaceCollectionBase(); 104 105 protected: 106 /// Get the interface for the dialect of given operation, or null if one 107 /// is not registered. 108 const DialectInterface *getInterfaceFor(Operation *op) const; 109 110 /// Get the interface for the given dialect. 111 const DialectInterface *getInterfaceFor(Dialect *dialect) const { 112 auto it = interfaces.find_as(dialect); 113 return it == interfaces.end() ? nullptr : *it; 114 } 115 116 /// An iterator class that iterates the held interface objects of the given 117 /// derived interface type. 118 template <typename InterfaceT> 119 struct iterator 120 : public llvm::mapped_iterator_base<iterator<InterfaceT>, 121 InterfaceVectorT::const_iterator, 122 const InterfaceT &> { 123 using llvm::mapped_iterator_base<iterator<InterfaceT>, 124 InterfaceVectorT::const_iterator, 125 const InterfaceT &>::mapped_iterator_base; 126 127 /// Map the element to the iterator result type. 128 const InterfaceT &mapElement(const DialectInterface *interface) const { 129 return *static_cast<const InterfaceT *>(interface); 130 } 131 }; 132 133 /// Iterator access to the held interfaces. 134 template <typename InterfaceT> 135 iterator<InterfaceT> interface_begin() const { 136 return iterator<InterfaceT>(orderedInterfaces.begin()); 137 } 138 template <typename InterfaceT> 139 iterator<InterfaceT> interface_end() const { 140 return iterator<InterfaceT>(orderedInterfaces.end()); 141 } 142 143 private: 144 /// A set of registered dialect interface instances. 145 InterfaceSetT interfaces; 146 /// An ordered list of the registered interface instances, necessary for 147 /// deterministic iteration. 148 // NOTE: SetVector does not provide find access, so it can't be used here. 149 InterfaceVectorT orderedInterfaces; 150 }; 151 } // namespace detail 152 153 /// A collection of dialect interfaces within a context, for a given concrete 154 /// interface type. 155 template <typename InterfaceType> 156 class DialectInterfaceCollection 157 : public detail::DialectInterfaceCollectionBase { 158 public: 159 using Base = DialectInterfaceCollection<InterfaceType>; 160 161 /// Collect the registered dialect interfaces within the provided context. 162 DialectInterfaceCollection(MLIRContext *ctx) 163 : detail::DialectInterfaceCollectionBase( 164 ctx, InterfaceType::getInterfaceID(), 165 llvm::getTypeName<InterfaceType>()) {} 166 167 /// Get the interface for a given object, or null if one is not registered. 168 /// The object may be a dialect or an operation instance. 169 template <typename Object> 170 const InterfaceType *getInterfaceFor(Object *obj) const { 171 return static_cast<const InterfaceType *>( 172 detail::DialectInterfaceCollectionBase::getInterfaceFor(obj)); 173 } 174 175 /// Iterator access to the held interfaces. 176 using iterator = 177 detail::DialectInterfaceCollectionBase::iterator<InterfaceType>; 178 iterator begin() const { return interface_begin<InterfaceType>(); } 179 iterator end() const { return interface_end<InterfaceType>(); } 180 181 private: 182 using detail::DialectInterfaceCollectionBase::interface_begin; 183 using detail::DialectInterfaceCollectionBase::interface_end; 184 }; 185 186 } // namespace mlir 187 188 #endif 189