1 //===- Types.h --------------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_TOOLS_PDLL_AST_TYPES_H_
10 #define MLIR_TOOLS_PDLL_AST_TYPES_H_
11
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/Support/StorageUniquer.h"
14 #include <optional>
15
16 namespace mlir {
17 namespace pdll {
18 namespace ods {
19 class Operation;
20 } // namespace ods
21
22 namespace ast {
23 class Context;
24
25 namespace detail {
26 struct AttributeTypeStorage;
27 struct ConstraintTypeStorage;
28 struct OperationTypeStorage;
29 struct RangeTypeStorage;
30 struct RewriteTypeStorage;
31 struct TupleTypeStorage;
32 struct TypeTypeStorage;
33 struct ValueTypeStorage;
34 } // namespace detail
35
36 //===----------------------------------------------------------------------===//
37 // Type
38 //===----------------------------------------------------------------------===//
39
40 class Type {
41 public:
42 /// This class represents the internal storage of the Type class.
43 struct Storage;
44
45 /// This class provides several utilities when defining derived type classes.
46 template <typename ImplT, typename BaseT = Type>
47 class TypeBase : public BaseT {
48 public:
49 using Base = TypeBase<ImplT, BaseT>;
50 using ImplTy = ImplT;
51 using BaseT::BaseT;
52
53 /// Provide type casting support.
classof(Type type)54 static bool classof(Type type) {
55 return type.getTypeID() == TypeID::get<ImplTy>();
56 }
57 };
58
impl(impl)59 Type(Storage *impl = nullptr) : impl(impl) {}
60
61 bool operator==(const Type &other) const { return impl == other.impl; }
62 bool operator!=(const Type &other) const { return !(*this == other); }
63 explicit operator bool() const { return impl; }
64
65 /// Provide type casting support.
66 template <typename U>
67 [[deprecated("Use mlir::isa<U>() instead")]]
isa()68 bool isa() const {
69 assert(impl && "isa<> used on a null type.");
70 return U::classof(*this);
71 }
72 template <typename U, typename V, typename... Others>
73 [[deprecated("Use mlir::isa<U>() instead")]]
isa()74 bool isa() const {
75 return isa<U>() || isa<V, Others...>();
76 }
77 template <typename U>
78 [[deprecated("Use mlir::dyn_cast<U>() instead")]]
dyn_cast()79 U dyn_cast() const {
80 return isa<U>() ? U(impl) : U(nullptr);
81 }
82 template <typename U>
83 [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
dyn_cast_or_null()84 U dyn_cast_or_null() const {
85 return (impl && isa<U>()) ? U(impl) : U(nullptr);
86 }
87 template <typename U>
88 [[deprecated("Use mlir::cast<U>() instead")]]
cast()89 U cast() const {
90 assert(isa<U>());
91 return U(impl);
92 }
93
94 /// Return the internal storage instance of this type.
getImpl()95 Storage *getImpl() const { return impl; }
96
97 /// Return the TypeID instance of this type.
98 TypeID getTypeID() const;
99
100 /// Print this type to the given stream.
101 void print(raw_ostream &os) const;
102
103 /// Try to refine this type with the one provided. Given two compatible types,
104 /// this will return a merged type contains as much detail from the two types.
105 /// For example, if refining two operation types and one contains a name,
106 /// while the other doesn't, the refined type contains the name. If the two
107 /// types are incompatible, null is returned.
108 Type refineWith(Type other) const;
109
110 protected:
111 /// Return the internal storage instance of this type reinterpreted as the
112 /// given derived storage type.
113 template <typename T>
getImplAs()114 const T *getImplAs() const {
115 return static_cast<const T *>(impl);
116 }
117
118 private:
119 Storage *impl;
120 };
121
hash_value(Type type)122 inline llvm::hash_code hash_value(Type type) {
123 return DenseMapInfo<Type::Storage *>::getHashValue(type.getImpl());
124 }
125
126 inline raw_ostream &operator<<(raw_ostream &os, Type type) {
127 type.print(os);
128 return os;
129 }
130
131 //===----------------------------------------------------------------------===//
132 // AttributeType
133 //===----------------------------------------------------------------------===//
134
135 /// This class represents a PDLL type that corresponds to an mlir::Attribute.
136 class AttributeType : public Type::TypeBase<detail::AttributeTypeStorage> {
137 public:
138 using Base::Base;
139
140 /// Return an instance of the Attribute type.
141 static AttributeType get(Context &context);
142 };
143
144 //===----------------------------------------------------------------------===//
145 // ConstraintType
146 //===----------------------------------------------------------------------===//
147
148 /// This class represents a PDLL type that corresponds to a constraint. This
149 /// type has no MLIR C++ API correspondance.
150 class ConstraintType : public Type::TypeBase<detail::ConstraintTypeStorage> {
151 public:
152 using Base::Base;
153
154 /// Return an instance of the Constraint type.
155 static ConstraintType get(Context &context);
156 };
157
158 //===----------------------------------------------------------------------===//
159 // OperationType
160 //===----------------------------------------------------------------------===//
161
162 /// This class represents a PDLL type that corresponds to an mlir::Operation.
163 class OperationType : public Type::TypeBase<detail::OperationTypeStorage> {
164 public:
165 using Base::Base;
166
167 /// Return an instance of the Operation type with an optional operation name.
168 /// If no name is provided, this type may refer to any operation.
169 static OperationType get(Context &context,
170 std::optional<StringRef> name = std::nullopt,
171 const ods::Operation *odsOp = nullptr);
172
173 /// Return the name of this operation type, or std::nullopt if it doesn't have
174 /// on.
175 std::optional<StringRef> getName() const;
176
177 /// Return the ODS operation that this type refers to, or nullptr if the ODS
178 /// operation is unknown.
179 const ods::Operation *getODSOperation() const;
180 };
181
182 //===----------------------------------------------------------------------===//
183 // RangeType
184 //===----------------------------------------------------------------------===//
185
186 /// This class represents a PDLL type that corresponds to a range of elements
187 /// with a given element type.
188 class RangeType : public Type::TypeBase<detail::RangeTypeStorage> {
189 public:
190 using Base::Base;
191
192 /// Return an instance of the Range type with the given element type.
193 static RangeType get(Context &context, Type elementType);
194
195 /// Return the element type of this range.
196 Type getElementType() const;
197 };
198
199 //===----------------------------------------------------------------------===//
200 // TypeRangeType
201
202 /// This class represents a PDLL type that corresponds to an mlir::TypeRange.
203 class TypeRangeType : public RangeType {
204 public:
205 using RangeType::RangeType;
206
207 /// Provide type casting support.
208 static bool classof(Type type);
209
210 /// Return an instance of the TypeRange type.
211 static TypeRangeType get(Context &context);
212 };
213
214 //===----------------------------------------------------------------------===//
215 // ValueRangeType
216
217 /// This class represents a PDLL type that corresponds to an mlir::ValueRange.
218 class ValueRangeType : public RangeType {
219 public:
220 using RangeType::RangeType;
221
222 /// Provide type casting support.
223 static bool classof(Type type);
224
225 /// Return an instance of the ValueRange type.
226 static ValueRangeType get(Context &context);
227 };
228
229 //===----------------------------------------------------------------------===//
230 // RewriteType
231 //===----------------------------------------------------------------------===//
232
233 /// This class represents a PDLL type that corresponds to a rewrite reference.
234 /// This type has no MLIR C++ API correspondance.
235 class RewriteType : public Type::TypeBase<detail::RewriteTypeStorage> {
236 public:
237 using Base::Base;
238
239 /// Return an instance of the Rewrite type.
240 static RewriteType get(Context &context);
241 };
242
243 //===----------------------------------------------------------------------===//
244 // TupleType
245 //===----------------------------------------------------------------------===//
246
247 /// This class represents a PDLL tuple type, i.e. an ordered set of element
248 /// types with optional names.
249 class TupleType : public Type::TypeBase<detail::TupleTypeStorage> {
250 public:
251 using Base::Base;
252
253 /// Return an instance of the Tuple type.
254 static TupleType get(Context &context, ArrayRef<Type> elementTypes,
255 ArrayRef<StringRef> elementNames);
256 static TupleType get(Context &context,
257 ArrayRef<Type> elementTypes = std::nullopt);
258
259 /// Return the element types of this tuple.
260 ArrayRef<Type> getElementTypes() const;
261
262 /// Return the element names of this tuple.
263 ArrayRef<StringRef> getElementNames() const;
264
265 /// Return the number of elements within this tuple.
size()266 size_t size() const { return getElementTypes().size(); }
267
268 /// Return if the tuple has no elements.
empty()269 bool empty() const { return size() == 0; }
270 };
271
272 //===----------------------------------------------------------------------===//
273 // TypeType
274 //===----------------------------------------------------------------------===//
275
276 /// This class represents a PDLL type that corresponds to an mlir::Type.
277 class TypeType : public Type::TypeBase<detail::TypeTypeStorage> {
278 public:
279 using Base::Base;
280
281 /// Return an instance of the Type type.
282 static TypeType get(Context &context);
283 };
284
285 //===----------------------------------------------------------------------===//
286 // ValueType
287 //===----------------------------------------------------------------------===//
288
289 /// This class represents a PDLL type that corresponds to an mlir::Value.
290 class ValueType : public Type::TypeBase<detail::ValueTypeStorage> {
291 public:
292 using Base::Base;
293
294 /// Return an instance of the Value type.
295 static ValueType get(Context &context);
296 };
297
298 } // namespace ast
299 } // namespace pdll
300 } // namespace mlir
301
302 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::AttributeTypeStorage)
MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ConstraintTypeStorage)303 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ConstraintTypeStorage)
304 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::OperationTypeStorage)
305 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::RangeTypeStorage)
306 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::RewriteTypeStorage)
307 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TupleTypeStorage)
308 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TypeTypeStorage)
309 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ValueTypeStorage)
310
311 namespace llvm {
312 template <>
313 struct DenseMapInfo<mlir::pdll::ast::Type> {
314 static mlir::pdll::ast::Type getEmptyKey() {
315 void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
316 return mlir::pdll::ast::Type(
317 static_cast<mlir::pdll::ast::Type::Storage *>(pointer));
318 }
319 static mlir::pdll::ast::Type getTombstoneKey() {
320 void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
321 return mlir::pdll::ast::Type(
322 static_cast<mlir::pdll::ast::Type::Storage *>(pointer));
323 }
324 static unsigned getHashValue(mlir::pdll::ast::Type val) {
325 return llvm::hash_value(val.getImpl());
326 }
327 static bool isEqual(mlir::pdll::ast::Type lhs, mlir::pdll::ast::Type rhs) {
328 return lhs == rhs;
329 }
330 };
331
332 /// Add support for llvm style casts.
333 /// We provide a cast between To and From if From is mlir::pdll::ast::Type or
334 /// derives from it
335 template <typename To, typename From>
336 struct CastInfo<
337 To, From,
338 std::enable_if_t<
339 std::is_same_v<mlir::pdll::ast::Type, std::remove_const_t<From>> ||
340 std::is_base_of_v<mlir::pdll::ast::Type, From>>>
341 : NullableValueCastFailed<To>,
342 DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
343 static inline bool isPossible(mlir::pdll::ast::Type ty) {
344 /// Return a constant true instead of a dynamic true when casting to self or
345 /// up the hierarchy.
346 if constexpr (std::is_base_of_v<To, From>) {
347 return true;
348 } else {
349 return To::classof(ty);
350 };
351 }
352 static inline To doCast(mlir::pdll::ast::Type ty) { return To(ty.getImpl()); }
353 };
354 } // namespace llvm
355
356 #endif // MLIR_TOOLS_PDLL_AST_TYPES_H_
357