1 //===- Attribute.cpp - Attribute wrapper class ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Attribute wrapper to simplify using TableGen Record defining a MLIR 10 // Attribute. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/Format.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "llvm/TableGen/Record.h" 17 18 using namespace mlir; 19 using namespace mlir::tblgen; 20 21 using llvm::CodeInit; 22 using llvm::DefInit; 23 using llvm::Init; 24 using llvm::Record; 25 using llvm::StringInit; 26 27 // Returns the initializer's value as string if the given TableGen initializer 28 // is a code or string initializer. Returns the empty StringRef otherwise. 29 static StringRef getValueAsString(const Init *init) { 30 if (const auto *code = dyn_cast<CodeInit>(init)) 31 return code->getValue().trim(); 32 if (const auto *str = dyn_cast<StringInit>(init)) 33 return str->getValue().trim(); 34 return {}; 35 } 36 37 AttrConstraint::AttrConstraint(const Record *record) 38 : Constraint(Constraint::CK_Attr, record) { 39 assert(isSubClassOf("AttrConstraint") && 40 "must be subclass of TableGen 'AttrConstraint' class"); 41 } 42 43 bool AttrConstraint::isSubClassOf(StringRef className) const { 44 return def->isSubClassOf(className); 45 } 46 47 Attribute::Attribute(const Record *record) : AttrConstraint(record) { 48 assert(record->isSubClassOf("Attr") && 49 "must be subclass of TableGen 'Attr' class"); 50 } 51 52 Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {} 53 54 bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); } 55 56 bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); } 57 58 bool Attribute::isSymbolRefAttr() const { 59 StringRef defName = def->getName(); 60 if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr") 61 return true; 62 return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr"); 63 } 64 65 bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); } 66 67 StringRef Attribute::getStorageType() const { 68 const auto *init = def->getValueInit("storageType"); 69 auto type = getValueAsString(init); 70 if (type.empty()) 71 return "Attribute"; 72 return type; 73 } 74 75 StringRef Attribute::getReturnType() const { 76 const auto *init = def->getValueInit("returnType"); 77 return getValueAsString(init); 78 } 79 80 // Return the type constraint corresponding to the type of this attribute, or 81 // None if this is not a TypedAttr. 82 llvm::Optional<Type> Attribute::getValueType() const { 83 if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType"))) 84 return Type(defInit->getDef()); 85 return llvm::None; 86 } 87 88 StringRef Attribute::getConvertFromStorageCall() const { 89 const auto *init = def->getValueInit("convertFromStorage"); 90 return getValueAsString(init); 91 } 92 93 bool Attribute::isConstBuildable() const { 94 const auto *init = def->getValueInit("constBuilderCall"); 95 return !getValueAsString(init).empty(); 96 } 97 98 StringRef Attribute::getConstBuilderTemplate() const { 99 const auto *init = def->getValueInit("constBuilderCall"); 100 return getValueAsString(init); 101 } 102 103 Attribute Attribute::getBaseAttr() const { 104 if (const auto *defInit = 105 llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) { 106 return Attribute(defInit).getBaseAttr(); 107 } 108 return *this; 109 } 110 111 bool Attribute::hasDefaultValue() const { 112 const auto *init = def->getValueInit("defaultValue"); 113 return !getValueAsString(init).empty(); 114 } 115 116 StringRef Attribute::getDefaultValue() const { 117 const auto *init = def->getValueInit("defaultValue"); 118 return getValueAsString(init); 119 } 120 121 bool Attribute::isOptional() const { return def->getValueAsBit("isOptional"); } 122 123 StringRef Attribute::getAttrDefName() const { 124 if (def->isAnonymous()) { 125 return getBaseAttr().def->getName(); 126 } 127 return def->getName(); 128 } 129 130 StringRef Attribute::getDerivedCodeBody() const { 131 assert(isDerivedAttr() && "only derived attribute has 'body' field"); 132 return def->getValueAsString("body"); 133 } 134 135 Dialect Attribute::getDialect() const { 136 const llvm::RecordVal *record = def->getValue("dialect"); 137 if (record && record->getValue()) { 138 if (DefInit *init = dyn_cast<DefInit>(record->getValue())) 139 return Dialect(init->getDef()); 140 } 141 return Dialect(nullptr); 142 } 143 144 ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) { 145 assert(def->isSubClassOf("ConstantAttr") && 146 "must be subclass of TableGen 'ConstantAttr' class"); 147 } 148 149 Attribute ConstantAttr::getAttribute() const { 150 return Attribute(def->getValueAsDef("attr")); 151 } 152 153 StringRef ConstantAttr::getConstantValue() const { 154 return def->getValueAsString("value"); 155 } 156 157 EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) { 158 assert(isSubClassOf("EnumAttrCaseInfo") && 159 "must be subclass of TableGen 'EnumAttrInfo' class"); 160 } 161 162 EnumAttrCase::EnumAttrCase(const llvm::DefInit *init) 163 : EnumAttrCase(init->getDef()) {} 164 165 bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); } 166 167 StringRef EnumAttrCase::getSymbol() const { 168 return def->getValueAsString("symbol"); 169 } 170 171 StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); } 172 173 int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); } 174 175 const llvm::Record &EnumAttrCase::getDef() const { return *def; } 176 177 EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) { 178 assert(isSubClassOf("EnumAttrInfo") && 179 "must be subclass of TableGen 'EnumAttr' class"); 180 } 181 182 EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {} 183 184 EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {} 185 186 bool EnumAttr::classof(const Attribute *attr) { 187 return attr->isSubClassOf("EnumAttrInfo"); 188 } 189 190 bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); } 191 192 StringRef EnumAttr::getEnumClassName() const { 193 return def->getValueAsString("className"); 194 } 195 196 StringRef EnumAttr::getCppNamespace() const { 197 return def->getValueAsString("cppNamespace"); 198 } 199 200 StringRef EnumAttr::getUnderlyingType() const { 201 return def->getValueAsString("underlyingType"); 202 } 203 204 StringRef EnumAttr::getUnderlyingToSymbolFnName() const { 205 return def->getValueAsString("underlyingToSymbolFnName"); 206 } 207 208 StringRef EnumAttr::getStringToSymbolFnName() const { 209 return def->getValueAsString("stringToSymbolFnName"); 210 } 211 212 StringRef EnumAttr::getSymbolToStringFnName() const { 213 return def->getValueAsString("symbolToStringFnName"); 214 } 215 216 StringRef EnumAttr::getSymbolToStringFnRetType() const { 217 return def->getValueAsString("symbolToStringFnRetType"); 218 } 219 220 StringRef EnumAttr::getMaxEnumValFnName() const { 221 return def->getValueAsString("maxEnumValFnName"); 222 } 223 224 std::vector<EnumAttrCase> EnumAttr::getAllCases() const { 225 const auto *inits = def->getValueAsListInit("enumerants"); 226 227 std::vector<EnumAttrCase> cases; 228 cases.reserve(inits->size()); 229 230 for (const llvm::Init *init : *inits) { 231 cases.push_back(EnumAttrCase(cast<llvm::DefInit>(init))); 232 } 233 234 return cases; 235 } 236 237 StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) { 238 assert(def->isSubClassOf("StructFieldAttr") && 239 "must be subclass of TableGen 'StructFieldAttr' class"); 240 } 241 242 StructFieldAttr::StructFieldAttr(const llvm::Record &record) 243 : StructFieldAttr(&record) {} 244 245 StructFieldAttr::StructFieldAttr(const llvm::DefInit *init) 246 : StructFieldAttr(init->getDef()) {} 247 248 StringRef StructFieldAttr::getName() const { 249 return def->getValueAsString("name"); 250 } 251 252 Attribute StructFieldAttr::getType() const { 253 auto init = def->getValueInit("type"); 254 return Attribute(cast<llvm::DefInit>(init)); 255 } 256 257 StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) { 258 assert(isSubClassOf("StructAttr") && 259 "must be subclass of TableGen 'StructAttr' class"); 260 } 261 262 StructAttr::StructAttr(const llvm::DefInit *init) 263 : StructAttr(init->getDef()) {} 264 265 StringRef StructAttr::getStructClassName() const { 266 return def->getValueAsString("className"); 267 } 268 269 StringRef StructAttr::getCppNamespace() const { 270 Dialect dialect(def->getValueAsDef("dialect")); 271 return dialect.getCppNamespace(); 272 } 273 274 std::vector<StructFieldAttr> StructAttr::getAllFields() const { 275 std::vector<StructFieldAttr> attributes; 276 277 const auto *inits = def->getValueAsListInit("fields"); 278 attributes.reserve(inits->size()); 279 280 for (const llvm::Init *init : *inits) { 281 attributes.emplace_back(cast<llvm::DefInit>(init)); 282 } 283 284 return attributes; 285 } 286 287 const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface"; 288