xref: /llvm-project/mlir/include/mlir/Tools/PDLL/AST/Types.h (revision 7ac1fb01e9b70d09e6c4f39414bcd7c93787ef91)
1 //===- Types.h --------------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_TOOLS_PDLL_AST_TYPES_H_
10 #define MLIR_TOOLS_PDLL_AST_TYPES_H_
11 
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/Support/StorageUniquer.h"
14 #include <optional>
15 
16 namespace mlir {
17 namespace pdll {
18 namespace ods {
19 class Operation;
20 } // namespace ods
21 
22 namespace ast {
23 class Context;
24 
25 namespace detail {
26 struct AttributeTypeStorage;
27 struct ConstraintTypeStorage;
28 struct OperationTypeStorage;
29 struct RangeTypeStorage;
30 struct RewriteTypeStorage;
31 struct TupleTypeStorage;
32 struct TypeTypeStorage;
33 struct ValueTypeStorage;
34 } // namespace detail
35 
36 //===----------------------------------------------------------------------===//
37 // Type
38 //===----------------------------------------------------------------------===//
39 
40 class Type {
41 public:
42   /// This class represents the internal storage of the Type class.
43   struct Storage;
44 
45   /// This class provides several utilities when defining derived type classes.
46   template <typename ImplT, typename BaseT = Type>
47   class TypeBase : public BaseT {
48   public:
49     using Base = TypeBase<ImplT, BaseT>;
50     using ImplTy = ImplT;
51     using BaseT::BaseT;
52 
53     /// Provide type casting support.
classof(Type type)54     static bool classof(Type type) {
55       return type.getTypeID() == TypeID::get<ImplTy>();
56     }
57   };
58 
impl(impl)59   Type(Storage *impl = nullptr) : impl(impl) {}
60 
61   bool operator==(const Type &other) const { return impl == other.impl; }
62   bool operator!=(const Type &other) const { return !(*this == other); }
63   explicit operator bool() const { return impl; }
64 
65   /// Provide type casting support.
66   template <typename U>
67   [[deprecated("Use mlir::isa<U>() instead")]]
isa()68   bool isa() const {
69     assert(impl && "isa<> used on a null type.");
70     return U::classof(*this);
71   }
72   template <typename U, typename V, typename... Others>
73   [[deprecated("Use mlir::isa<U>() instead")]]
isa()74   bool isa() const {
75     return isa<U>() || isa<V, Others...>();
76   }
77   template <typename U>
78   [[deprecated("Use mlir::dyn_cast<U>() instead")]]
dyn_cast()79   U dyn_cast() const {
80     return isa<U>() ? U(impl) : U(nullptr);
81   }
82   template <typename U>
83   [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
dyn_cast_or_null()84   U dyn_cast_or_null() const {
85     return (impl && isa<U>()) ? U(impl) : U(nullptr);
86   }
87   template <typename U>
88   [[deprecated("Use mlir::cast<U>() instead")]]
cast()89   U cast() const {
90     assert(isa<U>());
91     return U(impl);
92   }
93 
94   /// Return the internal storage instance of this type.
getImpl()95   Storage *getImpl() const { return impl; }
96 
97   /// Return the TypeID instance of this type.
98   TypeID getTypeID() const;
99 
100   /// Print this type to the given stream.
101   void print(raw_ostream &os) const;
102 
103   /// Try to refine this type with the one provided. Given two compatible types,
104   /// this will return a merged type contains as much detail from the two types.
105   /// For example, if refining two operation types and one contains a name,
106   /// while the other doesn't, the refined type contains the name. If the two
107   /// types are incompatible, null is returned.
108   Type refineWith(Type other) const;
109 
110 protected:
111   /// Return the internal storage instance of this type reinterpreted as the
112   /// given derived storage type.
113   template <typename T>
getImplAs()114   const T *getImplAs() const {
115     return static_cast<const T *>(impl);
116   }
117 
118 private:
119   Storage *impl;
120 };
121 
hash_value(Type type)122 inline llvm::hash_code hash_value(Type type) {
123   return DenseMapInfo<Type::Storage *>::getHashValue(type.getImpl());
124 }
125 
126 inline raw_ostream &operator<<(raw_ostream &os, Type type) {
127   type.print(os);
128   return os;
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // AttributeType
133 //===----------------------------------------------------------------------===//
134 
135 /// This class represents a PDLL type that corresponds to an mlir::Attribute.
136 class AttributeType : public Type::TypeBase<detail::AttributeTypeStorage> {
137 public:
138   using Base::Base;
139 
140   /// Return an instance of the Attribute type.
141   static AttributeType get(Context &context);
142 };
143 
144 //===----------------------------------------------------------------------===//
145 // ConstraintType
146 //===----------------------------------------------------------------------===//
147 
148 /// This class represents a PDLL type that corresponds to a constraint. This
149 /// type has no MLIR C++ API correspondance.
150 class ConstraintType : public Type::TypeBase<detail::ConstraintTypeStorage> {
151 public:
152   using Base::Base;
153 
154   /// Return an instance of the Constraint type.
155   static ConstraintType get(Context &context);
156 };
157 
158 //===----------------------------------------------------------------------===//
159 // OperationType
160 //===----------------------------------------------------------------------===//
161 
162 /// This class represents a PDLL type that corresponds to an mlir::Operation.
163 class OperationType : public Type::TypeBase<detail::OperationTypeStorage> {
164 public:
165   using Base::Base;
166 
167   /// Return an instance of the Operation type with an optional operation name.
168   /// If no name is provided, this type may refer to any operation.
169   static OperationType get(Context &context,
170                            std::optional<StringRef> name = std::nullopt,
171                            const ods::Operation *odsOp = nullptr);
172 
173   /// Return the name of this operation type, or std::nullopt if it doesn't have
174   /// on.
175   std::optional<StringRef> getName() const;
176 
177   /// Return the ODS operation that this type refers to, or nullptr if the ODS
178   /// operation is unknown.
179   const ods::Operation *getODSOperation() const;
180 };
181 
182 //===----------------------------------------------------------------------===//
183 // RangeType
184 //===----------------------------------------------------------------------===//
185 
186 /// This class represents a PDLL type that corresponds to a range of elements
187 /// with a given element type.
188 class RangeType : public Type::TypeBase<detail::RangeTypeStorage> {
189 public:
190   using Base::Base;
191 
192   /// Return an instance of the Range type with the given element type.
193   static RangeType get(Context &context, Type elementType);
194 
195   /// Return the element type of this range.
196   Type getElementType() const;
197 };
198 
199 //===----------------------------------------------------------------------===//
200 // TypeRangeType
201 
202 /// This class represents a PDLL type that corresponds to an mlir::TypeRange.
203 class TypeRangeType : public RangeType {
204 public:
205   using RangeType::RangeType;
206 
207   /// Provide type casting support.
208   static bool classof(Type type);
209 
210   /// Return an instance of the TypeRange type.
211   static TypeRangeType get(Context &context);
212 };
213 
214 //===----------------------------------------------------------------------===//
215 // ValueRangeType
216 
217 /// This class represents a PDLL type that corresponds to an mlir::ValueRange.
218 class ValueRangeType : public RangeType {
219 public:
220   using RangeType::RangeType;
221 
222   /// Provide type casting support.
223   static bool classof(Type type);
224 
225   /// Return an instance of the ValueRange type.
226   static ValueRangeType get(Context &context);
227 };
228 
229 //===----------------------------------------------------------------------===//
230 // RewriteType
231 //===----------------------------------------------------------------------===//
232 
233 /// This class represents a PDLL type that corresponds to a rewrite reference.
234 /// This type has no MLIR C++ API correspondance.
235 class RewriteType : public Type::TypeBase<detail::RewriteTypeStorage> {
236 public:
237   using Base::Base;
238 
239   /// Return an instance of the Rewrite type.
240   static RewriteType get(Context &context);
241 };
242 
243 //===----------------------------------------------------------------------===//
244 // TupleType
245 //===----------------------------------------------------------------------===//
246 
247 /// This class represents a PDLL tuple type, i.e. an ordered set of element
248 /// types with optional names.
249 class TupleType : public Type::TypeBase<detail::TupleTypeStorage> {
250 public:
251   using Base::Base;
252 
253   /// Return an instance of the Tuple type.
254   static TupleType get(Context &context, ArrayRef<Type> elementTypes,
255                        ArrayRef<StringRef> elementNames);
256   static TupleType get(Context &context,
257                        ArrayRef<Type> elementTypes = std::nullopt);
258 
259   /// Return the element types of this tuple.
260   ArrayRef<Type> getElementTypes() const;
261 
262   /// Return the element names of this tuple.
263   ArrayRef<StringRef> getElementNames() const;
264 
265   /// Return the number of elements within this tuple.
size()266   size_t size() const { return getElementTypes().size(); }
267 
268   /// Return if the tuple has no elements.
empty()269   bool empty() const { return size() == 0; }
270 };
271 
272 //===----------------------------------------------------------------------===//
273 // TypeType
274 //===----------------------------------------------------------------------===//
275 
276 /// This class represents a PDLL type that corresponds to an mlir::Type.
277 class TypeType : public Type::TypeBase<detail::TypeTypeStorage> {
278 public:
279   using Base::Base;
280 
281   /// Return an instance of the Type type.
282   static TypeType get(Context &context);
283 };
284 
285 //===----------------------------------------------------------------------===//
286 // ValueType
287 //===----------------------------------------------------------------------===//
288 
289 /// This class represents a PDLL type that corresponds to an mlir::Value.
290 class ValueType : public Type::TypeBase<detail::ValueTypeStorage> {
291 public:
292   using Base::Base;
293 
294   /// Return an instance of the Value type.
295   static ValueType get(Context &context);
296 };
297 
298 } // namespace ast
299 } // namespace pdll
300 } // namespace mlir
301 
302 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::AttributeTypeStorage)
MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ConstraintTypeStorage)303 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ConstraintTypeStorage)
304 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::OperationTypeStorage)
305 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::RangeTypeStorage)
306 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::RewriteTypeStorage)
307 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TupleTypeStorage)
308 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TypeTypeStorage)
309 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ValueTypeStorage)
310 
311 namespace llvm {
312 template <>
313 struct DenseMapInfo<mlir::pdll::ast::Type> {
314   static mlir::pdll::ast::Type getEmptyKey() {
315     void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
316     return mlir::pdll::ast::Type(
317         static_cast<mlir::pdll::ast::Type::Storage *>(pointer));
318   }
319   static mlir::pdll::ast::Type getTombstoneKey() {
320     void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
321     return mlir::pdll::ast::Type(
322         static_cast<mlir::pdll::ast::Type::Storage *>(pointer));
323   }
324   static unsigned getHashValue(mlir::pdll::ast::Type val) {
325     return llvm::hash_value(val.getImpl());
326   }
327   static bool isEqual(mlir::pdll::ast::Type lhs, mlir::pdll::ast::Type rhs) {
328     return lhs == rhs;
329   }
330 };
331 
332 /// Add support for llvm style casts.
333 /// We provide a cast between To and From if From is mlir::pdll::ast::Type or
334 /// derives from it
335 template <typename To, typename From>
336 struct CastInfo<
337     To, From,
338     std::enable_if_t<
339         std::is_same_v<mlir::pdll::ast::Type, std::remove_const_t<From>> ||
340         std::is_base_of_v<mlir::pdll::ast::Type, From>>>
341     : NullableValueCastFailed<To>,
342       DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
343   static inline bool isPossible(mlir::pdll::ast::Type ty) {
344     /// Return a constant true instead of a dynamic true when casting to self or
345     /// up the hierarchy.
346     if constexpr (std::is_base_of_v<To, From>) {
347       return true;
348     } else {
349       return To::classof(ty);
350     };
351   }
352   static inline To doCast(mlir::pdll::ast::Type ty) { return To(ty.getImpl()); }
353 };
354 } // namespace llvm
355 
356 #endif // MLIR_TOOLS_PDLL_AST_TYPES_H_
357