xref: /llvm-project/mlir/lib/TableGen/Attribute.cpp (revision c0958b7b4c6a31b0b89462c3ee770e486d4eb535)
1 //===- Attribute.cpp - Attribute wrapper class ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Attribute wrapper to simplify using TableGen Record defining a MLIR
10 // Attribute.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/TableGen/Record.h"
17 
18 using namespace mlir;
19 using namespace mlir::tblgen;
20 
21 using llvm::CodeInit;
22 using llvm::DefInit;
23 using llvm::Init;
24 using llvm::Record;
25 using llvm::StringInit;
26 
27 // Returns the initializer's value as string if the given TableGen initializer
28 // is a code or string initializer. Returns the empty StringRef otherwise.
29 static StringRef getValueAsString(const Init *init) {
30   if (const auto *code = dyn_cast<CodeInit>(init))
31     return code->getValue().trim();
32   if (const auto *str = dyn_cast<StringInit>(init))
33     return str->getValue().trim();
34   return {};
35 }
36 
37 AttrConstraint::AttrConstraint(const Record *record)
38     : Constraint(Constraint::CK_Attr, record) {
39   assert(isSubClassOf("AttrConstraint") &&
40          "must be subclass of TableGen 'AttrConstraint' class");
41 }
42 
43 bool AttrConstraint::isSubClassOf(StringRef className) const {
44   return def->isSubClassOf(className);
45 }
46 
47 Attribute::Attribute(const Record *record) : AttrConstraint(record) {
48   assert(record->isSubClassOf("Attr") &&
49          "must be subclass of TableGen 'Attr' class");
50 }
51 
52 Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {}
53 
54 bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); }
55 
56 bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); }
57 
58 bool Attribute::isSymbolRefAttr() const {
59   StringRef defName = def->getName();
60   if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr")
61     return true;
62   return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr");
63 }
64 
65 bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
66 
67 StringRef Attribute::getStorageType() const {
68   const auto *init = def->getValueInit("storageType");
69   auto type = getValueAsString(init);
70   if (type.empty())
71     return "Attribute";
72   return type;
73 }
74 
75 StringRef Attribute::getReturnType() const {
76   const auto *init = def->getValueInit("returnType");
77   return getValueAsString(init);
78 }
79 
80 // Return the type constraint corresponding to the type of this attribute, or
81 // None if this is not a TypedAttr.
82 llvm::Optional<Type> Attribute::getValueType() const {
83   if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType")))
84     return Type(defInit->getDef());
85   return llvm::None;
86 }
87 
88 StringRef Attribute::getConvertFromStorageCall() const {
89   const auto *init = def->getValueInit("convertFromStorage");
90   return getValueAsString(init);
91 }
92 
93 bool Attribute::isConstBuildable() const {
94   const auto *init = def->getValueInit("constBuilderCall");
95   return !getValueAsString(init).empty();
96 }
97 
98 StringRef Attribute::getConstBuilderTemplate() const {
99   const auto *init = def->getValueInit("constBuilderCall");
100   return getValueAsString(init);
101 }
102 
103 Attribute Attribute::getBaseAttr() const {
104   if (const auto *defInit =
105           llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
106     return Attribute(defInit).getBaseAttr();
107   }
108   return *this;
109 }
110 
111 bool Attribute::hasDefaultValue() const {
112   const auto *init = def->getValueInit("defaultValue");
113   return !getValueAsString(init).empty();
114 }
115 
116 StringRef Attribute::getDefaultValue() const {
117   const auto *init = def->getValueInit("defaultValue");
118   return getValueAsString(init);
119 }
120 
121 bool Attribute::isOptional() const { return def->getValueAsBit("isOptional"); }
122 
123 StringRef Attribute::getAttrDefName() const {
124   if (def->isAnonymous()) {
125     return getBaseAttr().def->getName();
126   }
127   return def->getName();
128 }
129 
130 StringRef Attribute::getDerivedCodeBody() const {
131   assert(isDerivedAttr() && "only derived attribute has 'body' field");
132   return def->getValueAsString("body");
133 }
134 
135 Dialect Attribute::getDialect() const {
136   const llvm::RecordVal *record = def->getValue("dialect");
137   if (record && record->getValue()) {
138     if (DefInit *init = dyn_cast<DefInit>(record->getValue()))
139       return Dialect(init->getDef());
140   }
141   return Dialect(nullptr);
142 }
143 
144 ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
145   assert(def->isSubClassOf("ConstantAttr") &&
146          "must be subclass of TableGen 'ConstantAttr' class");
147 }
148 
149 Attribute ConstantAttr::getAttribute() const {
150   return Attribute(def->getValueAsDef("attr"));
151 }
152 
153 StringRef ConstantAttr::getConstantValue() const {
154   return def->getValueAsString("value");
155 }
156 
157 EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
158   assert(isSubClassOf("EnumAttrCaseInfo") &&
159          "must be subclass of TableGen 'EnumAttrInfo' class");
160 }
161 
162 EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
163     : EnumAttrCase(init->getDef()) {}
164 
165 bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); }
166 
167 StringRef EnumAttrCase::getSymbol() const {
168   return def->getValueAsString("symbol");
169 }
170 
171 StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
172 
173 int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
174 
175 const llvm::Record &EnumAttrCase::getDef() const { return *def; }
176 
177 EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
178   assert(isSubClassOf("EnumAttrInfo") &&
179          "must be subclass of TableGen 'EnumAttr' class");
180 }
181 
182 EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
183 
184 EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {}
185 
186 bool EnumAttr::classof(const Attribute *attr) {
187   return attr->isSubClassOf("EnumAttrInfo");
188 }
189 
190 bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
191 
192 StringRef EnumAttr::getEnumClassName() const {
193   return def->getValueAsString("className");
194 }
195 
196 StringRef EnumAttr::getCppNamespace() const {
197   return def->getValueAsString("cppNamespace");
198 }
199 
200 StringRef EnumAttr::getUnderlyingType() const {
201   return def->getValueAsString("underlyingType");
202 }
203 
204 StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
205   return def->getValueAsString("underlyingToSymbolFnName");
206 }
207 
208 StringRef EnumAttr::getStringToSymbolFnName() const {
209   return def->getValueAsString("stringToSymbolFnName");
210 }
211 
212 StringRef EnumAttr::getSymbolToStringFnName() const {
213   return def->getValueAsString("symbolToStringFnName");
214 }
215 
216 StringRef EnumAttr::getSymbolToStringFnRetType() const {
217   return def->getValueAsString("symbolToStringFnRetType");
218 }
219 
220 StringRef EnumAttr::getMaxEnumValFnName() const {
221   return def->getValueAsString("maxEnumValFnName");
222 }
223 
224 std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
225   const auto *inits = def->getValueAsListInit("enumerants");
226 
227   std::vector<EnumAttrCase> cases;
228   cases.reserve(inits->size());
229 
230   for (const llvm::Init *init : *inits) {
231     cases.push_back(EnumAttrCase(cast<llvm::DefInit>(init)));
232   }
233 
234   return cases;
235 }
236 
237 StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) {
238   assert(def->isSubClassOf("StructFieldAttr") &&
239          "must be subclass of TableGen 'StructFieldAttr' class");
240 }
241 
242 StructFieldAttr::StructFieldAttr(const llvm::Record &record)
243     : StructFieldAttr(&record) {}
244 
245 StructFieldAttr::StructFieldAttr(const llvm::DefInit *init)
246     : StructFieldAttr(init->getDef()) {}
247 
248 StringRef StructFieldAttr::getName() const {
249   return def->getValueAsString("name");
250 }
251 
252 Attribute StructFieldAttr::getType() const {
253   auto init = def->getValueInit("type");
254   return Attribute(cast<llvm::DefInit>(init));
255 }
256 
257 StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) {
258   assert(isSubClassOf("StructAttr") &&
259          "must be subclass of TableGen 'StructAttr' class");
260 }
261 
262 StructAttr::StructAttr(const llvm::DefInit *init)
263     : StructAttr(init->getDef()) {}
264 
265 StringRef StructAttr::getStructClassName() const {
266   return def->getValueAsString("className");
267 }
268 
269 StringRef StructAttr::getCppNamespace() const {
270   Dialect dialect(def->getValueAsDef("dialect"));
271   return dialect.getCppNamespace();
272 }
273 
274 std::vector<StructFieldAttr> StructAttr::getAllFields() const {
275   std::vector<StructFieldAttr> attributes;
276 
277   const auto *inits = def->getValueAsListInit("fields");
278   attributes.reserve(inits->size());
279 
280   for (const llvm::Init *init : *inits) {
281     attributes.emplace_back(cast<llvm::DefInit>(init));
282   }
283 
284   return attributes;
285 }
286 
287 const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";
288