1 //===- Types.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Tools/PDLL/AST/Types.h"
10 #include "TypeDetail.h"
11 #include "mlir/Tools/PDLL/AST/Context.h"
12 #include <optional>
13
14 using namespace mlir;
15 using namespace mlir::pdll;
16 using namespace mlir::pdll::ast;
17
18 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::AttributeTypeStorage)
MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ConstraintTypeStorage)19 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ConstraintTypeStorage)
20 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::OperationTypeStorage)
21 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::RangeTypeStorage)
22 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::RewriteTypeStorage)
23 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TupleTypeStorage)
24 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TypeTypeStorage)
25 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ValueTypeStorage)
26
27 //===----------------------------------------------------------------------===//
28 // Type
29 //===----------------------------------------------------------------------===//
30
31 TypeID Type::getTypeID() const { return impl->typeID; }
32
refineWith(Type other) const33 Type Type::refineWith(Type other) const {
34 if (*this == other)
35 return *this;
36
37 // Operation types are compatible if the operation names don't conflict.
38 if (auto opTy = mlir::dyn_cast<OperationType>(*this)) {
39 auto otherOpTy = mlir::dyn_cast<ast::OperationType>(other);
40 if (!otherOpTy)
41 return nullptr;
42 if (!otherOpTy.getName())
43 return *this;
44 if (!opTy.getName())
45 return other;
46
47 return nullptr;
48 }
49
50 return nullptr;
51 }
52
53 //===----------------------------------------------------------------------===//
54 // AttributeType
55 //===----------------------------------------------------------------------===//
56
get(Context & context)57 AttributeType AttributeType::get(Context &context) {
58 return context.getTypeUniquer().get<ImplTy>();
59 }
60
61 //===----------------------------------------------------------------------===//
62 // ConstraintType
63 //===----------------------------------------------------------------------===//
64
get(Context & context)65 ConstraintType ConstraintType::get(Context &context) {
66 return context.getTypeUniquer().get<ImplTy>();
67 }
68
69 //===----------------------------------------------------------------------===//
70 // OperationType
71 //===----------------------------------------------------------------------===//
72
get(Context & context,std::optional<StringRef> name,const ods::Operation * odsOp)73 OperationType OperationType::get(Context &context,
74 std::optional<StringRef> name,
75 const ods::Operation *odsOp) {
76 return context.getTypeUniquer().get<ImplTy>(
77 /*initFn=*/function_ref<void(ImplTy *)>(),
78 std::make_pair(name.value_or(""), odsOp));
79 }
80
getName() const81 std::optional<StringRef> OperationType::getName() const {
82 StringRef name = getImplAs<ImplTy>()->getValue().first;
83 return name.empty() ? std::optional<StringRef>()
84 : std::optional<StringRef>(name);
85 }
86
getODSOperation() const87 const ods::Operation *OperationType::getODSOperation() const {
88 return getImplAs<ImplTy>()->getValue().second;
89 }
90
91 //===----------------------------------------------------------------------===//
92 // RangeType
93 //===----------------------------------------------------------------------===//
94
get(Context & context,Type elementType)95 RangeType RangeType::get(Context &context, Type elementType) {
96 return context.getTypeUniquer().get<ImplTy>(
97 /*initFn=*/function_ref<void(ImplTy *)>(), elementType);
98 }
99
getElementType() const100 Type RangeType::getElementType() const {
101 return getImplAs<ImplTy>()->getValue();
102 }
103
104 //===----------------------------------------------------------------------===//
105 // TypeRangeType
106
classof(Type type)107 bool TypeRangeType::classof(Type type) {
108 RangeType range = mlir::dyn_cast<RangeType>(type);
109 return range && mlir::isa<TypeType>(range.getElementType());
110 }
111
get(Context & context)112 TypeRangeType TypeRangeType::get(Context &context) {
113 return mlir::cast<TypeRangeType>(
114 RangeType::get(context, TypeType::get(context)));
115 }
116
117 //===----------------------------------------------------------------------===//
118 // ValueRangeType
119
classof(Type type)120 bool ValueRangeType::classof(Type type) {
121 RangeType range = mlir::dyn_cast<RangeType>(type);
122 return range && mlir::isa<ValueType>(range.getElementType());
123 }
124
get(Context & context)125 ValueRangeType ValueRangeType::get(Context &context) {
126 return mlir::cast<ValueRangeType>(
127 RangeType::get(context, ValueType::get(context)));
128 }
129
130 //===----------------------------------------------------------------------===//
131 // RewriteType
132 //===----------------------------------------------------------------------===//
133
get(Context & context)134 RewriteType RewriteType::get(Context &context) {
135 return context.getTypeUniquer().get<ImplTy>();
136 }
137
138 //===----------------------------------------------------------------------===//
139 // TupleType
140 //===----------------------------------------------------------------------===//
141
get(Context & context,ArrayRef<Type> elementTypes,ArrayRef<StringRef> elementNames)142 TupleType TupleType::get(Context &context, ArrayRef<Type> elementTypes,
143 ArrayRef<StringRef> elementNames) {
144 assert(elementTypes.size() == elementNames.size());
145 return context.getTypeUniquer().get<ImplTy>(
146 /*initFn=*/function_ref<void(ImplTy *)>(), elementTypes, elementNames);
147 }
get(Context & context,ArrayRef<Type> elementTypes)148 TupleType TupleType::get(Context &context, ArrayRef<Type> elementTypes) {
149 SmallVector<StringRef> elementNames(elementTypes.size());
150 return get(context, elementTypes, elementNames);
151 }
152
getElementTypes() const153 ArrayRef<Type> TupleType::getElementTypes() const {
154 return getImplAs<ImplTy>()->getValue().first;
155 }
156
getElementNames() const157 ArrayRef<StringRef> TupleType::getElementNames() const {
158 return getImplAs<ImplTy>()->getValue().second;
159 }
160
161 //===----------------------------------------------------------------------===//
162 // TypeType
163 //===----------------------------------------------------------------------===//
164
get(Context & context)165 TypeType TypeType::get(Context &context) {
166 return context.getTypeUniquer().get<ImplTy>();
167 }
168
169 //===----------------------------------------------------------------------===//
170 // ValueType
171 //===----------------------------------------------------------------------===//
172
get(Context & context)173 ValueType ValueType::get(Context &context) {
174 return context.getTypeUniquer().get<ImplTy>();
175 }
176