xref: /llvm-project/mlir/lib/Tools/PDLL/AST/Types.cpp (revision d2353695f8cb864f88475d3a921249b0dcbcc6f4)
1 //===- Types.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/AST/Types.h"
10 #include "TypeDetail.h"
11 #include "mlir/Tools/PDLL/AST/Context.h"
12 #include <optional>
13 
14 using namespace mlir;
15 using namespace mlir::pdll;
16 using namespace mlir::pdll::ast;
17 
18 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::AttributeTypeStorage)
MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ConstraintTypeStorage)19 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ConstraintTypeStorage)
20 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::OperationTypeStorage)
21 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::RangeTypeStorage)
22 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::RewriteTypeStorage)
23 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TupleTypeStorage)
24 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TypeTypeStorage)
25 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ValueTypeStorage)
26 
27 //===----------------------------------------------------------------------===//
28 // Type
29 //===----------------------------------------------------------------------===//
30 
31 TypeID Type::getTypeID() const { return impl->typeID; }
32 
refineWith(Type other) const33 Type Type::refineWith(Type other) const {
34   if (*this == other)
35     return *this;
36 
37   // Operation types are compatible if the operation names don't conflict.
38   if (auto opTy = mlir::dyn_cast<OperationType>(*this)) {
39     auto otherOpTy = mlir::dyn_cast<ast::OperationType>(other);
40     if (!otherOpTy)
41       return nullptr;
42     if (!otherOpTy.getName())
43       return *this;
44     if (!opTy.getName())
45       return other;
46 
47     return nullptr;
48   }
49 
50   return nullptr;
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // AttributeType
55 //===----------------------------------------------------------------------===//
56 
get(Context & context)57 AttributeType AttributeType::get(Context &context) {
58   return context.getTypeUniquer().get<ImplTy>();
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // ConstraintType
63 //===----------------------------------------------------------------------===//
64 
get(Context & context)65 ConstraintType ConstraintType::get(Context &context) {
66   return context.getTypeUniquer().get<ImplTy>();
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // OperationType
71 //===----------------------------------------------------------------------===//
72 
get(Context & context,std::optional<StringRef> name,const ods::Operation * odsOp)73 OperationType OperationType::get(Context &context,
74                                  std::optional<StringRef> name,
75                                  const ods::Operation *odsOp) {
76   return context.getTypeUniquer().get<ImplTy>(
77       /*initFn=*/function_ref<void(ImplTy *)>(),
78       std::make_pair(name.value_or(""), odsOp));
79 }
80 
getName() const81 std::optional<StringRef> OperationType::getName() const {
82   StringRef name = getImplAs<ImplTy>()->getValue().first;
83   return name.empty() ? std::optional<StringRef>()
84                       : std::optional<StringRef>(name);
85 }
86 
getODSOperation() const87 const ods::Operation *OperationType::getODSOperation() const {
88   return getImplAs<ImplTy>()->getValue().second;
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // RangeType
93 //===----------------------------------------------------------------------===//
94 
get(Context & context,Type elementType)95 RangeType RangeType::get(Context &context, Type elementType) {
96   return context.getTypeUniquer().get<ImplTy>(
97       /*initFn=*/function_ref<void(ImplTy *)>(), elementType);
98 }
99 
getElementType() const100 Type RangeType::getElementType() const {
101   return getImplAs<ImplTy>()->getValue();
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // TypeRangeType
106 
classof(Type type)107 bool TypeRangeType::classof(Type type) {
108   RangeType range = mlir::dyn_cast<RangeType>(type);
109   return range && mlir::isa<TypeType>(range.getElementType());
110 }
111 
get(Context & context)112 TypeRangeType TypeRangeType::get(Context &context) {
113   return mlir::cast<TypeRangeType>(
114       RangeType::get(context, TypeType::get(context)));
115 }
116 
117 //===----------------------------------------------------------------------===//
118 // ValueRangeType
119 
classof(Type type)120 bool ValueRangeType::classof(Type type) {
121   RangeType range = mlir::dyn_cast<RangeType>(type);
122   return range && mlir::isa<ValueType>(range.getElementType());
123 }
124 
get(Context & context)125 ValueRangeType ValueRangeType::get(Context &context) {
126   return mlir::cast<ValueRangeType>(
127       RangeType::get(context, ValueType::get(context)));
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // RewriteType
132 //===----------------------------------------------------------------------===//
133 
get(Context & context)134 RewriteType RewriteType::get(Context &context) {
135   return context.getTypeUniquer().get<ImplTy>();
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // TupleType
140 //===----------------------------------------------------------------------===//
141 
get(Context & context,ArrayRef<Type> elementTypes,ArrayRef<StringRef> elementNames)142 TupleType TupleType::get(Context &context, ArrayRef<Type> elementTypes,
143                          ArrayRef<StringRef> elementNames) {
144   assert(elementTypes.size() == elementNames.size());
145   return context.getTypeUniquer().get<ImplTy>(
146       /*initFn=*/function_ref<void(ImplTy *)>(), elementTypes, elementNames);
147 }
get(Context & context,ArrayRef<Type> elementTypes)148 TupleType TupleType::get(Context &context, ArrayRef<Type> elementTypes) {
149   SmallVector<StringRef> elementNames(elementTypes.size());
150   return get(context, elementTypes, elementNames);
151 }
152 
getElementTypes() const153 ArrayRef<Type> TupleType::getElementTypes() const {
154   return getImplAs<ImplTy>()->getValue().first;
155 }
156 
getElementNames() const157 ArrayRef<StringRef> TupleType::getElementNames() const {
158   return getImplAs<ImplTy>()->getValue().second;
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // TypeType
163 //===----------------------------------------------------------------------===//
164 
get(Context & context)165 TypeType TypeType::get(Context &context) {
166   return context.getTypeUniquer().get<ImplTy>();
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // ValueType
171 //===----------------------------------------------------------------------===//
172 
get(Context & context)173 ValueType ValueType::get(Context &context) {
174   return context.getTypeUniquer().get<ImplTy>();
175 }
176