xref: /llvm-project/mlir/include/mlir/IR/DialectInterface.h (revision cb3bc5be9c20d893adf94cdf436092657ab5ab40)
1 //===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_DIALECTINTERFACE_H
10 #define MLIR_IR_DIALECTINTERFACE_H
11 
12 #include "mlir/Support/TypeID.h"
13 #include "llvm/ADT/DenseSet.h"
14 #include "llvm/ADT/STLExtras.h"
15 #include <vector>
16 
17 namespace mlir {
18 class Dialect;
19 class MLIRContext;
20 class Operation;
21 
22 //===----------------------------------------------------------------------===//
23 // DialectInterface
24 //===----------------------------------------------------------------------===//
25 namespace detail {
26 /// The base class used for all derived interface types. This class provides
27 /// utilities necessary for registration.
28 template <typename ConcreteType, typename BaseT>
29 class DialectInterfaceBase : public BaseT {
30 public:
31   using Base = DialectInterfaceBase<ConcreteType, BaseT>;
32 
33   /// Get a unique id for the derived interface type.
34   static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
35 
36 protected:
37   DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {}
38 };
39 } // namespace detail
40 
41 /// This class represents an interface overridden for a single dialect.
42 class DialectInterface {
43 public:
44   virtual ~DialectInterface();
45 
46   /// The base class used for all derived interface types. This class provides
47   /// utilities necessary for registration.
48   template <typename ConcreteType>
49   using Base = detail::DialectInterfaceBase<ConcreteType, DialectInterface>;
50 
51   /// Return the dialect that this interface represents.
52   Dialect *getDialect() const { return dialect; }
53 
54   /// Return the context that holds the parent dialect of this interface.
55   MLIRContext *getContext() const;
56 
57   /// Return the derived interface id.
58   TypeID getID() const { return interfaceID; }
59 
60 protected:
61   DialectInterface(Dialect *dialect, TypeID id)
62       : dialect(dialect), interfaceID(id) {}
63 
64 private:
65   /// The dialect that represents this interface.
66   Dialect *dialect;
67 
68   /// The unique identifier for the derived interface type.
69   TypeID interfaceID;
70 };
71 
72 //===----------------------------------------------------------------------===//
73 // DialectInterfaceCollection
74 //===----------------------------------------------------------------------===//
75 
76 namespace detail {
77 /// This class is the base class for a collection of instances for a specific
78 /// interface kind.
79 class DialectInterfaceCollectionBase {
80   /// DenseMap info for dialect interfaces that allows lookup by the dialect.
81   struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> {
82     using DenseMapInfo<const DialectInterface *>::isEqual;
83 
84     static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); }
85     static unsigned getHashValue(const DialectInterface *key) {
86       return getHashValue(key->getDialect());
87     }
88 
89     static bool isEqual(Dialect *lhs, const DialectInterface *rhs) {
90       if (rhs == getEmptyKey() || rhs == getTombstoneKey())
91         return false;
92       return lhs == rhs->getDialect();
93     }
94   };
95 
96   /// A set of registered dialect interface instances.
97   using InterfaceSetT = DenseSet<const DialectInterface *, InterfaceKeyInfo>;
98   using InterfaceVectorT = std::vector<const DialectInterface *>;
99 
100 public:
101   DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind,
102                                  StringRef interfaceName);
103   virtual ~DialectInterfaceCollectionBase();
104 
105 protected:
106   /// Get the interface for the dialect of given operation, or null if one
107   /// is not registered.
108   const DialectInterface *getInterfaceFor(Operation *op) const;
109 
110   /// Get the interface for the given dialect.
111   const DialectInterface *getInterfaceFor(Dialect *dialect) const {
112     auto it = interfaces.find_as(dialect);
113     return it == interfaces.end() ? nullptr : *it;
114   }
115 
116   /// An iterator class that iterates the held interface objects of the given
117   /// derived interface type.
118   template <typename InterfaceT>
119   struct iterator
120       : public llvm::mapped_iterator_base<iterator<InterfaceT>,
121                                           InterfaceVectorT::const_iterator,
122                                           const InterfaceT &> {
123     using llvm::mapped_iterator_base<iterator<InterfaceT>,
124                                      InterfaceVectorT::const_iterator,
125                                      const InterfaceT &>::mapped_iterator_base;
126 
127     /// Map the element to the iterator result type.
128     const InterfaceT &mapElement(const DialectInterface *interface) const {
129       return *static_cast<const InterfaceT *>(interface);
130     }
131   };
132 
133   /// Iterator access to the held interfaces.
134   template <typename InterfaceT>
135   iterator<InterfaceT> interface_begin() const {
136     return iterator<InterfaceT>(orderedInterfaces.begin());
137   }
138   template <typename InterfaceT>
139   iterator<InterfaceT> interface_end() const {
140     return iterator<InterfaceT>(orderedInterfaces.end());
141   }
142 
143 private:
144   /// A set of registered dialect interface instances.
145   InterfaceSetT interfaces;
146   /// An ordered list of the registered interface instances, necessary for
147   /// deterministic iteration.
148   // NOTE: SetVector does not provide find access, so it can't be used here.
149   InterfaceVectorT orderedInterfaces;
150 };
151 } // namespace detail
152 
153 /// A collection of dialect interfaces within a context, for a given concrete
154 /// interface type.
155 template <typename InterfaceType>
156 class DialectInterfaceCollection
157     : public detail::DialectInterfaceCollectionBase {
158 public:
159   using Base = DialectInterfaceCollection<InterfaceType>;
160 
161   /// Collect the registered dialect interfaces within the provided context.
162   DialectInterfaceCollection(MLIRContext *ctx)
163       : detail::DialectInterfaceCollectionBase(
164             ctx, InterfaceType::getInterfaceID(),
165             llvm::getTypeName<InterfaceType>()) {}
166 
167   /// Get the interface for a given object, or null if one is not registered.
168   /// The object may be a dialect or an operation instance.
169   template <typename Object>
170   const InterfaceType *getInterfaceFor(Object *obj) const {
171     return static_cast<const InterfaceType *>(
172         detail::DialectInterfaceCollectionBase::getInterfaceFor(obj));
173   }
174 
175   /// Iterator access to the held interfaces.
176   using iterator =
177       detail::DialectInterfaceCollectionBase::iterator<InterfaceType>;
178   iterator begin() const { return interface_begin<InterfaceType>(); }
179   iterator end() const { return interface_end<InterfaceType>(); }
180 
181 private:
182   using detail::DialectInterfaceCollectionBase::interface_begin;
183   using detail::DialectInterfaceCollectionBase::interface_end;
184 };
185 
186 } // namespace mlir
187 
188 #endif
189