1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains attributes defined by the TestDialect for testing various 10 // features of MLIR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestAttributes.h" 15 #include "TestDialect.h" 16 #include "TestTypes.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/ExtensibleDialect.h" 21 #include "mlir/IR/OpImplementation.h" 22 #include "mlir/IR/Types.h" 23 #include "llvm/ADT/APFloat.h" 24 #include "llvm/ADT/Hashing.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/ADT/bit.h" 28 #include "llvm/Support/ErrorHandling.h" 29 #include "llvm/Support/raw_ostream.h" 30 31 using namespace mlir; 32 using namespace test; 33 34 //===----------------------------------------------------------------------===// 35 // CompoundAAttr 36 //===----------------------------------------------------------------------===// 37 38 Attribute CompoundAAttr::parse(AsmParser &parser, Type type) { 39 int widthOfSomething; 40 Type oneType; 41 SmallVector<int, 4> arrayOfInts; 42 if (parser.parseLess() || parser.parseInteger(widthOfSomething) || 43 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || 44 parser.parseLSquare()) 45 return Attribute(); 46 47 int intVal; 48 while (!*parser.parseOptionalInteger(intVal)) { 49 arrayOfInts.push_back(intVal); 50 if (parser.parseOptionalComma()) 51 break; 52 } 53 54 if (parser.parseRSquare() || parser.parseGreater()) 55 return Attribute(); 56 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts); 57 } 58 59 void CompoundAAttr::print(AsmPrinter &printer) const { 60 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", ["; 61 llvm::interleaveComma(getArrayOfInts(), printer); 62 printer << "]>"; 63 } 64 65 //===----------------------------------------------------------------------===// 66 // CompoundAAttr 67 //===----------------------------------------------------------------------===// 68 69 Attribute TestDecimalShapeAttr::parse(AsmParser &parser, Type type) { 70 if (parser.parseLess()){ 71 return Attribute(); 72 } 73 SmallVector<int64_t> shape; 74 if (parser.parseOptionalGreater()) { 75 auto parseDecimal = [&]() { 76 shape.emplace_back(); 77 auto parseResult = parser.parseOptionalDecimalInteger(shape.back()); 78 if (!parseResult.has_value() || failed(*parseResult)) { 79 parser.emitError(parser.getCurrentLocation()) << "expected an integer"; 80 return failure(); 81 } 82 return success(); 83 }; 84 if (failed(parseDecimal())) { 85 return Attribute(); 86 } 87 while (failed(parser.parseOptionalGreater())) { 88 if (failed(parser.parseXInDimensionList()) || failed(parseDecimal())) { 89 return Attribute(); 90 } 91 } 92 } 93 return get(parser.getContext(), shape); 94 } 95 96 void TestDecimalShapeAttr::print(AsmPrinter &printer) const { 97 printer << "<"; 98 llvm::interleave(getShape(), printer, "x"); 99 printer << ">"; 100 } 101 102 Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) { 103 SmallVector<uint64_t> elements; 104 if (parser.parseLess() || parser.parseLSquare()) 105 return Attribute(); 106 uint64_t intVal; 107 while (succeeded(*parser.parseOptionalInteger(intVal))) { 108 elements.push_back(intVal); 109 if (parser.parseOptionalComma()) 110 break; 111 } 112 113 if (parser.parseRSquare() || parser.parseGreater()) 114 return Attribute(); 115 return parser.getChecked<TestI64ElementsAttr>( 116 parser.getContext(), llvm::cast<ShapedType>(type), elements); 117 } 118 119 void TestI64ElementsAttr::print(AsmPrinter &printer) const { 120 printer << "<["; 121 llvm::interleaveComma(getElements(), printer); 122 printer << "]>"; 123 } 124 125 LogicalResult 126 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 127 ShapedType type, ArrayRef<uint64_t> elements) { 128 if (type.getNumElements() != static_cast<int64_t>(elements.size())) { 129 return emitError() 130 << "number of elements does not match the provided shape type, got: " 131 << elements.size() << ", but expected: " << type.getNumElements(); 132 } 133 if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64)) 134 return emitError() << "expected single rank 64-bit shape type, but got: " 135 << type; 136 return success(); 137 } 138 139 LogicalResult TestAttrWithFormatAttr::verify( 140 function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two, 141 IntegerAttr three, ArrayRef<int> four, uint64_t five, ArrayRef<int> six, 142 ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrs) { 143 if (four.size() != static_cast<unsigned>(one)) 144 return emitError() << "expected 'one' to equal 'four.size()'"; 145 return success(); 146 } 147 148 //===----------------------------------------------------------------------===// 149 // Utility Functions for Generated Attributes 150 //===----------------------------------------------------------------------===// 151 152 static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) { 153 SmallVector<int> ints; 154 if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() { 155 ints.push_back(0); 156 return parser.parseInteger(ints.back()); 157 }) || 158 parser.parseRSquare()) 159 return failure(); 160 return ints; 161 } 162 163 static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) { 164 printer << '['; 165 llvm::interleaveComma(ints, printer); 166 printer << ']'; 167 } 168 169 //===----------------------------------------------------------------------===// 170 // TestSubElementsAccessAttr 171 //===----------------------------------------------------------------------===// 172 173 Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser, 174 ::mlir::Type type) { 175 Attribute first, second, third; 176 if (parser.parseLess() || parser.parseAttribute(first) || 177 parser.parseComma() || parser.parseAttribute(second) || 178 parser.parseComma() || parser.parseAttribute(third) || 179 parser.parseGreater()) { 180 return {}; 181 } 182 return get(parser.getContext(), first, second, third); 183 } 184 185 void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const { 186 printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird() 187 << ">"; 188 } 189 190 //===----------------------------------------------------------------------===// 191 // TestExtern1DI64ElementsAttr 192 //===----------------------------------------------------------------------===// 193 194 ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const { 195 if (auto *blob = getHandle().getBlob()) 196 return blob->getDataAs<uint64_t>(); 197 return std::nullopt; 198 } 199 200 //===----------------------------------------------------------------------===// 201 // TestCustomAnchorAttr 202 //===----------------------------------------------------------------------===// 203 204 static ParseResult parseTrueFalse(AsmParser &p, std::optional<int> &result) { 205 bool b; 206 if (p.parseInteger(b)) 207 return failure(); 208 result = b; 209 return success(); 210 } 211 212 static void printTrueFalse(AsmPrinter &p, std::optional<int> result) { 213 p << (*result ? "true" : "false"); 214 } 215 216 //===----------------------------------------------------------------------===// 217 // CopyCountAttr Implementation 218 //===----------------------------------------------------------------------===// 219 220 CopyCount::CopyCount(const CopyCount &rhs) : value(rhs.value) { 221 CopyCount::counter++; 222 } 223 224 CopyCount &CopyCount::operator=(const CopyCount &rhs) { 225 CopyCount::counter++; 226 value = rhs.value; 227 return *this; 228 } 229 230 int CopyCount::counter; 231 232 static bool operator==(const test::CopyCount &lhs, const test::CopyCount &rhs) { 233 return lhs.value == rhs.value; 234 } 235 236 llvm::raw_ostream &test::operator<<(llvm::raw_ostream &os, 237 const test::CopyCount &value) { 238 return os << value.value; 239 } 240 241 template <> 242 struct mlir::FieldParser<test::CopyCount> { 243 static FailureOr<test::CopyCount> parse(AsmParser &parser) { 244 std::string value; 245 if (parser.parseKeyword(value)) 246 return failure(); 247 return test::CopyCount(value); 248 } 249 }; 250 namespace test { 251 llvm::hash_code hash_value(const test::CopyCount ©Count) { 252 return llvm::hash_value(copyCount.value); 253 } 254 } // namespace test 255 256 //===----------------------------------------------------------------------===// 257 // TestConditionalAliasAttr 258 //===----------------------------------------------------------------------===// 259 260 /// Attempt to parse the conditionally-aliased string attribute as a keyword or 261 /// string, else try to parse an alias. 262 static ParseResult parseConditionalAlias(AsmParser &p, StringAttr &value) { 263 std::string str; 264 if (succeeded(p.parseOptionalKeywordOrString(&str))) { 265 value = StringAttr::get(p.getContext(), str); 266 return success(); 267 } 268 return p.parseAttribute(value); 269 } 270 271 /// Print the string attribute as an alias if it has one, otherwise print it as 272 /// a keyword if possible. 273 static void printConditionalAlias(AsmPrinter &p, StringAttr value) { 274 if (succeeded(p.printAlias(value))) 275 return; 276 p.printKeywordOrString(value); 277 } 278 279 //===----------------------------------------------------------------------===// 280 // Custom Float Attribute 281 //===----------------------------------------------------------------------===// 282 283 static void printCustomFloatAttr(AsmPrinter &p, StringAttr typeStrAttr, 284 APFloat value) { 285 p << typeStrAttr << " : " << value; 286 } 287 288 static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr, 289 FailureOr<APFloat> &value) { 290 291 std::string str; 292 if (p.parseString(&str)) 293 return failure(); 294 295 typeStrAttr = StringAttr::get(p.getContext(), str); 296 297 if (p.parseColon()) 298 return failure(); 299 300 const llvm::fltSemantics *semantics; 301 if (str == "float") 302 semantics = &llvm::APFloat::IEEEsingle(); 303 else if (str == "double") 304 semantics = &llvm::APFloat::IEEEdouble(); 305 else if (str == "fp80") 306 semantics = &llvm::APFloat::x87DoubleExtended(); 307 else 308 return p.emitError(p.getCurrentLocation(), "unknown float type, expected " 309 "'float', 'double' or 'fp80'"); 310 311 APFloat parsedValue(0.0); 312 if (p.parseFloat(*semantics, parsedValue)) 313 return failure(); 314 315 value.emplace(parsedValue); 316 return success(); 317 } 318 319 //===----------------------------------------------------------------------===// 320 // Tablegen Generated Definitions 321 //===----------------------------------------------------------------------===// 322 323 #include "TestAttrInterfaces.cpp.inc" 324 #include "TestOpEnums.cpp.inc" 325 #define GET_ATTRDEF_CLASSES 326 #include "TestAttrDefs.cpp.inc" 327 328 //===----------------------------------------------------------------------===// 329 // Dynamic Attributes 330 //===----------------------------------------------------------------------===// 331 332 /// Define a singleton dynamic attribute. 333 static std::unique_ptr<DynamicAttrDefinition> 334 getDynamicSingletonAttr(TestDialect *testDialect) { 335 return DynamicAttrDefinition::get( 336 "dynamic_singleton", testDialect, 337 [](function_ref<InFlightDiagnostic()> emitError, 338 ArrayRef<Attribute> args) { 339 if (!args.empty()) { 340 emitError() << "expected 0 attribute arguments, but had " 341 << args.size(); 342 return failure(); 343 } 344 return success(); 345 }); 346 } 347 348 /// Define a dynamic attribute representing a pair or attributes. 349 static std::unique_ptr<DynamicAttrDefinition> 350 getDynamicPairAttr(TestDialect *testDialect) { 351 return DynamicAttrDefinition::get( 352 "dynamic_pair", testDialect, 353 [](function_ref<InFlightDiagnostic()> emitError, 354 ArrayRef<Attribute> args) { 355 if (args.size() != 2) { 356 emitError() << "expected 2 attribute arguments, but had " 357 << args.size(); 358 return failure(); 359 } 360 return success(); 361 }); 362 } 363 364 static std::unique_ptr<DynamicAttrDefinition> 365 getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) { 366 auto verifier = [](function_ref<InFlightDiagnostic()> emitError, 367 ArrayRef<Attribute> args) { 368 if (args.size() != 2) { 369 emitError() << "expected 2 attribute arguments, but had " << args.size(); 370 return failure(); 371 } 372 return success(); 373 }; 374 375 auto parser = [](AsmParser &parser, 376 llvm::SmallVectorImpl<Attribute> &parsedParams) { 377 Attribute leftAttr, rightAttr; 378 if (parser.parseLess() || parser.parseAttribute(leftAttr) || 379 parser.parseColon() || parser.parseAttribute(rightAttr) || 380 parser.parseGreater()) 381 return failure(); 382 parsedParams.push_back(leftAttr); 383 parsedParams.push_back(rightAttr); 384 return success(); 385 }; 386 387 auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) { 388 printer << "<" << params[0] << ":" << params[1] << ">"; 389 }; 390 391 return DynamicAttrDefinition::get("dynamic_custom_assembly_format", 392 testDialect, std::move(verifier), 393 std::move(parser), std::move(printer)); 394 } 395 396 //===----------------------------------------------------------------------===// 397 // TestDialect 398 //===----------------------------------------------------------------------===// 399 400 void TestDialect::registerAttributes() { 401 addAttributes< 402 #define GET_ATTRDEF_LIST 403 #include "TestAttrDefs.cpp.inc" 404 >(); 405 registerDynamicAttr(getDynamicSingletonAttr(this)); 406 registerDynamicAttr(getDynamicPairAttr(this)); 407 registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this)); 408 } 409