1 //===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains types defined by the TestDialect for testing various 10 // features of MLIR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestTypes.h" 15 #include "TestDialect.h" 16 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/ExtensibleDialect.h" 20 #include "mlir/IR/Types.h" 21 #include "llvm/ADT/Hashing.h" 22 #include "llvm/ADT/SetVector.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/Support/TypeSize.h" 25 #include <optional> 26 27 using namespace mlir; 28 using namespace test; 29 30 // Custom parser for SignednessSemantics. 31 static ParseResult 32 parseSignedness(AsmParser &parser, 33 TestIntegerType::SignednessSemantics &result) { 34 StringRef signStr; 35 auto loc = parser.getCurrentLocation(); 36 if (parser.parseKeyword(&signStr)) 37 return failure(); 38 if (signStr.equals_insensitive("u") || signStr.equals_insensitive("unsigned")) 39 result = TestIntegerType::SignednessSemantics::Unsigned; 40 else if (signStr.equals_insensitive("s") || 41 signStr.equals_insensitive("signed")) 42 result = TestIntegerType::SignednessSemantics::Signed; 43 else if (signStr.equals_insensitive("n") || 44 signStr.equals_insensitive("none")) 45 result = TestIntegerType::SignednessSemantics::Signless; 46 else 47 return parser.emitError(loc, "expected signed, unsigned, or none"); 48 return success(); 49 } 50 51 // Custom printer for SignednessSemantics. 52 static void printSignedness(AsmPrinter &printer, 53 const TestIntegerType::SignednessSemantics &ss) { 54 switch (ss) { 55 case TestIntegerType::SignednessSemantics::Unsigned: 56 printer << "unsigned"; 57 break; 58 case TestIntegerType::SignednessSemantics::Signed: 59 printer << "signed"; 60 break; 61 case TestIntegerType::SignednessSemantics::Signless: 62 printer << "none"; 63 break; 64 } 65 } 66 67 // The functions don't need to be in the header file, but need to be in the mlir 68 // namespace. Declare them here, then define them immediately below. Separating 69 // the declaration and definition adheres to the LLVM coding standards. 70 namespace test { 71 // FieldInfo is used as part of a parameter, so equality comparison is 72 // compulsory. 73 static bool operator==(const FieldInfo &a, const FieldInfo &b); 74 // FieldInfo is used as part of a parameter, so a hash will be computed. 75 static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT 76 } // namespace test 77 78 // FieldInfo is used as part of a parameter, so equality comparison is 79 // compulsory. 80 static bool test::operator==(const FieldInfo &a, const FieldInfo &b) { 81 return a.name == b.name && a.type == b.type; 82 } 83 84 // FieldInfo is used as part of a parameter, so a hash will be computed. 85 static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT 86 return llvm::hash_combine(fi.name, fi.type); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // TestCustomType 91 //===----------------------------------------------------------------------===// 92 93 static ParseResult parseCustomTypeA(AsmParser &parser, int &aResult) { 94 return parser.parseInteger(aResult); 95 } 96 97 static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; } 98 99 static ParseResult parseCustomTypeB(AsmParser &parser, int a, 100 std::optional<int> &bResult) { 101 if (a < 0) 102 return success(); 103 for (int i : llvm::seq(0, a)) 104 if (failed(parser.parseInteger(i))) 105 return failure(); 106 bResult.emplace(0); 107 return parser.parseInteger(*bResult); 108 } 109 110 static void printCustomTypeB(AsmPrinter &printer, int a, std::optional<int> b) { 111 if (a < 0) 112 return; 113 printer << ' '; 114 for (int i : llvm::seq(0, a)) 115 printer << i << ' '; 116 printer << *b; 117 } 118 119 static ParseResult parseFooString(AsmParser &parser, std::string &foo) { 120 std::string result; 121 if (parser.parseString(&result)) 122 return failure(); 123 foo = std::move(result); 124 return success(); 125 } 126 127 static void printFooString(AsmPrinter &printer, StringRef foo) { 128 printer << '"' << foo << '"'; 129 } 130 131 static ParseResult parseBarString(AsmParser &parser, StringRef foo) { 132 return parser.parseKeyword(foo); 133 } 134 135 static void printBarString(AsmPrinter &printer, StringRef foo) { 136 printer << foo; 137 } 138 //===----------------------------------------------------------------------===// 139 // Tablegen Generated Definitions 140 //===----------------------------------------------------------------------===// 141 142 #include "TestTypeInterfaces.cpp.inc" 143 #define GET_TYPEDEF_CLASSES 144 #include "TestTypeDefs.cpp.inc" 145 146 //===----------------------------------------------------------------------===// 147 // CompoundAType 148 //===----------------------------------------------------------------------===// 149 150 Type CompoundAType::parse(AsmParser &parser) { 151 int widthOfSomething; 152 Type oneType; 153 SmallVector<int, 4> arrayOfInts; 154 if (parser.parseLess() || parser.parseInteger(widthOfSomething) || 155 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || 156 parser.parseLSquare()) 157 return Type(); 158 159 int i; 160 while (!*parser.parseOptionalInteger(i)) { 161 arrayOfInts.push_back(i); 162 if (parser.parseOptionalComma()) 163 break; 164 } 165 166 if (parser.parseRSquare() || parser.parseGreater()) 167 return Type(); 168 169 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts); 170 } 171 void CompoundAType::print(AsmPrinter &printer) const { 172 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", ["; 173 auto intArray = getArrayOfInts(); 174 llvm::interleaveComma(intArray, printer); 175 printer << "]>"; 176 } 177 178 //===----------------------------------------------------------------------===// 179 // TestIntegerType 180 //===----------------------------------------------------------------------===// 181 182 // Example type validity checker. 183 LogicalResult 184 TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError, 185 unsigned width, 186 TestIntegerType::SignednessSemantics ss) { 187 if (width > 8) 188 return failure(); 189 return success(); 190 } 191 192 Type TestIntegerType::parse(AsmParser &parser) { 193 SignednessSemantics signedness; 194 int width; 195 if (parser.parseLess() || parseSignedness(parser, signedness) || 196 parser.parseComma() || parser.parseInteger(width) || 197 parser.parseGreater()) 198 return Type(); 199 Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); 200 return getChecked(loc, loc.getContext(), width, signedness); 201 } 202 203 void TestIntegerType::print(AsmPrinter &p) const { 204 p << "<"; 205 printSignedness(p, getSignedness()); 206 p << ", " << getWidth() << ">"; 207 } 208 209 //===----------------------------------------------------------------------===// 210 // TestStructType 211 //===----------------------------------------------------------------------===// 212 213 Type StructType::parse(AsmParser &p) { 214 SmallVector<FieldInfo, 4> parameters; 215 if (p.parseLess()) 216 return Type(); 217 while (succeeded(p.parseOptionalLBrace())) { 218 Type type; 219 StringRef name; 220 if (p.parseKeyword(&name) || p.parseComma() || p.parseType(type) || 221 p.parseRBrace()) 222 return Type(); 223 parameters.push_back(FieldInfo{name, type}); 224 if (p.parseOptionalComma()) 225 break; 226 } 227 if (p.parseGreater()) 228 return Type(); 229 return get(p.getContext(), parameters); 230 } 231 232 void StructType::print(AsmPrinter &p) const { 233 p << "<"; 234 llvm::interleaveComma(getFields(), p, [&](const FieldInfo &field) { 235 p << "{" << field.name << "," << field.type << "}"; 236 }); 237 p << ">"; 238 } 239 240 //===----------------------------------------------------------------------===// 241 // TestType 242 //===----------------------------------------------------------------------===// 243 244 void TestType::printTypeC(Location loc) const { 245 emitRemark(loc) << *this << " - TestC"; 246 } 247 248 //===----------------------------------------------------------------------===// 249 // TestTypeWithLayout 250 //===----------------------------------------------------------------------===// 251 252 Type TestTypeWithLayoutType::parse(AsmParser &parser) { 253 unsigned val; 254 if (parser.parseLess() || parser.parseInteger(val) || parser.parseGreater()) 255 return Type(); 256 return TestTypeWithLayoutType::get(parser.getContext(), val); 257 } 258 259 void TestTypeWithLayoutType::print(AsmPrinter &printer) const { 260 printer << "<" << getKey() << ">"; 261 } 262 263 llvm::TypeSize 264 TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout, 265 DataLayoutEntryListRef params) const { 266 return llvm::TypeSize::getFixed(extractKind(params, "size")); 267 } 268 269 uint64_t 270 TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout, 271 DataLayoutEntryListRef params) const { 272 return extractKind(params, "alignment"); 273 } 274 275 uint64_t TestTypeWithLayoutType::getPreferredAlignment( 276 const DataLayout &dataLayout, DataLayoutEntryListRef params) const { 277 return extractKind(params, "preferred"); 278 } 279 280 std::optional<uint64_t> 281 TestTypeWithLayoutType::getIndexBitwidth(const DataLayout &dataLayout, 282 DataLayoutEntryListRef params) const { 283 return extractKind(params, "index"); 284 } 285 286 bool TestTypeWithLayoutType::areCompatible( 287 DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const { 288 unsigned old = extractKind(oldLayout, "alignment"); 289 return old == 1 || extractKind(newLayout, "alignment") <= old; 290 } 291 292 LogicalResult 293 TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params, 294 Location loc) const { 295 for (DataLayoutEntryInterface entry : params) { 296 // This is for testing purposes only, so assert well-formedness. 297 assert(entry.isTypeEntry() && "unexpected identifier entry"); 298 assert( 299 llvm::isa<TestTypeWithLayoutType>(llvm::cast<Type>(entry.getKey())) && 300 "wrong type passed in"); 301 auto array = llvm::dyn_cast<ArrayAttr>(entry.getValue()); 302 assert(array && array.getValue().size() == 2 && 303 "expected array of two elements"); 304 auto kind = llvm::dyn_cast<StringAttr>(array.getValue().front()); 305 (void)kind; 306 assert(kind && 307 (kind.getValue() == "size" || kind.getValue() == "alignment" || 308 kind.getValue() == "preferred" || kind.getValue() == "index") && 309 "unexpected kind"); 310 assert(llvm::isa<IntegerAttr>(array.getValue().back())); 311 } 312 return success(); 313 } 314 315 uint64_t TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params, 316 StringRef expectedKind) const { 317 for (DataLayoutEntryInterface entry : params) { 318 ArrayRef<Attribute> pair = 319 llvm::cast<ArrayAttr>(entry.getValue()).getValue(); 320 StringRef kind = llvm::cast<StringAttr>(pair.front()).getValue(); 321 if (kind == expectedKind) 322 return llvm::cast<IntegerAttr>(pair.back()).getValue().getZExtValue(); 323 } 324 return 1; 325 } 326 327 //===----------------------------------------------------------------------===// 328 // Dynamic Types 329 //===----------------------------------------------------------------------===// 330 331 /// Define a singleton dynamic type. 332 static std::unique_ptr<DynamicTypeDefinition> 333 getSingletonDynamicType(TestDialect *testDialect) { 334 return DynamicTypeDefinition::get( 335 "dynamic_singleton", testDialect, 336 [](function_ref<InFlightDiagnostic()> emitError, 337 ArrayRef<Attribute> args) { 338 if (!args.empty()) { 339 emitError() << "expected 0 type arguments, but had " << args.size(); 340 return failure(); 341 } 342 return success(); 343 }); 344 } 345 346 /// Define a dynamic type representing a pair. 347 static std::unique_ptr<DynamicTypeDefinition> 348 getPairDynamicType(TestDialect *testDialect) { 349 return DynamicTypeDefinition::get( 350 "dynamic_pair", testDialect, 351 [](function_ref<InFlightDiagnostic()> emitError, 352 ArrayRef<Attribute> args) { 353 if (args.size() != 2) { 354 emitError() << "expected 2 type arguments, but had " << args.size(); 355 return failure(); 356 } 357 return success(); 358 }); 359 } 360 361 static std::unique_ptr<DynamicTypeDefinition> 362 getCustomAssemblyFormatDynamicType(TestDialect *testDialect) { 363 auto verifier = [](function_ref<InFlightDiagnostic()> emitError, 364 ArrayRef<Attribute> args) { 365 if (args.size() != 2) { 366 emitError() << "expected 2 type arguments, but had " << args.size(); 367 return failure(); 368 } 369 return success(); 370 }; 371 372 auto parser = [](AsmParser &parser, 373 llvm::SmallVectorImpl<Attribute> &parsedParams) { 374 Attribute leftAttr, rightAttr; 375 if (parser.parseLess() || parser.parseAttribute(leftAttr) || 376 parser.parseColon() || parser.parseAttribute(rightAttr) || 377 parser.parseGreater()) 378 return failure(); 379 parsedParams.push_back(leftAttr); 380 parsedParams.push_back(rightAttr); 381 return success(); 382 }; 383 384 auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) { 385 printer << "<" << params[0] << ":" << params[1] << ">"; 386 }; 387 388 return DynamicTypeDefinition::get("dynamic_custom_assembly_format", 389 testDialect, std::move(verifier), 390 std::move(parser), std::move(printer)); 391 } 392 393 //===----------------------------------------------------------------------===// 394 // TestDialect 395 //===----------------------------------------------------------------------===// 396 397 namespace { 398 399 struct PtrElementModel 400 : public LLVM::PointerElementTypeInterface::ExternalModel<PtrElementModel, 401 SimpleAType> {}; 402 } // namespace 403 404 void TestDialect::registerTypes() { 405 addTypes<TestRecursiveType, 406 #define GET_TYPEDEF_LIST 407 #include "TestTypeDefs.cpp.inc" 408 >(); 409 SimpleAType::attachInterface<PtrElementModel>(*getContext()); 410 411 registerDynamicType(getSingletonDynamicType(this)); 412 registerDynamicType(getPairDynamicType(this)); 413 registerDynamicType(getCustomAssemblyFormatDynamicType(this)); 414 } 415 416 Type TestDialect::parseType(DialectAsmParser &parser) const { 417 StringRef typeTag; 418 { 419 Type genType; 420 auto parseResult = generatedTypeParser(parser, &typeTag, genType); 421 if (parseResult.has_value()) 422 return genType; 423 } 424 425 { 426 Type dynType; 427 auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); 428 if (parseResult.has_value()) { 429 if (succeeded(parseResult.value())) 430 return dynType; 431 return Type(); 432 } 433 } 434 435 if (typeTag != "test_rec") { 436 parser.emitError(parser.getNameLoc()) << "unknown type!"; 437 return Type(); 438 } 439 440 StringRef name; 441 if (parser.parseLess() || parser.parseKeyword(&name)) 442 return Type(); 443 auto rec = TestRecursiveType::get(parser.getContext(), name); 444 445 FailureOr<AsmParser::CyclicParseReset> cyclicParse = 446 parser.tryStartCyclicParse(rec); 447 448 // If this type already has been parsed above in the stack, expect just the 449 // name. 450 if (failed(cyclicParse)) { 451 if (failed(parser.parseGreater())) 452 return Type(); 453 return rec; 454 } 455 456 // Otherwise, parse the body and update the type. 457 if (failed(parser.parseComma())) 458 return Type(); 459 Type subtype = parseType(parser); 460 if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) 461 return Type(); 462 463 return rec; 464 } 465 466 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { 467 if (succeeded(generatedTypePrinter(type, printer))) 468 return; 469 470 if (succeeded(printIfDynamicType(type, printer))) 471 return; 472 473 auto rec = llvm::cast<TestRecursiveType>(type); 474 475 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint = 476 printer.tryStartCyclicPrint(rec); 477 478 printer << "test_rec<" << rec.getName(); 479 if (succeeded(cyclicPrint)) { 480 printer << ", "; 481 printType(rec.getBody(), printer); 482 } 483 printer << ">"; 484 } 485 486 Type TestRecursiveAliasType::getBody() const { return getImpl()->body; } 487 488 void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); } 489 490 StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; } 491 492 Type TestRecursiveAliasType::parse(AsmParser &parser) { 493 StringRef name; 494 if (parser.parseLess() || parser.parseKeyword(&name)) 495 return Type(); 496 auto rec = TestRecursiveAliasType::get(parser.getContext(), name); 497 498 FailureOr<AsmParser::CyclicParseReset> cyclicParse = 499 parser.tryStartCyclicParse(rec); 500 501 // If this type already has been parsed above in the stack, expect just the 502 // name. 503 if (failed(cyclicParse)) { 504 if (failed(parser.parseGreater())) 505 return Type(); 506 return rec; 507 } 508 509 // Otherwise, parse the body and update the type. 510 if (failed(parser.parseComma())) 511 return Type(); 512 Type subtype; 513 if (parser.parseType(subtype)) 514 return nullptr; 515 if (!subtype || failed(parser.parseGreater())) 516 return Type(); 517 518 rec.setBody(subtype); 519 520 return rec; 521 } 522 523 void TestRecursiveAliasType::print(AsmPrinter &printer) const { 524 525 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint = 526 printer.tryStartCyclicPrint(*this); 527 528 printer << "<" << getName(); 529 if (succeeded(cyclicPrint)) { 530 printer << ", "; 531 printer << getBody(); 532 } 533 printer << ">"; 534 } 535 536 void TestTypeOpAsmTypeInterfaceType::getAsmName( 537 OpAsmSetNameFn setNameFn) const { 538 setNameFn("op_asm_type_interface"); 539 } 540