1 //===- Attribute.cpp - Attribute wrapper class ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Attribute wrapper to simplify using TableGen Record defining a MLIR 10 // Attribute. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/Format.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "llvm/TableGen/Record.h" 17 18 using namespace mlir; 19 using namespace mlir::tblgen; 20 21 using llvm::DefInit; 22 using llvm::Init; 23 using llvm::Record; 24 using llvm::StringInit; 25 26 // Returns the initializer's value as string if the given TableGen initializer 27 // is a code or string initializer. Returns the empty StringRef otherwise. 28 static StringRef getValueAsString(const Init *init) { 29 if (const auto *str = dyn_cast<StringInit>(init)) 30 return str->getValue().trim(); 31 return {}; 32 } 33 34 bool AttrConstraint::isSubClassOf(StringRef className) const { 35 return def->isSubClassOf(className); 36 } 37 38 Attribute::Attribute(const Record *record) : AttrConstraint(record) { 39 assert(record->isSubClassOf("Attr") && 40 "must be subclass of TableGen 'Attr' class"); 41 } 42 43 Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {} 44 45 bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); } 46 47 bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); } 48 49 bool Attribute::isSymbolRefAttr() const { 50 StringRef defName = def->getName(); 51 if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr") 52 return true; 53 return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr"); 54 } 55 56 bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); } 57 58 StringRef Attribute::getStorageType() const { 59 const auto *init = def->getValueInit("storageType"); 60 auto type = getValueAsString(init); 61 if (type.empty()) 62 return "::mlir::Attribute"; 63 return type; 64 } 65 66 StringRef Attribute::getReturnType() const { 67 const auto *init = def->getValueInit("returnType"); 68 return getValueAsString(init); 69 } 70 71 // Return the type constraint corresponding to the type of this attribute, or 72 // std::nullopt if this is not a TypedAttr. 73 std::optional<Type> Attribute::getValueType() const { 74 if (const auto *defInit = dyn_cast<DefInit>(def->getValueInit("valueType"))) 75 return Type(defInit->getDef()); 76 return std::nullopt; 77 } 78 79 StringRef Attribute::getConvertFromStorageCall() const { 80 const auto *init = def->getValueInit("convertFromStorage"); 81 return getValueAsString(init); 82 } 83 84 bool Attribute::isConstBuildable() const { 85 const auto *init = def->getValueInit("constBuilderCall"); 86 return !getValueAsString(init).empty(); 87 } 88 89 StringRef Attribute::getConstBuilderTemplate() const { 90 const auto *init = def->getValueInit("constBuilderCall"); 91 return getValueAsString(init); 92 } 93 94 Attribute Attribute::getBaseAttr() const { 95 if (const auto *defInit = dyn_cast<DefInit>(def->getValueInit("baseAttr"))) { 96 return Attribute(defInit).getBaseAttr(); 97 } 98 return *this; 99 } 100 101 bool Attribute::hasDefaultValue() const { 102 const auto *init = def->getValueInit("defaultValue"); 103 return !getValueAsString(init).empty(); 104 } 105 106 StringRef Attribute::getDefaultValue() const { 107 const auto *init = def->getValueInit("defaultValue"); 108 return getValueAsString(init); 109 } 110 111 bool Attribute::isOptional() const { return def->getValueAsBit("isOptional"); } 112 113 StringRef Attribute::getAttrDefName() const { 114 if (def->isAnonymous()) { 115 return getBaseAttr().def->getName(); 116 } 117 return def->getName(); 118 } 119 120 StringRef Attribute::getDerivedCodeBody() const { 121 assert(isDerivedAttr() && "only derived attribute has 'body' field"); 122 return def->getValueAsString("body"); 123 } 124 125 Dialect Attribute::getDialect() const { 126 const llvm::RecordVal *record = def->getValue("dialect"); 127 if (record && record->getValue()) { 128 if (const DefInit *init = dyn_cast<DefInit>(record->getValue())) 129 return Dialect(init->getDef()); 130 } 131 return Dialect(nullptr); 132 } 133 134 const Record &Attribute::getDef() const { return *def; } 135 136 ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) { 137 assert(def->isSubClassOf("ConstantAttr") && 138 "must be subclass of TableGen 'ConstantAttr' class"); 139 } 140 141 Attribute ConstantAttr::getAttribute() const { 142 return Attribute(def->getValueAsDef("attr")); 143 } 144 145 StringRef ConstantAttr::getConstantValue() const { 146 return def->getValueAsString("value"); 147 } 148 149 EnumAttrCase::EnumAttrCase(const Record *record) : Attribute(record) { 150 assert(isSubClassOf("EnumAttrCaseInfo") && 151 "must be subclass of TableGen 'EnumAttrInfo' class"); 152 } 153 154 EnumAttrCase::EnumAttrCase(const DefInit *init) 155 : EnumAttrCase(init->getDef()) {} 156 157 StringRef EnumAttrCase::getSymbol() const { 158 return def->getValueAsString("symbol"); 159 } 160 161 StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); } 162 163 int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); } 164 165 const Record &EnumAttrCase::getDef() const { return *def; } 166 167 EnumAttr::EnumAttr(const Record *record) : Attribute(record) { 168 assert(isSubClassOf("EnumAttrInfo") && 169 "must be subclass of TableGen 'EnumAttr' class"); 170 } 171 172 EnumAttr::EnumAttr(const Record &record) : Attribute(&record) {} 173 174 EnumAttr::EnumAttr(const DefInit *init) : EnumAttr(init->getDef()) {} 175 176 bool EnumAttr::classof(const Attribute *attr) { 177 return attr->isSubClassOf("EnumAttrInfo"); 178 } 179 180 bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); } 181 182 StringRef EnumAttr::getEnumClassName() const { 183 return def->getValueAsString("className"); 184 } 185 186 StringRef EnumAttr::getCppNamespace() const { 187 return def->getValueAsString("cppNamespace"); 188 } 189 190 StringRef EnumAttr::getUnderlyingType() const { 191 return def->getValueAsString("underlyingType"); 192 } 193 194 StringRef EnumAttr::getUnderlyingToSymbolFnName() const { 195 return def->getValueAsString("underlyingToSymbolFnName"); 196 } 197 198 StringRef EnumAttr::getStringToSymbolFnName() const { 199 return def->getValueAsString("stringToSymbolFnName"); 200 } 201 202 StringRef EnumAttr::getSymbolToStringFnName() const { 203 return def->getValueAsString("symbolToStringFnName"); 204 } 205 206 StringRef EnumAttr::getSymbolToStringFnRetType() const { 207 return def->getValueAsString("symbolToStringFnRetType"); 208 } 209 210 StringRef EnumAttr::getMaxEnumValFnName() const { 211 return def->getValueAsString("maxEnumValFnName"); 212 } 213 214 std::vector<EnumAttrCase> EnumAttr::getAllCases() const { 215 const auto *inits = def->getValueAsListInit("enumerants"); 216 217 std::vector<EnumAttrCase> cases; 218 cases.reserve(inits->size()); 219 220 for (const Init *init : *inits) { 221 cases.emplace_back(cast<DefInit>(init)); 222 } 223 224 return cases; 225 } 226 227 bool EnumAttr::genSpecializedAttr() const { 228 return def->getValueAsBit("genSpecializedAttr"); 229 } 230 231 const Record *EnumAttr::getBaseAttrClass() const { 232 return def->getValueAsDef("baseAttrClass"); 233 } 234 235 StringRef EnumAttr::getSpecializedAttrClassName() const { 236 return def->getValueAsString("specializedAttrClassName"); 237 } 238 239 bool EnumAttr::printBitEnumPrimaryGroups() const { 240 return def->getValueAsBit("printBitEnumPrimaryGroups"); 241 } 242 243 const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface"; 244