1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the MLIR AsmPrinter class, which is used to implement 10 // the various print() methods on the core IR objects. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/AffineExpr.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/AsmState.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinAttributes.h" 20 #include "mlir/IR/BuiltinDialect.h" 21 #include "mlir/IR/BuiltinTypeInterfaces.h" 22 #include "mlir/IR/BuiltinTypes.h" 23 #include "mlir/IR/Dialect.h" 24 #include "mlir/IR/DialectImplementation.h" 25 #include "mlir/IR/DialectResourceBlobManager.h" 26 #include "mlir/IR/IntegerSet.h" 27 #include "mlir/IR/MLIRContext.h" 28 #include "mlir/IR/OpImplementation.h" 29 #include "mlir/IR/Operation.h" 30 #include "mlir/IR/Verifier.h" 31 #include "llvm/ADT/APFloat.h" 32 #include "llvm/ADT/ArrayRef.h" 33 #include "llvm/ADT/DenseMap.h" 34 #include "llvm/ADT/MapVector.h" 35 #include "llvm/ADT/STLExtras.h" 36 #include "llvm/ADT/ScopeExit.h" 37 #include "llvm/ADT/ScopedHashTable.h" 38 #include "llvm/ADT/SetVector.h" 39 #include "llvm/ADT/SmallString.h" 40 #include "llvm/ADT/StringExtras.h" 41 #include "llvm/ADT/StringSet.h" 42 #include "llvm/ADT/TypeSwitch.h" 43 #include "llvm/Support/CommandLine.h" 44 #include "llvm/Support/Debug.h" 45 #include "llvm/Support/Endian.h" 46 #include "llvm/Support/ManagedStatic.h" 47 #include "llvm/Support/Regex.h" 48 #include "llvm/Support/SaveAndRestore.h" 49 #include "llvm/Support/Threading.h" 50 #include "llvm/Support/raw_ostream.h" 51 #include <type_traits> 52 53 #include <optional> 54 #include <tuple> 55 56 using namespace mlir; 57 using namespace mlir::detail; 58 59 #define DEBUG_TYPE "mlir-asm-printer" 60 61 void OperationName::print(raw_ostream &os) const { os << getStringRef(); } 62 63 void OperationName::dump() const { print(llvm::errs()); } 64 65 //===--------------------------------------------------------------------===// 66 // AsmParser 67 //===--------------------------------------------------------------------===// 68 69 AsmParser::~AsmParser() = default; 70 DialectAsmParser::~DialectAsmParser() = default; 71 OpAsmParser::~OpAsmParser() = default; 72 73 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); } 74 75 /// Parse a type list. 76 /// This is out-of-line to work-around 77 /// https://github.com/llvm/llvm-project/issues/62918 78 ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) { 79 return parseCommaSeparatedList( 80 [&]() { return parseType(result.emplace_back()); }); 81 } 82 83 //===----------------------------------------------------------------------===// 84 // DialectAsmPrinter 85 //===----------------------------------------------------------------------===// 86 87 DialectAsmPrinter::~DialectAsmPrinter() = default; 88 89 //===----------------------------------------------------------------------===// 90 // OpAsmPrinter 91 //===----------------------------------------------------------------------===// 92 93 OpAsmPrinter::~OpAsmPrinter() = default; 94 95 void OpAsmPrinter::printFunctionalType(Operation *op) { 96 auto &os = getStream(); 97 os << '('; 98 llvm::interleaveComma(op->getOperands(), os, [&](Value operand) { 99 // Print the types of null values as <<NULL TYPE>>. 100 *this << (operand ? operand.getType() : Type()); 101 }); 102 os << ") -> "; 103 104 // Print the result list. We don't parenthesize single result types unless 105 // it is a function (avoiding a grammar ambiguity). 106 bool wrapped = op->getNumResults() != 1; 107 if (!wrapped && op->getResult(0).getType() && 108 llvm::isa<FunctionType>(op->getResult(0).getType())) 109 wrapped = true; 110 111 if (wrapped) 112 os << '('; 113 114 llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) { 115 // Print the types of null values as <<NULL TYPE>>. 116 *this << (result ? result.getType() : Type()); 117 }); 118 119 if (wrapped) 120 os << ')'; 121 } 122 123 //===----------------------------------------------------------------------===// 124 // Operation OpAsm interface. 125 //===----------------------------------------------------------------------===// 126 127 /// The OpAsmOpInterface, see OpAsmInterface.td for more details. 128 #include "mlir/IR/OpAsmOpInterface.cpp.inc" 129 #include "mlir/IR/OpAsmTypeInterface.cpp.inc" 130 131 LogicalResult 132 OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const { 133 return entry.emitError() << "unknown 'resource' key '" << entry.getKey() 134 << "' for dialect '" << getDialect()->getNamespace() 135 << "'"; 136 } 137 138 //===----------------------------------------------------------------------===// 139 // OpPrintingFlags 140 //===----------------------------------------------------------------------===// 141 142 namespace { 143 /// This struct contains command line options that can be used to initialize 144 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need 145 /// for global command line options. 146 struct AsmPrinterOptions { 147 llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{ 148 "mlir-print-elementsattrs-with-hex-if-larger", 149 llvm::cl::desc( 150 "Print DenseElementsAttrs with a hex string that have " 151 "more elements than the given upper limit (use -1 to disable)")}; 152 153 llvm::cl::opt<unsigned> elideElementsAttrIfLarger{ 154 "mlir-elide-elementsattrs-if-larger", 155 llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " 156 "more elements than the given upper limit")}; 157 158 llvm::cl::opt<unsigned> elideResourceStringsIfLarger{ 159 "mlir-elide-resource-strings-if-larger", 160 llvm::cl::desc( 161 "Elide printing value of resources if string is too long in chars.")}; 162 163 llvm::cl::opt<bool> printDebugInfoOpt{ 164 "mlir-print-debuginfo", llvm::cl::init(false), 165 llvm::cl::desc("Print debug info in MLIR output")}; 166 167 llvm::cl::opt<bool> printPrettyDebugInfoOpt{ 168 "mlir-pretty-debuginfo", llvm::cl::init(false), 169 llvm::cl::desc("Print pretty debug info in MLIR output")}; 170 171 // Use the generic op output form in the operation printer even if the custom 172 // form is defined. 173 llvm::cl::opt<bool> printGenericOpFormOpt{ 174 "mlir-print-op-generic", llvm::cl::init(false), 175 llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden}; 176 177 llvm::cl::opt<bool> assumeVerifiedOpt{ 178 "mlir-print-assume-verified", llvm::cl::init(false), 179 llvm::cl::desc("Skip op verification when using custom printers"), 180 llvm::cl::Hidden}; 181 182 llvm::cl::opt<bool> printLocalScopeOpt{ 183 "mlir-print-local-scope", llvm::cl::init(false), 184 llvm::cl::desc("Print with local scope and inline information (eliding " 185 "aliases for attributes, types, and locations")}; 186 187 llvm::cl::opt<bool> skipRegionsOpt{ 188 "mlir-print-skip-regions", llvm::cl::init(false), 189 llvm::cl::desc("Skip regions when printing ops.")}; 190 191 llvm::cl::opt<bool> printValueUsers{ 192 "mlir-print-value-users", llvm::cl::init(false), 193 llvm::cl::desc( 194 "Print users of operation results and block arguments as a comment")}; 195 196 llvm::cl::opt<bool> printUniqueSSAIDs{ 197 "mlir-print-unique-ssa-ids", llvm::cl::init(false), 198 llvm::cl::desc("Print unique SSA ID numbers for values, block arguments " 199 "and naming conflicts across all regions")}; 200 201 llvm::cl::opt<bool> useNameLocAsPrefix{ 202 "mlir-use-nameloc-as-prefix", llvm::cl::init(false), 203 llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")}; 204 }; 205 } // namespace 206 207 static llvm::ManagedStatic<AsmPrinterOptions> clOptions; 208 209 /// Register a set of useful command-line options that can be used to configure 210 /// various flags within the AsmPrinter. 211 void mlir::registerAsmPrinterCLOptions() { 212 // Make sure that the options struct has been initialized. 213 *clOptions; 214 } 215 216 /// Initialize the printing flags with default supplied by the cl::opts above. 217 OpPrintingFlags::OpPrintingFlags() 218 : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), 219 printGenericOpFormFlag(false), skipRegionsFlag(false), 220 assumeVerifiedFlag(false), printLocalScope(false), 221 printValueUsersFlag(false), printUniqueSSAIDsFlag(false), 222 useNameLocAsPrefix(false) { 223 // Initialize based upon command line options, if they are available. 224 if (!clOptions.isConstructed()) 225 return; 226 if (clOptions->elideElementsAttrIfLarger.getNumOccurrences()) 227 elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger; 228 if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) 229 elementsAttrHexElementLimit = 230 clOptions->printElementsAttrWithHexIfLarger.getValue(); 231 if (clOptions->elideResourceStringsIfLarger.getNumOccurrences()) 232 resourceStringCharLimit = clOptions->elideResourceStringsIfLarger; 233 printDebugInfoFlag = clOptions->printDebugInfoOpt; 234 printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; 235 printGenericOpFormFlag = clOptions->printGenericOpFormOpt; 236 assumeVerifiedFlag = clOptions->assumeVerifiedOpt; 237 printLocalScope = clOptions->printLocalScopeOpt; 238 skipRegionsFlag = clOptions->skipRegionsOpt; 239 printValueUsersFlag = clOptions->printValueUsers; 240 printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs; 241 useNameLocAsPrefix = clOptions->useNameLocAsPrefix; 242 } 243 244 /// Enable the elision of large elements attributes, by printing a '...' 245 /// instead of the element data, when the number of elements is greater than 246 /// `largeElementLimit`. Note: The IR generated with this option is not 247 /// parsable. 248 OpPrintingFlags & 249 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) { 250 elementsAttrElementLimit = largeElementLimit; 251 return *this; 252 } 253 254 OpPrintingFlags & 255 OpPrintingFlags::printLargeElementsAttrWithHex(int64_t largeElementLimit) { 256 elementsAttrHexElementLimit = largeElementLimit; 257 return *this; 258 } 259 260 OpPrintingFlags & 261 OpPrintingFlags::elideLargeResourceString(int64_t largeResourceLimit) { 262 resourceStringCharLimit = largeResourceLimit; 263 return *this; 264 } 265 266 /// Enable printing of debug information. If 'prettyForm' is set to true, 267 /// debug information is printed in a more readable 'pretty' form. 268 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool enable, 269 bool prettyForm) { 270 printDebugInfoFlag = enable; 271 printDebugInfoPrettyFormFlag = prettyForm; 272 return *this; 273 } 274 275 /// Always print operations in the generic form. 276 OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) { 277 printGenericOpFormFlag = enable; 278 return *this; 279 } 280 281 /// Always skip Regions. 282 OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) { 283 skipRegionsFlag = skip; 284 return *this; 285 } 286 287 /// Do not verify the operation when using custom operation printers. 288 OpPrintingFlags &OpPrintingFlags::assumeVerified(bool enable) { 289 assumeVerifiedFlag = enable; 290 return *this; 291 } 292 293 /// Use local scope when printing the operation. This allows for using the 294 /// printer in a more localized and thread-safe setting, but may not necessarily 295 /// be identical of what the IR will look like when dumping the full module. 296 OpPrintingFlags &OpPrintingFlags::useLocalScope(bool enable) { 297 printLocalScope = enable; 298 return *this; 299 } 300 301 /// Print users of values as comments. 302 OpPrintingFlags &OpPrintingFlags::printValueUsers(bool enable) { 303 printValueUsersFlag = enable; 304 return *this; 305 } 306 307 /// Print unique SSA ID numbers for values, block arguments and naming conflicts 308 /// across all regions 309 OpPrintingFlags &OpPrintingFlags::printUniqueSSAIDs(bool enable) { 310 printUniqueSSAIDsFlag = enable; 311 return *this; 312 } 313 314 /// Return if the given ElementsAttr should be elided. 315 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { 316 return elementsAttrElementLimit && 317 *elementsAttrElementLimit < int64_t(attr.getNumElements()) && 318 !llvm::isa<SplatElementsAttr>(attr); 319 } 320 321 /// Return if the given ElementsAttr should be printed as hex string. 322 bool OpPrintingFlags::shouldPrintElementsAttrWithHex(ElementsAttr attr) const { 323 // -1 is used to disable hex printing. 324 return (elementsAttrHexElementLimit != -1) && 325 (elementsAttrHexElementLimit < int64_t(attr.getNumElements())) && 326 !llvm::isa<SplatElementsAttr>(attr); 327 } 328 329 /// Return the size limit for printing large ElementsAttr. 330 std::optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const { 331 return elementsAttrElementLimit; 332 } 333 334 /// Return the size limit for printing large ElementsAttr as hex string. 335 int64_t OpPrintingFlags::getLargeElementsAttrHexLimit() const { 336 return elementsAttrHexElementLimit; 337 } 338 339 /// Return the size limit for printing large ElementsAttr. 340 std::optional<uint64_t> OpPrintingFlags::getLargeResourceStringLimit() const { 341 return resourceStringCharLimit; 342 } 343 344 /// Return if debug information should be printed. 345 bool OpPrintingFlags::shouldPrintDebugInfo() const { 346 return printDebugInfoFlag; 347 } 348 349 /// Return if debug information should be printed in the pretty form. 350 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const { 351 return printDebugInfoPrettyFormFlag; 352 } 353 354 /// Return if operations should be printed in the generic form. 355 bool OpPrintingFlags::shouldPrintGenericOpForm() const { 356 return printGenericOpFormFlag; 357 } 358 359 /// Return if Region should be skipped. 360 bool OpPrintingFlags::shouldSkipRegions() const { return skipRegionsFlag; } 361 362 /// Return if operation verification should be skipped. 363 bool OpPrintingFlags::shouldAssumeVerified() const { 364 return assumeVerifiedFlag; 365 } 366 367 /// Return if the printer should use local scope when dumping the IR. 368 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } 369 370 /// Return if the printer should print users of values. 371 bool OpPrintingFlags::shouldPrintValueUsers() const { 372 return printValueUsersFlag; 373 } 374 375 /// Return if the printer should use unique IDs. 376 bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const { 377 return printUniqueSSAIDsFlag || shouldPrintGenericOpForm(); 378 } 379 380 /// Return if the printer should use NameLocs as prefixes when printing SSA IDs. 381 bool OpPrintingFlags::shouldUseNameLocAsPrefix() const { 382 return useNameLocAsPrefix; 383 } 384 385 //===----------------------------------------------------------------------===// 386 // NewLineCounter 387 //===----------------------------------------------------------------------===// 388 389 namespace { 390 /// This class is a simple formatter that emits a new line when inputted into a 391 /// stream, that enables counting the number of newlines emitted. This class 392 /// should be used whenever emitting newlines in the printer. 393 struct NewLineCounter { 394 unsigned curLine = 1; 395 }; 396 397 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) { 398 ++newLine.curLine; 399 return os << '\n'; 400 } 401 } // namespace 402 403 //===----------------------------------------------------------------------===// 404 // AsmPrinter::Impl 405 //===----------------------------------------------------------------------===// 406 407 namespace mlir { 408 class AsmPrinter::Impl { 409 public: 410 Impl(raw_ostream &os, AsmStateImpl &state); 411 explicit Impl(Impl &other) : Impl(other.os, other.state) {} 412 413 /// Returns the output stream of the printer. 414 raw_ostream &getStream() { return os; } 415 416 template <typename Container, typename UnaryFunctor> 417 inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const { 418 llvm::interleaveComma(c, os, eachFn); 419 } 420 421 /// This enum describes the different kinds of elision for the type of an 422 /// attribute when printing it. 423 enum class AttrTypeElision { 424 /// The type must not be elided, 425 Never, 426 /// The type may be elided when it matches the default used in the parser 427 /// (for example i64 is the default for integer attributes). 428 May, 429 /// The type must be elided. 430 Must 431 }; 432 433 /// Print the given attribute or an alias. 434 void printAttribute(Attribute attr, 435 AttrTypeElision typeElision = AttrTypeElision::Never); 436 /// Print the given attribute without considering an alias. 437 void printAttributeImpl(Attribute attr, 438 AttrTypeElision typeElision = AttrTypeElision::Never); 439 440 /// Print the alias for the given attribute, return failure if no alias could 441 /// be printed. 442 LogicalResult printAlias(Attribute attr); 443 444 /// Print the given type or an alias. 445 void printType(Type type); 446 /// Print the given type. 447 void printTypeImpl(Type type); 448 449 /// Print the alias for the given type, return failure if no alias could 450 /// be printed. 451 LogicalResult printAlias(Type type); 452 453 /// Print the given location to the stream. If `allowAlias` is true, this 454 /// allows for the internal location to use an attribute alias. 455 void printLocation(LocationAttr loc, bool allowAlias = false); 456 457 /// Print a reference to the given resource that is owned by the given 458 /// dialect. 459 void printResourceHandle(const AsmDialectResourceHandle &resource); 460 461 void printAffineMap(AffineMap map); 462 void 463 printAffineExpr(AffineExpr expr, 464 function_ref<void(unsigned, bool)> printValueName = nullptr); 465 void printAffineConstraint(AffineExpr expr, bool isEq); 466 void printIntegerSet(IntegerSet set); 467 468 LogicalResult pushCyclicPrinting(const void *opaquePointer); 469 470 void popCyclicPrinting(); 471 472 void printDimensionList(ArrayRef<int64_t> shape); 473 474 protected: 475 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 476 ArrayRef<StringRef> elidedAttrs = {}, 477 bool withKeyword = false); 478 void printNamedAttribute(NamedAttribute attr); 479 void printTrailingLocation(Location loc, bool allowAlias = true); 480 void printLocationInternal(LocationAttr loc, bool pretty = false, 481 bool isTopLevel = false); 482 483 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is 484 /// used instead of individual elements when the elements attr is large. 485 void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); 486 487 /// Print a dense string elements attribute. 488 void printDenseStringElementsAttr(DenseStringElementsAttr attr); 489 490 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is 491 /// used instead of individual elements when the elements attr is large. 492 void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, 493 bool allowHex); 494 495 /// Print a dense array attribute. 496 void printDenseArrayAttr(DenseArrayAttr attr); 497 498 void printDialectAttribute(Attribute attr); 499 void printDialectType(Type type); 500 501 /// Print an escaped string, wrapped with "". 502 void printEscapedString(StringRef str); 503 504 /// Print a hex string, wrapped with "". 505 void printHexString(StringRef str); 506 void printHexString(ArrayRef<char> data); 507 508 /// This enum is used to represent the binding strength of the enclosing 509 /// context that an AffineExprStorage is being printed in, so we can 510 /// intelligently produce parens. 511 enum class BindingStrength { 512 Weak, // + and - 513 Strong, // All other binary operators. 514 }; 515 void printAffineExprInternal( 516 AffineExpr expr, BindingStrength enclosingTightness, 517 function_ref<void(unsigned, bool)> printValueName = nullptr); 518 519 /// The output stream for the printer. 520 raw_ostream &os; 521 522 /// An underlying assembly printer state. 523 AsmStateImpl &state; 524 525 /// A set of flags to control the printer's behavior. 526 OpPrintingFlags printerFlags; 527 528 /// A tracker for the number of new lines emitted during printing. 529 NewLineCounter newLine; 530 }; 531 } // namespace mlir 532 533 //===----------------------------------------------------------------------===// 534 // AliasInitializer 535 //===----------------------------------------------------------------------===// 536 537 namespace { 538 /// This class represents a specific instance of a symbol Alias. 539 class SymbolAlias { 540 public: 541 SymbolAlias(StringRef name, uint32_t suffixIndex, bool isType, 542 bool isDeferrable) 543 : name(name), suffixIndex(suffixIndex), isType(isType), 544 isDeferrable(isDeferrable) {} 545 546 /// Print this alias to the given stream. 547 void print(raw_ostream &os) const { 548 os << (isType ? "!" : "#") << name; 549 if (suffixIndex) 550 os << suffixIndex; 551 } 552 553 /// Returns true if this is a type alias. 554 bool isTypeAlias() const { return isType; } 555 556 /// Returns true if this alias supports deferred resolution when parsing. 557 bool canBeDeferred() const { return isDeferrable; } 558 559 private: 560 /// The main name of the alias. 561 StringRef name; 562 /// The suffix index of the alias. 563 uint32_t suffixIndex : 30; 564 /// A flag indicating whether this alias is for a type. 565 bool isType : 1; 566 /// A flag indicating whether this alias may be deferred or not. 567 bool isDeferrable : 1; 568 569 public: 570 /// Used to avoid printing incomplete aliases for recursive types. 571 bool isPrinted = false; 572 }; 573 574 /// This class represents a utility that initializes the set of attribute and 575 /// type aliases, without the need to store the extra information within the 576 /// main AliasState class or pass it around via function arguments. 577 class AliasInitializer { 578 public: 579 AliasInitializer( 580 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces, 581 llvm::BumpPtrAllocator &aliasAllocator) 582 : interfaces(interfaces), aliasAllocator(aliasAllocator), 583 aliasOS(aliasBuffer) {} 584 585 void initialize(Operation *op, const OpPrintingFlags &printerFlags, 586 llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias); 587 588 /// Visit the given attribute to see if it has an alias. `canBeDeferred` is 589 /// set to true if the originator of this attribute can resolve the alias 590 /// after parsing has completed (e.g. in the case of operation locations). 591 /// `elideType` indicates if the type of the attribute should be skipped when 592 /// looking for nested aliases. Returns the maximum alias depth of the 593 /// attribute, and the alias index of this attribute. 594 std::pair<size_t, size_t> visit(Attribute attr, bool canBeDeferred = false, 595 bool elideType = false) { 596 return visitImpl(attr, aliases, canBeDeferred, elideType); 597 } 598 599 /// Visit the given type to see if it has an alias. `canBeDeferred` is 600 /// set to true if the originator of this attribute can resolve the alias 601 /// after parsing has completed. Returns the maximum alias depth of the type, 602 /// and the alias index of this type. 603 std::pair<size_t, size_t> visit(Type type, bool canBeDeferred = false) { 604 return visitImpl(type, aliases, canBeDeferred); 605 } 606 607 private: 608 struct InProgressAliasInfo { 609 InProgressAliasInfo() 610 : aliasDepth(0), isType(false), canBeDeferred(false) {} 611 InProgressAliasInfo(StringRef alias) 612 : alias(alias), aliasDepth(1), isType(false), canBeDeferred(false) {} 613 614 bool operator<(const InProgressAliasInfo &rhs) const { 615 // Order first by depth, then by attr/type kind, and then by name. 616 if (aliasDepth != rhs.aliasDepth) 617 return aliasDepth < rhs.aliasDepth; 618 if (isType != rhs.isType) 619 return isType; 620 return alias < rhs.alias; 621 } 622 623 /// The alias for the attribute or type, or std::nullopt if the value has no 624 /// alias. 625 std::optional<StringRef> alias; 626 /// The alias depth of this attribute or type, i.e. an indication of the 627 /// relative ordering of when to print this alias. 628 unsigned aliasDepth : 30; 629 /// If this alias represents a type or an attribute. 630 bool isType : 1; 631 /// If this alias can be deferred or not. 632 bool canBeDeferred : 1; 633 /// Indices for child aliases. 634 SmallVector<size_t> childIndices; 635 }; 636 637 /// Visit the given attribute or type to see if it has an alias. 638 /// `canBeDeferred` is set to true if the originator of this value can resolve 639 /// the alias after parsing has completed (e.g. in the case of operation 640 /// locations). Returns the maximum alias depth of the value, and its alias 641 /// index. 642 template <typename T, typename... PrintArgs> 643 std::pair<size_t, size_t> 644 visitImpl(T value, 645 llvm::MapVector<const void *, InProgressAliasInfo> &aliases, 646 bool canBeDeferred, PrintArgs &&...printArgs); 647 648 /// Mark the given alias as non-deferrable. 649 void markAliasNonDeferrable(size_t aliasIndex); 650 651 /// Try to generate an alias for the provided symbol. If an alias is 652 /// generated, the provided alias mapping and reverse mapping are updated. 653 template <typename T> 654 void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred); 655 656 /// Given a collection of aliases and symbols, initialize a mapping from a 657 /// symbol to a given alias. 658 static void initializeAliases( 659 llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols, 660 llvm::MapVector<const void *, SymbolAlias> &symbolToAlias); 661 662 /// The set of asm interfaces within the context. 663 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces; 664 665 /// An allocator used for alias names. 666 llvm::BumpPtrAllocator &aliasAllocator; 667 668 /// The set of built aliases. 669 llvm::MapVector<const void *, InProgressAliasInfo> aliases; 670 671 /// Storage and stream used when generating an alias. 672 SmallString<32> aliasBuffer; 673 llvm::raw_svector_ostream aliasOS; 674 }; 675 676 /// This class implements a dummy OpAsmPrinter that doesn't print any output, 677 /// and merely collects the attributes and types that *would* be printed in a 678 /// normal print invocation so that we can generate proper aliases. This allows 679 /// for us to generate aliases only for the attributes and types that would be 680 /// in the output, and trims down unnecessary output. 681 class DummyAliasOperationPrinter : private OpAsmPrinter { 682 public: 683 explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags, 684 AliasInitializer &initializer) 685 : printerFlags(printerFlags), initializer(initializer) {} 686 687 /// Prints the entire operation with the custom assembly form, if available, 688 /// or the generic assembly form, otherwise. 689 void printCustomOrGenericOp(Operation *op) override { 690 // Visit the operation location. 691 if (printerFlags.shouldPrintDebugInfo()) 692 initializer.visit(op->getLoc(), /*canBeDeferred=*/true); 693 694 // If requested, always print the generic form. 695 if (!printerFlags.shouldPrintGenericOpForm()) { 696 op->getName().printAssembly(op, *this, /*defaultDialect=*/""); 697 return; 698 } 699 700 // Otherwise print with the generic assembly form. 701 printGenericOp(op); 702 } 703 704 private: 705 /// Print the given operation in the generic form. 706 void printGenericOp(Operation *op, bool printOpName = true) override { 707 // Consider nested operations for aliases. 708 if (!printerFlags.shouldSkipRegions()) { 709 for (Region ®ion : op->getRegions()) 710 printRegion(region, /*printEntryBlockArgs=*/true, 711 /*printBlockTerminators=*/true); 712 } 713 714 // Visit all the types used in the operation. 715 for (Type type : op->getOperandTypes()) 716 printType(type); 717 for (Type type : op->getResultTypes()) 718 printType(type); 719 720 // Consider the attributes of the operation for aliases. 721 for (const NamedAttribute &attr : op->getAttrs()) 722 printAttribute(attr.getValue()); 723 } 724 725 /// Print the given block. If 'printBlockArgs' is false, the arguments of the 726 /// block are not printed. If 'printBlockTerminator' is false, the terminator 727 /// operation of the block is not printed. 728 void print(Block *block, bool printBlockArgs = true, 729 bool printBlockTerminator = true) { 730 // Consider the types of the block arguments for aliases if 'printBlockArgs' 731 // is set to true. 732 if (printBlockArgs) { 733 for (BlockArgument arg : block->getArguments()) { 734 printType(arg.getType()); 735 736 // Visit the argument location. 737 if (printerFlags.shouldPrintDebugInfo()) 738 // TODO: Allow deferring argument locations. 739 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); 740 } 741 } 742 743 // Consider the operations within this block, ignoring the terminator if 744 // requested. 745 bool hasTerminator = 746 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); 747 auto range = llvm::make_range( 748 block->begin(), 749 std::prev(block->end(), 750 (!hasTerminator || printBlockTerminator) ? 0 : 1)); 751 for (Operation &op : range) 752 printCustomOrGenericOp(&op); 753 } 754 755 /// Print the given region. 756 void printRegion(Region ®ion, bool printEntryBlockArgs, 757 bool printBlockTerminators, 758 bool printEmptyBlock = false) override { 759 if (region.empty()) 760 return; 761 if (printerFlags.shouldSkipRegions()) { 762 os << "{...}"; 763 return; 764 } 765 766 auto *entryBlock = ®ion.front(); 767 print(entryBlock, printEntryBlockArgs, printBlockTerminators); 768 for (Block &b : llvm::drop_begin(region, 1)) 769 print(&b); 770 } 771 772 void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, 773 bool omitType) override { 774 printType(arg.getType()); 775 // Visit the argument location. 776 if (printerFlags.shouldPrintDebugInfo()) 777 // TODO: Allow deferring argument locations. 778 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); 779 } 780 781 /// Consider the given type to be printed for an alias. 782 void printType(Type type) override { initializer.visit(type); } 783 784 /// Consider the given attribute to be printed for an alias. 785 void printAttribute(Attribute attr) override { initializer.visit(attr); } 786 void printAttributeWithoutType(Attribute attr) override { 787 printAttribute(attr); 788 } 789 LogicalResult printAlias(Attribute attr) override { 790 initializer.visit(attr); 791 return success(); 792 } 793 LogicalResult printAlias(Type type) override { 794 initializer.visit(type); 795 return success(); 796 } 797 798 /// Consider the given location to be printed for an alias. 799 void printOptionalLocationSpecifier(Location loc) override { 800 printAttribute(loc); 801 } 802 803 /// Print the given set of attributes with names not included within 804 /// 'elidedAttrs'. 805 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 806 ArrayRef<StringRef> elidedAttrs = {}) override { 807 if (attrs.empty()) 808 return; 809 if (elidedAttrs.empty()) { 810 for (const NamedAttribute &attr : attrs) 811 printAttribute(attr.getValue()); 812 return; 813 } 814 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), 815 elidedAttrs.end()); 816 for (const NamedAttribute &attr : attrs) 817 if (!elidedAttrsSet.contains(attr.getName().strref())) 818 printAttribute(attr.getValue()); 819 } 820 void printOptionalAttrDictWithKeyword( 821 ArrayRef<NamedAttribute> attrs, 822 ArrayRef<StringRef> elidedAttrs = {}) override { 823 printOptionalAttrDict(attrs, elidedAttrs); 824 } 825 826 /// Return a null stream as the output stream, this will ignore any data fed 827 /// to it. 828 raw_ostream &getStream() const override { return os; } 829 830 /// The following are hooks of `OpAsmPrinter` that are not necessary for 831 /// determining potential aliases. 832 void printFloat(const APFloat &) override {} 833 void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} 834 void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} 835 void printNewline() override {} 836 void increaseIndent() override {} 837 void decreaseIndent() override {} 838 void printOperand(Value) override {} 839 void printOperand(Value, raw_ostream &os) override { 840 // Users expect the output string to have at least the prefixed % to signal 841 // a value name. To maintain this invariant, emit a name even if it is 842 // guaranteed to go unused. 843 os << "%"; 844 } 845 void printKeywordOrString(StringRef) override {} 846 void printString(StringRef) override {} 847 void printResourceHandle(const AsmDialectResourceHandle &) override {} 848 void printSymbolName(StringRef) override {} 849 void printSuccessor(Block *) override {} 850 void printSuccessorAndUseList(Block *, ValueRange) override {} 851 void shadowRegionArgs(Region &, ValueRange) override {} 852 853 /// The printer flags to use when determining potential aliases. 854 const OpPrintingFlags &printerFlags; 855 856 /// The initializer to use when identifying aliases. 857 AliasInitializer &initializer; 858 859 /// A dummy output stream. 860 mutable llvm::raw_null_ostream os; 861 }; 862 863 class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { 864 public: 865 explicit DummyAliasDialectAsmPrinter(AliasInitializer &initializer, 866 bool canBeDeferred, 867 SmallVectorImpl<size_t> &childIndices) 868 : initializer(initializer), canBeDeferred(canBeDeferred), 869 childIndices(childIndices) {} 870 871 /// Print the given attribute/type, visiting any nested aliases that would be 872 /// generated as part of printing. Returns the maximum alias depth found while 873 /// printing the given value. 874 template <typename T, typename... PrintArgs> 875 size_t printAndVisitNestedAliases(T value, PrintArgs &&...printArgs) { 876 printAndVisitNestedAliasesImpl(value, printArgs...); 877 return maxAliasDepth; 878 } 879 880 private: 881 /// Print the given attribute/type, visiting any nested aliases that would be 882 /// generated as part of printing. 883 void printAndVisitNestedAliasesImpl(Attribute attr, bool elideType) { 884 if (!isa<BuiltinDialect>(attr.getDialect())) { 885 attr.getDialect().printAttribute(attr, *this); 886 887 // Process the builtin attributes. 888 } else if (llvm::isa<AffineMapAttr, DenseArrayAttr, FloatAttr, IntegerAttr, 889 IntegerSetAttr, UnitAttr>(attr)) { 890 return; 891 } else if (auto distinctAttr = dyn_cast<DistinctAttr>(attr)) { 892 printAttribute(distinctAttr.getReferencedAttr()); 893 } else if (auto dictAttr = dyn_cast<DictionaryAttr>(attr)) { 894 for (const NamedAttribute &nestedAttr : dictAttr.getValue()) { 895 printAttribute(nestedAttr.getName()); 896 printAttribute(nestedAttr.getValue()); 897 } 898 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) { 899 for (Attribute nestedAttr : arrayAttr.getValue()) 900 printAttribute(nestedAttr); 901 } else if (auto typeAttr = dyn_cast<TypeAttr>(attr)) { 902 printType(typeAttr.getValue()); 903 } else if (auto locAttr = dyn_cast<OpaqueLoc>(attr)) { 904 printAttribute(locAttr.getFallbackLocation()); 905 } else if (auto locAttr = dyn_cast<NameLoc>(attr)) { 906 if (!isa<UnknownLoc>(locAttr.getChildLoc())) 907 printAttribute(locAttr.getChildLoc()); 908 } else if (auto locAttr = dyn_cast<CallSiteLoc>(attr)) { 909 printAttribute(locAttr.getCallee()); 910 printAttribute(locAttr.getCaller()); 911 } else if (auto locAttr = dyn_cast<FusedLoc>(attr)) { 912 if (Attribute metadata = locAttr.getMetadata()) 913 printAttribute(metadata); 914 for (Location nestedLoc : locAttr.getLocations()) 915 printAttribute(nestedLoc); 916 } 917 918 // Don't print the type if we must elide it, or if it is a None type. 919 if (!elideType) { 920 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) { 921 Type attrType = typedAttr.getType(); 922 if (!llvm::isa<NoneType>(attrType)) 923 printType(attrType); 924 } 925 } 926 } 927 void printAndVisitNestedAliasesImpl(Type type) { 928 if (!isa<BuiltinDialect>(type.getDialect())) 929 return type.getDialect().printType(type, *this); 930 931 // Only visit the layout of memref if it isn't the identity. 932 if (auto memrefTy = llvm::dyn_cast<MemRefType>(type)) { 933 printType(memrefTy.getElementType()); 934 MemRefLayoutAttrInterface layout = memrefTy.getLayout(); 935 if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) 936 printAttribute(memrefTy.getLayout()); 937 if (memrefTy.getMemorySpace()) 938 printAttribute(memrefTy.getMemorySpace()); 939 return; 940 } 941 942 // For most builtin types, we can simply walk the sub elements. 943 auto visitFn = [&](auto element) { 944 if (element) 945 (void)printAlias(element); 946 }; 947 type.walkImmediateSubElements(visitFn, visitFn); 948 } 949 950 /// Consider the given type to be printed for an alias. 951 void printType(Type type) override { 952 recordAliasResult(initializer.visit(type, canBeDeferred)); 953 } 954 955 /// Consider the given attribute to be printed for an alias. 956 void printAttribute(Attribute attr) override { 957 recordAliasResult(initializer.visit(attr, canBeDeferred)); 958 } 959 void printAttributeWithoutType(Attribute attr) override { 960 recordAliasResult( 961 initializer.visit(attr, canBeDeferred, /*elideType=*/true)); 962 } 963 LogicalResult printAlias(Attribute attr) override { 964 printAttribute(attr); 965 return success(); 966 } 967 LogicalResult printAlias(Type type) override { 968 printType(type); 969 return success(); 970 } 971 972 /// Record the alias result of a child element. 973 void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) { 974 childIndices.push_back(aliasDepthAndIndex.second); 975 if (aliasDepthAndIndex.first > maxAliasDepth) 976 maxAliasDepth = aliasDepthAndIndex.first; 977 } 978 979 /// Return a null stream as the output stream, this will ignore any data fed 980 /// to it. 981 raw_ostream &getStream() const override { return os; } 982 983 /// The following are hooks of `DialectAsmPrinter` that are not necessary for 984 /// determining potential aliases. 985 void printFloat(const APFloat &) override {} 986 void printKeywordOrString(StringRef) override {} 987 void printString(StringRef) override {} 988 void printSymbolName(StringRef) override {} 989 void printResourceHandle(const AsmDialectResourceHandle &) override {} 990 991 LogicalResult pushCyclicPrinting(const void *opaquePointer) override { 992 return success(cyclicPrintingStack.insert(opaquePointer)); 993 } 994 995 void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); } 996 997 /// Stack of potentially cyclic mutable attributes or type currently being 998 /// printed. 999 SetVector<const void *> cyclicPrintingStack; 1000 1001 /// The initializer to use when identifying aliases. 1002 AliasInitializer &initializer; 1003 1004 /// If the aliases visited by this printer can be deferred. 1005 bool canBeDeferred; 1006 1007 /// The indices of child aliases. 1008 SmallVectorImpl<size_t> &childIndices; 1009 1010 /// The maximum alias depth found by the printer. 1011 size_t maxAliasDepth = 0; 1012 1013 /// A dummy output stream. 1014 mutable llvm::raw_null_ostream os; 1015 }; 1016 } // namespace 1017 1018 /// Sanitize the given name such that it can be used as a valid identifier. If 1019 /// the string needs to be modified in any way, the provided buffer is used to 1020 /// store the new copy, 1021 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, 1022 StringRef allowedPunctChars = "$._-", 1023 bool allowTrailingDigit = true) { 1024 assert(!name.empty() && "Shouldn't have an empty name here"); 1025 1026 auto validChar = [&](char ch) { 1027 return llvm::isAlnum(ch) || allowedPunctChars.contains(ch); 1028 }; 1029 1030 auto copyNameToBuffer = [&] { 1031 for (char ch : name) { 1032 if (validChar(ch)) 1033 buffer.push_back(ch); 1034 else if (ch == ' ') 1035 buffer.push_back('_'); 1036 else 1037 buffer.append(llvm::utohexstr((unsigned char)ch)); 1038 } 1039 }; 1040 1041 // Check to see if this name is valid. If it starts with a digit, then it 1042 // could conflict with the autogenerated numeric ID's, so add an underscore 1043 // prefix to avoid problems. 1044 if (isdigit(name[0]) || (!validChar(name[0]) && name[0] != ' ')) { 1045 buffer.push_back('_'); 1046 copyNameToBuffer(); 1047 return buffer; 1048 } 1049 1050 // If the name ends with a trailing digit, add a '_' to avoid potential 1051 // conflicts with autogenerated ID's. 1052 if (!allowTrailingDigit && isdigit(name.back())) { 1053 copyNameToBuffer(); 1054 buffer.push_back('_'); 1055 return buffer; 1056 } 1057 1058 // Check to see that the name consists of only valid identifier characters. 1059 for (char ch : name) { 1060 if (!validChar(ch)) { 1061 copyNameToBuffer(); 1062 return buffer; 1063 } 1064 } 1065 1066 // If there are no invalid characters, return the original name. 1067 return name; 1068 } 1069 1070 /// Given a collection of aliases and symbols, initialize a mapping from a 1071 /// symbol to a given alias. 1072 void AliasInitializer::initializeAliases( 1073 llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols, 1074 llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) { 1075 SmallVector<std::pair<const void *, InProgressAliasInfo>, 0> 1076 unprocessedAliases = visitedSymbols.takeVector(); 1077 llvm::stable_sort(unprocessedAliases, [](const auto &lhs, const auto &rhs) { 1078 return lhs.second < rhs.second; 1079 }); 1080 1081 llvm::StringMap<unsigned> nameCounts; 1082 for (auto &[symbol, aliasInfo] : unprocessedAliases) { 1083 if (!aliasInfo.alias) 1084 continue; 1085 StringRef alias = *aliasInfo.alias; 1086 unsigned nameIndex = nameCounts[alias]++; 1087 symbolToAlias.insert( 1088 {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType, 1089 aliasInfo.canBeDeferred)}); 1090 } 1091 } 1092 1093 void AliasInitializer::initialize( 1094 Operation *op, const OpPrintingFlags &printerFlags, 1095 llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) { 1096 // Use a dummy printer when walking the IR so that we can collect the 1097 // attributes/types that will actually be used during printing when 1098 // considering aliases. 1099 DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); 1100 aliasPrinter.printCustomOrGenericOp(op); 1101 1102 // Initialize the aliases. 1103 initializeAliases(aliases, attrTypeToAlias); 1104 } 1105 1106 template <typename T, typename... PrintArgs> 1107 std::pair<size_t, size_t> AliasInitializer::visitImpl( 1108 T value, llvm::MapVector<const void *, InProgressAliasInfo> &aliases, 1109 bool canBeDeferred, PrintArgs &&...printArgs) { 1110 auto [it, inserted] = 1111 aliases.insert({value.getAsOpaquePointer(), InProgressAliasInfo()}); 1112 size_t aliasIndex = std::distance(aliases.begin(), it); 1113 if (!inserted) { 1114 // Make sure that the alias isn't deferred if we don't permit it. 1115 if (!canBeDeferred) 1116 markAliasNonDeferrable(aliasIndex); 1117 return {static_cast<size_t>(it->second.aliasDepth), aliasIndex}; 1118 } 1119 1120 // Try to generate an alias for this value. 1121 generateAlias(value, it->second, canBeDeferred); 1122 it->second.isType = std::is_base_of_v<Type, T>; 1123 it->second.canBeDeferred = canBeDeferred; 1124 1125 // Print the value, capturing any nested elements that require aliases. 1126 SmallVector<size_t> childAliases; 1127 DummyAliasDialectAsmPrinter printer(*this, canBeDeferred, childAliases); 1128 size_t maxAliasDepth = 1129 printer.printAndVisitNestedAliases(value, printArgs...); 1130 1131 // Make sure to recompute `it` in case the map was reallocated. 1132 it = std::next(aliases.begin(), aliasIndex); 1133 1134 // If we had sub elements, update to account for the depth. 1135 it->second.childIndices = std::move(childAliases); 1136 if (maxAliasDepth) 1137 it->second.aliasDepth = maxAliasDepth + 1; 1138 1139 // Propagate the alias depth of the value. 1140 return {(size_t)it->second.aliasDepth, aliasIndex}; 1141 } 1142 1143 void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) { 1144 auto *it = std::next(aliases.begin(), aliasIndex); 1145 1146 // If already marked non-deferrable stop the recursion. 1147 // All children should already be marked non-deferrable as well. 1148 if (!it->second.canBeDeferred) 1149 return; 1150 1151 it->second.canBeDeferred = false; 1152 1153 // Propagate the non-deferrable flag to any child aliases. 1154 for (size_t childIndex : it->second.childIndices) 1155 markAliasNonDeferrable(childIndex); 1156 } 1157 1158 template <typename T> 1159 void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias, 1160 bool canBeDeferred) { 1161 SmallString<32> nameBuffer; 1162 for (const auto &interface : interfaces) { 1163 OpAsmDialectInterface::AliasResult result = 1164 interface.getAlias(symbol, aliasOS); 1165 if (result == OpAsmDialectInterface::AliasResult::NoAlias) 1166 continue; 1167 nameBuffer = std::move(aliasBuffer); 1168 assert(!nameBuffer.empty() && "expected valid alias name"); 1169 if (result == OpAsmDialectInterface::AliasResult::FinalAlias) 1170 break; 1171 } 1172 1173 if (nameBuffer.empty()) 1174 return; 1175 1176 SmallString<16> tempBuffer; 1177 StringRef name = 1178 sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-", 1179 /*allowTrailingDigit=*/false); 1180 name = name.copy(aliasAllocator); 1181 alias = InProgressAliasInfo(name); 1182 } 1183 1184 //===----------------------------------------------------------------------===// 1185 // AliasState 1186 //===----------------------------------------------------------------------===// 1187 1188 namespace { 1189 /// This class manages the state for type and attribute aliases. 1190 class AliasState { 1191 public: 1192 // Initialize the internal aliases. 1193 void 1194 initialize(Operation *op, const OpPrintingFlags &printerFlags, 1195 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 1196 1197 /// Get an alias for the given attribute if it has one and print it in `os`. 1198 /// Returns success if an alias was printed, failure otherwise. 1199 LogicalResult getAlias(Attribute attr, raw_ostream &os) const; 1200 1201 /// Get an alias for the given type if it has one and print it in `os`. 1202 /// Returns success if an alias was printed, failure otherwise. 1203 LogicalResult getAlias(Type ty, raw_ostream &os) const; 1204 1205 /// Print all of the referenced aliases that can not be resolved in a deferred 1206 /// manner. 1207 void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) { 1208 printAliases(p, newLine, /*isDeferred=*/false); 1209 } 1210 1211 /// Print all of the referenced aliases that support deferred resolution. 1212 void printDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) { 1213 printAliases(p, newLine, /*isDeferred=*/true); 1214 } 1215 1216 private: 1217 /// Print all of the referenced aliases that support the provided resolution 1218 /// behavior. 1219 void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, 1220 bool isDeferred); 1221 1222 /// Mapping between attribute/type and alias. 1223 llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias; 1224 1225 /// An allocator used for alias names. 1226 llvm::BumpPtrAllocator aliasAllocator; 1227 }; 1228 } // namespace 1229 1230 void AliasState::initialize( 1231 Operation *op, const OpPrintingFlags &printerFlags, 1232 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 1233 AliasInitializer initializer(interfaces, aliasAllocator); 1234 initializer.initialize(op, printerFlags, attrTypeToAlias); 1235 } 1236 1237 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const { 1238 const auto *it = attrTypeToAlias.find(attr.getAsOpaquePointer()); 1239 if (it == attrTypeToAlias.end()) 1240 return failure(); 1241 it->second.print(os); 1242 return success(); 1243 } 1244 1245 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { 1246 const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer()); 1247 if (it == attrTypeToAlias.end()) 1248 return failure(); 1249 if (!it->second.isPrinted) 1250 return failure(); 1251 1252 it->second.print(os); 1253 return success(); 1254 } 1255 1256 void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, 1257 bool isDeferred) { 1258 auto filterFn = [=](const auto &aliasIt) { 1259 return aliasIt.second.canBeDeferred() == isDeferred; 1260 }; 1261 for (auto &[opaqueSymbol, alias] : 1262 llvm::make_filter_range(attrTypeToAlias, filterFn)) { 1263 alias.print(p.getStream()); 1264 p.getStream() << " = "; 1265 1266 if (alias.isTypeAlias()) { 1267 Type type = Type::getFromOpaquePointer(opaqueSymbol); 1268 p.printTypeImpl(type); 1269 alias.isPrinted = true; 1270 } else { 1271 // TODO: Support nested aliases in mutable attributes. 1272 Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol); 1273 if (attr.hasTrait<AttributeTrait::IsMutable>()) 1274 p.getStream() << attr; 1275 else 1276 p.printAttributeImpl(attr); 1277 } 1278 1279 p.getStream() << newLine; 1280 } 1281 } 1282 1283 //===----------------------------------------------------------------------===// 1284 // SSANameState 1285 //===----------------------------------------------------------------------===// 1286 1287 namespace { 1288 /// Info about block printing: a number which is its position in the visitation 1289 /// order, and a name that is used to print reference to it, e.g. ^bb42. 1290 struct BlockInfo { 1291 int ordering; 1292 StringRef name; 1293 }; 1294 1295 /// This class manages the state of SSA value names. 1296 class SSANameState { 1297 public: 1298 /// A sentinel value used for values with names set. 1299 enum : unsigned { NameSentinel = ~0U }; 1300 1301 SSANameState(Operation *op, const OpPrintingFlags &printerFlags); 1302 SSANameState() = default; 1303 1304 /// Print the SSA identifier for the given value to 'stream'. If 1305 /// 'printResultNo' is true, it also presents the result number ('#' number) 1306 /// of this value. 1307 void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; 1308 1309 /// Print the operation identifier. 1310 void printOperationID(Operation *op, raw_ostream &stream) const; 1311 1312 /// Return the result indices for each of the result groups registered by this 1313 /// operation, or empty if none exist. 1314 ArrayRef<int> getOpResultGroups(Operation *op); 1315 1316 /// Get the info for the given block. 1317 BlockInfo getBlockInfo(Block *block); 1318 1319 /// Renumber the arguments for the specified region to the same names as the 1320 /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for 1321 /// details. 1322 void shadowRegionArgs(Region ®ion, ValueRange namesToUse); 1323 1324 private: 1325 /// Number the SSA values within the given IR unit. 1326 void numberValuesInRegion(Region ®ion); 1327 void numberValuesInBlock(Block &block); 1328 void numberValuesInOp(Operation &op); 1329 1330 /// Given a result of an operation 'result', find the result group head 1331 /// 'lookupValue' and the result of 'result' within that group in 1332 /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group 1333 /// has more than 1 result. 1334 void getResultIDAndNumber(OpResult result, Value &lookupValue, 1335 std::optional<int> &lookupResultNo) const; 1336 1337 /// Set a special value name for the given value. 1338 void setValueName(Value value, StringRef name); 1339 1340 /// Uniques the given value name within the printer. If the given name 1341 /// conflicts, it is automatically renamed. 1342 StringRef uniqueValueName(StringRef name); 1343 1344 /// This is the value ID for each SSA value. If this returns NameSentinel, 1345 /// then the valueID has an entry in valueNames. 1346 DenseMap<Value, unsigned> valueIDs; 1347 DenseMap<Value, StringRef> valueNames; 1348 1349 /// When printing users of values, an operation without a result might 1350 /// be the user. This map holds ids for such operations. 1351 DenseMap<Operation *, unsigned> operationIDs; 1352 1353 /// This is a map of operations that contain multiple named result groups, 1354 /// i.e. there may be multiple names for the results of the operation. The 1355 /// value of this map are the result numbers that start a result group. 1356 DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; 1357 1358 /// This maps blocks to there visitation number in the current region as well 1359 /// as the string representing their name. 1360 DenseMap<Block *, BlockInfo> blockNames; 1361 1362 /// This keeps track of all of the non-numeric names that are in flight, 1363 /// allowing us to check for duplicates. 1364 /// Note: the value of the map is unused. 1365 llvm::ScopedHashTable<StringRef, char> usedNames; 1366 llvm::BumpPtrAllocator usedNameAllocator; 1367 1368 /// This is the next value ID to assign in numbering. 1369 unsigned nextValueID = 0; 1370 /// This is the next ID to assign to a region entry block argument. 1371 unsigned nextArgumentID = 0; 1372 /// This is the next ID to assign when a name conflict is detected. 1373 unsigned nextConflictID = 0; 1374 1375 /// These are the printing flags. They control, eg., whether to print in 1376 /// generic form. 1377 OpPrintingFlags printerFlags; 1378 }; 1379 } // namespace 1380 1381 SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags) 1382 : printerFlags(printerFlags) { 1383 llvm::SaveAndRestore valueIDSaver(nextValueID); 1384 llvm::SaveAndRestore argumentIDSaver(nextArgumentID); 1385 llvm::SaveAndRestore conflictIDSaver(nextConflictID); 1386 1387 // The naming context includes `nextValueID`, `nextArgumentID`, 1388 // `nextConflictID` and `usedNames` scoped HashTable. This information is 1389 // carried from the parent region. 1390 using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy; 1391 using NamingContext = 1392 std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>; 1393 1394 // Allocator for UsedNamesScopeTy 1395 llvm::BumpPtrAllocator allocator; 1396 1397 // Add a scope for the top level operation. 1398 auto *topLevelNamesScope = 1399 new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames); 1400 1401 SmallVector<NamingContext, 8> nameContext; 1402 for (Region ®ion : op->getRegions()) 1403 nameContext.push_back(std::make_tuple(®ion, nextValueID, nextArgumentID, 1404 nextConflictID, topLevelNamesScope)); 1405 1406 numberValuesInOp(*op); 1407 1408 while (!nameContext.empty()) { 1409 Region *region; 1410 UsedNamesScopeTy *parentScope; 1411 1412 if (printerFlags.shouldPrintUniqueSSAIDs()) 1413 // To print unique SSA IDs, ignore saved ID counts from parent regions 1414 std::tie(region, std::ignore, std::ignore, std::ignore, parentScope) = 1415 nameContext.pop_back_val(); 1416 else 1417 std::tie(region, nextValueID, nextArgumentID, nextConflictID, 1418 parentScope) = nameContext.pop_back_val(); 1419 1420 // When we switch from one subtree to another, pop the scopes(needless) 1421 // until the parent scope. 1422 while (usedNames.getCurScope() != parentScope) { 1423 usedNames.getCurScope()->~UsedNamesScopeTy(); 1424 assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) && 1425 "top level parentScope must be a nullptr"); 1426 } 1427 1428 // Add a scope for the current region. 1429 auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>()) 1430 UsedNamesScopeTy(usedNames); 1431 1432 numberValuesInRegion(*region); 1433 1434 for (Operation &op : region->getOps()) 1435 for (Region ®ion : op.getRegions()) 1436 nameContext.push_back(std::make_tuple(®ion, nextValueID, 1437 nextArgumentID, nextConflictID, 1438 curNamesScope)); 1439 } 1440 1441 // Manually remove all the scopes. 1442 while (usedNames.getCurScope() != nullptr) 1443 usedNames.getCurScope()->~UsedNamesScopeTy(); 1444 } 1445 1446 void SSANameState::printValueID(Value value, bool printResultNo, 1447 raw_ostream &stream) const { 1448 if (!value) { 1449 stream << "<<NULL VALUE>>"; 1450 return; 1451 } 1452 1453 std::optional<int> resultNo; 1454 auto lookupValue = value; 1455 1456 // If this is an operation result, collect the head lookup value of the result 1457 // group and the result number of 'result' within that group. 1458 if (OpResult result = dyn_cast<OpResult>(value)) 1459 getResultIDAndNumber(result, lookupValue, resultNo); 1460 1461 auto it = valueIDs.find(lookupValue); 1462 if (it == valueIDs.end()) { 1463 stream << "<<UNKNOWN SSA VALUE>>"; 1464 return; 1465 } 1466 1467 stream << '%'; 1468 if (it->second != NameSentinel) { 1469 stream << it->second; 1470 } else { 1471 auto nameIt = valueNames.find(lookupValue); 1472 assert(nameIt != valueNames.end() && "Didn't have a name entry?"); 1473 stream << nameIt->second; 1474 } 1475 1476 if (resultNo && printResultNo) 1477 stream << '#' << *resultNo; 1478 } 1479 1480 void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const { 1481 auto it = operationIDs.find(op); 1482 if (it == operationIDs.end()) { 1483 stream << "<<UNKNOWN OPERATION>>"; 1484 } else { 1485 stream << '%' << it->second; 1486 } 1487 } 1488 1489 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) { 1490 auto it = opResultGroups.find(op); 1491 return it == opResultGroups.end() ? ArrayRef<int>() : it->second; 1492 } 1493 1494 BlockInfo SSANameState::getBlockInfo(Block *block) { 1495 auto it = blockNames.find(block); 1496 BlockInfo invalidBlock{-1, "INVALIDBLOCK"}; 1497 return it != blockNames.end() ? it->second : invalidBlock; 1498 } 1499 1500 void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { 1501 assert(!region.empty() && "cannot shadow arguments of an empty region"); 1502 assert(region.getNumArguments() == namesToUse.size() && 1503 "incorrect number of names passed in"); 1504 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1505 "only KnownIsolatedFromAbove ops can shadow names"); 1506 1507 SmallVector<char, 16> nameStr; 1508 for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { 1509 auto nameToUse = namesToUse[i]; 1510 if (nameToUse == nullptr) 1511 continue; 1512 auto nameToReplace = region.getArgument(i); 1513 1514 nameStr.clear(); 1515 llvm::raw_svector_ostream nameStream(nameStr); 1516 printValueID(nameToUse, /*printResultNo=*/true, nameStream); 1517 1518 // Entry block arguments should already have a pretty "arg" name. 1519 assert(valueIDs[nameToReplace] == NameSentinel); 1520 1521 // Use the name without the leading %. 1522 auto name = StringRef(nameStream.str()).drop_front(); 1523 1524 // Overwrite the name. 1525 valueNames[nameToReplace] = name.copy(usedNameAllocator); 1526 } 1527 } 1528 1529 namespace { 1530 /// Try to get value name from value's location, fallback to `name`. 1531 StringRef maybeGetValueNameFromLoc(Value value, StringRef name) { 1532 if (auto maybeNameLoc = value.getLoc()->findInstanceOf<NameLoc>()) 1533 return maybeNameLoc.getName(); 1534 return name; 1535 } 1536 } // namespace 1537 1538 void SSANameState::numberValuesInRegion(Region ®ion) { 1539 auto setBlockArgNameFn = [&](Value arg, StringRef name) { 1540 assert(!valueIDs.count(arg) && "arg numbered multiple times"); 1541 assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion && 1542 "arg not defined in current region"); 1543 if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) 1544 name = maybeGetValueNameFromLoc(arg, name); 1545 setValueName(arg, name); 1546 }; 1547 1548 if (!printerFlags.shouldPrintGenericOpForm()) { 1549 if (Operation *op = region.getParentOp()) { 1550 if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op)) 1551 asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn); 1552 } 1553 } 1554 1555 // Number the values within this region in a breadth-first order. 1556 unsigned nextBlockID = 0; 1557 for (auto &block : region) { 1558 // Each block gets a unique ID, and all of the operations within it get 1559 // numbered as well. 1560 auto blockInfoIt = blockNames.insert({&block, {-1, ""}}); 1561 if (blockInfoIt.second) { 1562 // This block hasn't been named through `getAsmBlockArgumentNames`, use 1563 // default `^bbNNN` format. 1564 std::string name; 1565 llvm::raw_string_ostream(name) << "^bb" << nextBlockID; 1566 blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator); 1567 } 1568 blockInfoIt.first->second.ordering = nextBlockID++; 1569 1570 numberValuesInBlock(block); 1571 } 1572 } 1573 1574 void SSANameState::numberValuesInBlock(Block &block) { 1575 // Number the block arguments. We give entry block arguments a special name 1576 // 'arg'. 1577 bool isEntryBlock = block.isEntryBlock(); 1578 SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); 1579 llvm::raw_svector_ostream specialName(specialNameBuffer); 1580 for (auto arg : block.getArguments()) { 1581 if (valueIDs.count(arg)) 1582 continue; 1583 if (isEntryBlock) { 1584 specialNameBuffer.resize(strlen("arg")); 1585 specialName << nextArgumentID++; 1586 } 1587 StringRef specialNameStr = specialName.str(); 1588 if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) 1589 specialNameStr = maybeGetValueNameFromLoc(arg, specialNameStr); 1590 setValueName(arg, specialNameStr); 1591 } 1592 1593 // Number the operations in this block. 1594 for (auto &op : block) 1595 numberValuesInOp(op); 1596 } 1597 1598 void SSANameState::numberValuesInOp(Operation &op) { 1599 // Function used to set the special result names for the operation. 1600 SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); 1601 auto setResultNameFn = [&](Value result, StringRef name) { 1602 assert(!valueIDs.count(result) && "result numbered multiple times"); 1603 assert(result.getDefiningOp() == &op && "result not defined by 'op'"); 1604 if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) 1605 name = maybeGetValueNameFromLoc(result, name); 1606 setValueName(result, name); 1607 1608 // Record the result number for groups not anchored at 0. 1609 if (int resultNo = llvm::cast<OpResult>(result).getResultNumber()) 1610 resultGroups.push_back(resultNo); 1611 }; 1612 // Operations can customize the printing of block names in OpAsmOpInterface. 1613 auto setBlockNameFn = [&](Block *block, StringRef name) { 1614 assert(block->getParentOp() == &op && 1615 "getAsmBlockArgumentNames callback invoked on a block not directly " 1616 "nested under the current operation"); 1617 assert(!blockNames.count(block) && "block numbered multiple times"); 1618 SmallString<16> tmpBuffer{"^"}; 1619 name = sanitizeIdentifier(name, tmpBuffer); 1620 if (name.data() != tmpBuffer.data()) { 1621 tmpBuffer.append(name); 1622 name = tmpBuffer.str(); 1623 } 1624 name = name.copy(usedNameAllocator); 1625 blockNames[block] = {-1, name}; 1626 }; 1627 1628 if (!printerFlags.shouldPrintGenericOpForm()) { 1629 if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) { 1630 asmInterface.getAsmBlockNames(setBlockNameFn); 1631 asmInterface.getAsmResultNames(setResultNameFn); 1632 } 1633 } 1634 1635 unsigned numResults = op.getNumResults(); 1636 if (numResults == 0) { 1637 // If value users should be printed, operations with no result need an id. 1638 if (printerFlags.shouldPrintValueUsers()) { 1639 if (operationIDs.try_emplace(&op, nextValueID).second) 1640 ++nextValueID; 1641 } 1642 return; 1643 } 1644 Value resultBegin = op.getResult(0); 1645 1646 if (printerFlags.shouldUseNameLocAsPrefix() && !valueIDs.count(resultBegin)) { 1647 if (auto nameLoc = resultBegin.getLoc()->findInstanceOf<NameLoc>()) { 1648 setValueName(resultBegin, nameLoc.getName()); 1649 } 1650 } 1651 1652 // If the first result wasn't numbered, give it a default number. 1653 if (valueIDs.try_emplace(resultBegin, nextValueID).second) 1654 ++nextValueID; 1655 1656 // If this operation has multiple result groups, mark it. 1657 if (resultGroups.size() != 1) { 1658 llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); 1659 opResultGroups.try_emplace(&op, std::move(resultGroups)); 1660 } 1661 } 1662 1663 void SSANameState::getResultIDAndNumber( 1664 OpResult result, Value &lookupValue, 1665 std::optional<int> &lookupResultNo) const { 1666 Operation *owner = result.getOwner(); 1667 if (owner->getNumResults() == 1) 1668 return; 1669 int resultNo = result.getResultNumber(); 1670 1671 // If this operation has multiple result groups, we will need to find the 1672 // one corresponding to this result. 1673 auto resultGroupIt = opResultGroups.find(owner); 1674 if (resultGroupIt == opResultGroups.end()) { 1675 // If not, just use the first result. 1676 lookupResultNo = resultNo; 1677 lookupValue = owner->getResult(0); 1678 return; 1679 } 1680 1681 // Find the correct index using a binary search, as the groups are ordered. 1682 ArrayRef<int> resultGroups = resultGroupIt->second; 1683 const auto *it = llvm::upper_bound(resultGroups, resultNo); 1684 int groupResultNo = 0, groupSize = 0; 1685 1686 // If there are no smaller elements, the last result group is the lookup. 1687 if (it == resultGroups.end()) { 1688 groupResultNo = resultGroups.back(); 1689 groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); 1690 } else { 1691 // Otherwise, the previous element is the lookup. 1692 groupResultNo = *std::prev(it); 1693 groupSize = *it - groupResultNo; 1694 } 1695 1696 // We only record the result number for a group of size greater than 1. 1697 if (groupSize != 1) 1698 lookupResultNo = resultNo - groupResultNo; 1699 lookupValue = owner->getResult(groupResultNo); 1700 } 1701 1702 void SSANameState::setValueName(Value value, StringRef name) { 1703 // If the name is empty, the value uses the default numbering. 1704 if (name.empty()) { 1705 valueIDs[value] = nextValueID++; 1706 return; 1707 } 1708 1709 valueIDs[value] = NameSentinel; 1710 valueNames[value] = uniqueValueName(name); 1711 } 1712 1713 StringRef SSANameState::uniqueValueName(StringRef name) { 1714 SmallString<16> tmpBuffer; 1715 name = sanitizeIdentifier(name, tmpBuffer); 1716 1717 // Check to see if this name is already unique. 1718 if (!usedNames.count(name)) { 1719 name = name.copy(usedNameAllocator); 1720 } else { 1721 // Otherwise, we had a conflict - probe until we find a unique name. This 1722 // is guaranteed to terminate (and usually in a single iteration) because it 1723 // generates new names by incrementing nextConflictID. 1724 SmallString<64> probeName(name); 1725 probeName.push_back('_'); 1726 while (true) { 1727 probeName += llvm::utostr(nextConflictID++); 1728 if (!usedNames.count(probeName)) { 1729 name = probeName.str().copy(usedNameAllocator); 1730 break; 1731 } 1732 probeName.resize(name.size() + 1); 1733 } 1734 } 1735 1736 usedNames.insert(name, char()); 1737 return name; 1738 } 1739 1740 //===----------------------------------------------------------------------===// 1741 // DistinctState 1742 //===----------------------------------------------------------------------===// 1743 1744 namespace { 1745 /// This class manages the state for distinct attributes. 1746 class DistinctState { 1747 public: 1748 /// Returns a unique identifier for the given distinct attribute. 1749 uint64_t getId(DistinctAttr distinctAttr); 1750 1751 private: 1752 uint64_t distinctCounter = 0; 1753 DenseMap<DistinctAttr, uint64_t> distinctAttrMap; 1754 }; 1755 } // namespace 1756 1757 uint64_t DistinctState::getId(DistinctAttr distinctAttr) { 1758 auto [it, inserted] = 1759 distinctAttrMap.try_emplace(distinctAttr, distinctCounter); 1760 if (inserted) 1761 distinctCounter++; 1762 return it->getSecond(); 1763 } 1764 1765 //===----------------------------------------------------------------------===// 1766 // Resources 1767 //===----------------------------------------------------------------------===// 1768 1769 AsmParsedResourceEntry::~AsmParsedResourceEntry() = default; 1770 AsmResourceBuilder::~AsmResourceBuilder() = default; 1771 AsmResourceParser::~AsmResourceParser() = default; 1772 AsmResourcePrinter::~AsmResourcePrinter() = default; 1773 1774 StringRef mlir::toString(AsmResourceEntryKind kind) { 1775 switch (kind) { 1776 case AsmResourceEntryKind::Blob: 1777 return "blob"; 1778 case AsmResourceEntryKind::Bool: 1779 return "bool"; 1780 case AsmResourceEntryKind::String: 1781 return "string"; 1782 } 1783 llvm_unreachable("unknown AsmResourceEntryKind"); 1784 } 1785 1786 AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) { 1787 std::unique_ptr<ResourceCollection> &collection = keyToResources[key.str()]; 1788 if (!collection) 1789 collection = std::make_unique<ResourceCollection>(key); 1790 return *collection; 1791 } 1792 1793 std::vector<std::unique_ptr<AsmResourcePrinter>> 1794 FallbackAsmResourceMap::getPrinters() { 1795 std::vector<std::unique_ptr<AsmResourcePrinter>> printers; 1796 for (auto &it : keyToResources) { 1797 ResourceCollection *collection = it.second.get(); 1798 auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) { 1799 return collection->buildResources(op, builder); 1800 }; 1801 printers.emplace_back( 1802 AsmResourcePrinter::fromCallable(collection->getName(), buildValues)); 1803 } 1804 return printers; 1805 } 1806 1807 LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource( 1808 AsmParsedResourceEntry &entry) { 1809 switch (entry.getKind()) { 1810 case AsmResourceEntryKind::Blob: { 1811 FailureOr<AsmResourceBlob> blob = entry.parseAsBlob(); 1812 if (failed(blob)) 1813 return failure(); 1814 resources.emplace_back(entry.getKey(), std::move(*blob)); 1815 return success(); 1816 } 1817 case AsmResourceEntryKind::Bool: { 1818 FailureOr<bool> value = entry.parseAsBool(); 1819 if (failed(value)) 1820 return failure(); 1821 resources.emplace_back(entry.getKey(), *value); 1822 break; 1823 } 1824 case AsmResourceEntryKind::String: { 1825 FailureOr<std::string> str = entry.parseAsString(); 1826 if (failed(str)) 1827 return failure(); 1828 resources.emplace_back(entry.getKey(), std::move(*str)); 1829 break; 1830 } 1831 } 1832 return success(); 1833 } 1834 1835 void FallbackAsmResourceMap::ResourceCollection::buildResources( 1836 Operation *op, AsmResourceBuilder &builder) const { 1837 for (const auto &entry : resources) { 1838 if (const auto *value = std::get_if<AsmResourceBlob>(&entry.value)) 1839 builder.buildBlob(entry.key, *value); 1840 else if (const auto *value = std::get_if<bool>(&entry.value)) 1841 builder.buildBool(entry.key, *value); 1842 else if (const auto *value = std::get_if<std::string>(&entry.value)) 1843 builder.buildString(entry.key, *value); 1844 else 1845 llvm_unreachable("unknown AsmResourceEntryKind"); 1846 } 1847 } 1848 1849 //===----------------------------------------------------------------------===// 1850 // AsmState 1851 //===----------------------------------------------------------------------===// 1852 1853 namespace mlir { 1854 namespace detail { 1855 class AsmStateImpl { 1856 public: 1857 explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, 1858 AsmState::LocationMap *locationMap) 1859 : interfaces(op->getContext()), nameState(op, printerFlags), 1860 printerFlags(printerFlags), locationMap(locationMap) {} 1861 explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags, 1862 AsmState::LocationMap *locationMap) 1863 : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {} 1864 1865 /// Initialize the alias state to enable the printing of aliases. 1866 void initializeAliases(Operation *op) { 1867 aliasState.initialize(op, printerFlags, interfaces); 1868 } 1869 1870 /// Get the state used for aliases. 1871 AliasState &getAliasState() { return aliasState; } 1872 1873 /// Get the state used for SSA names. 1874 SSANameState &getSSANameState() { return nameState; } 1875 1876 /// Get the state used for distinct attribute identifiers. 1877 DistinctState &getDistinctState() { return distinctState; } 1878 1879 /// Return the dialects within the context that implement 1880 /// OpAsmDialectInterface. 1881 DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() { 1882 return interfaces; 1883 } 1884 1885 /// Return the non-dialect resource printers. 1886 auto getResourcePrinters() { 1887 return llvm::make_pointee_range(externalResourcePrinters); 1888 } 1889 1890 /// Get the printer flags. 1891 const OpPrintingFlags &getPrinterFlags() const { return printerFlags; } 1892 1893 /// Register the location, line and column, within the buffer that the given 1894 /// operation was printed at. 1895 void registerOperationLocation(Operation *op, unsigned line, unsigned col) { 1896 if (locationMap) 1897 (*locationMap)[op] = std::make_pair(line, col); 1898 } 1899 1900 /// Return the referenced dialect resources within the printer. 1901 DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> & 1902 getDialectResources() { 1903 return dialectResources; 1904 } 1905 1906 LogicalResult pushCyclicPrinting(const void *opaquePointer) { 1907 return success(cyclicPrintingStack.insert(opaquePointer)); 1908 } 1909 1910 void popCyclicPrinting() { cyclicPrintingStack.pop_back(); } 1911 1912 private: 1913 /// Collection of OpAsm interfaces implemented in the context. 1914 DialectInterfaceCollection<OpAsmDialectInterface> interfaces; 1915 1916 /// A collection of non-dialect resource printers. 1917 SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters; 1918 1919 /// A set of dialect resources that were referenced during printing. 1920 DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources; 1921 1922 /// The state used for attribute and type aliases. 1923 AliasState aliasState; 1924 1925 /// The state used for SSA value names. 1926 SSANameState nameState; 1927 1928 /// The state used for distinct attribute identifiers. 1929 DistinctState distinctState; 1930 1931 /// Flags that control op output. 1932 OpPrintingFlags printerFlags; 1933 1934 /// An optional location map to be populated. 1935 AsmState::LocationMap *locationMap; 1936 1937 /// Stack of potentially cyclic mutable attributes or type currently being 1938 /// printed. 1939 SetVector<const void *> cyclicPrintingStack; 1940 1941 // Allow direct access to the impl fields. 1942 friend AsmState; 1943 }; 1944 1945 template <typename Range> 1946 void printDimensionList(raw_ostream &stream, Range &&shape) { 1947 llvm::interleave( 1948 shape, stream, 1949 [&stream](const auto &dimSize) { 1950 if (ShapedType::isDynamic(dimSize)) 1951 stream << "?"; 1952 else 1953 stream << dimSize; 1954 }, 1955 "x"); 1956 } 1957 1958 } // namespace detail 1959 } // namespace mlir 1960 1961 /// Verifies the operation and switches to generic op printing if verification 1962 /// fails. We need to do this because custom print functions may fail for 1963 /// invalid ops. 1964 static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, 1965 OpPrintingFlags printerFlags) { 1966 if (printerFlags.shouldPrintGenericOpForm() || 1967 printerFlags.shouldAssumeVerified()) 1968 return printerFlags; 1969 1970 // Ignore errors emitted by the verifier. We check the thread id to avoid 1971 // consuming other threads' errors. 1972 auto parentThreadId = llvm::get_threadid(); 1973 ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) { 1974 if (parentThreadId == llvm::get_threadid()) { 1975 LLVM_DEBUG({ 1976 diag.print(llvm::dbgs()); 1977 llvm::dbgs() << "\n"; 1978 }); 1979 return success(); 1980 } 1981 return failure(); 1982 }); 1983 if (failed(verify(op))) { 1984 LLVM_DEBUG(llvm::dbgs() 1985 << DEBUG_TYPE << ": '" << op->getName() 1986 << "' failed to verify and will be printed in generic form\n"); 1987 printerFlags.printGenericOpForm(); 1988 } 1989 1990 return printerFlags; 1991 } 1992 1993 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, 1994 LocationMap *locationMap, FallbackAsmResourceMap *map) 1995 : impl(std::make_unique<AsmStateImpl>( 1996 op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) { 1997 if (map) 1998 attachFallbackResourcePrinter(*map); 1999 } 2000 AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags, 2001 LocationMap *locationMap, FallbackAsmResourceMap *map) 2002 : impl(std::make_unique<AsmStateImpl>(ctx, printerFlags, locationMap)) { 2003 if (map) 2004 attachFallbackResourcePrinter(*map); 2005 } 2006 AsmState::~AsmState() = default; 2007 2008 const OpPrintingFlags &AsmState::getPrinterFlags() const { 2009 return impl->getPrinterFlags(); 2010 } 2011 2012 void AsmState::attachResourcePrinter( 2013 std::unique_ptr<AsmResourcePrinter> printer) { 2014 impl->externalResourcePrinters.emplace_back(std::move(printer)); 2015 } 2016 2017 DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> & 2018 AsmState::getDialectResources() const { 2019 return impl->getDialectResources(); 2020 } 2021 2022 //===----------------------------------------------------------------------===// 2023 // AsmPrinter::Impl 2024 //===----------------------------------------------------------------------===// 2025 2026 AsmPrinter::Impl::Impl(raw_ostream &os, AsmStateImpl &state) 2027 : os(os), state(state), printerFlags(state.getPrinterFlags()) {} 2028 2029 void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) { 2030 // Check to see if we are printing debug information. 2031 if (!printerFlags.shouldPrintDebugInfo()) 2032 return; 2033 2034 os << " "; 2035 printLocation(loc, /*allowAlias=*/allowAlias); 2036 } 2037 2038 void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty, 2039 bool isTopLevel) { 2040 // If this isn't a top-level location, check for an alias. 2041 if (!isTopLevel && succeeded(state.getAliasState().getAlias(loc, os))) 2042 return; 2043 2044 TypeSwitch<LocationAttr>(loc) 2045 .Case<OpaqueLoc>([&](OpaqueLoc loc) { 2046 printLocationInternal(loc.getFallbackLocation(), pretty); 2047 }) 2048 .Case<UnknownLoc>([&](UnknownLoc loc) { 2049 if (pretty) 2050 os << "[unknown]"; 2051 else 2052 os << "unknown"; 2053 }) 2054 .Case<FileLineColRange>([&](FileLineColRange loc) { 2055 if (pretty) 2056 os << loc.getFilename().getValue(); 2057 else 2058 printEscapedString(loc.getFilename()); 2059 if (loc.getEndColumn() == loc.getStartColumn() && 2060 loc.getStartLine() == loc.getEndLine()) { 2061 os << ':' << loc.getStartLine() << ':' << loc.getStartColumn(); 2062 return; 2063 } 2064 if (loc.getStartLine() == loc.getEndLine()) { 2065 os << ':' << loc.getStartLine() << ':' << loc.getStartColumn() 2066 << " to :" << loc.getEndColumn(); 2067 return; 2068 } 2069 os << ':' << loc.getStartLine() << ':' << loc.getStartColumn() << " to " 2070 << loc.getEndLine() << ':' << loc.getEndColumn(); 2071 }) 2072 .Case<NameLoc>([&](NameLoc loc) { 2073 printEscapedString(loc.getName()); 2074 2075 // Print the child if it isn't unknown. 2076 auto childLoc = loc.getChildLoc(); 2077 if (!llvm::isa<UnknownLoc>(childLoc)) { 2078 os << '('; 2079 printLocationInternal(childLoc, pretty); 2080 os << ')'; 2081 } 2082 }) 2083 .Case<CallSiteLoc>([&](CallSiteLoc loc) { 2084 Location caller = loc.getCaller(); 2085 Location callee = loc.getCallee(); 2086 if (!pretty) 2087 os << "callsite("; 2088 printLocationInternal(callee, pretty); 2089 if (pretty) { 2090 if (llvm::isa<NameLoc>(callee)) { 2091 if (llvm::isa<FileLineColLoc>(caller)) { 2092 os << " at "; 2093 } else { 2094 os << newLine << " at "; 2095 } 2096 } else { 2097 os << newLine << " at "; 2098 } 2099 } else { 2100 os << " at "; 2101 } 2102 printLocationInternal(caller, pretty); 2103 if (!pretty) 2104 os << ")"; 2105 }) 2106 .Case<FusedLoc>([&](FusedLoc loc) { 2107 if (!pretty) 2108 os << "fused"; 2109 if (Attribute metadata = loc.getMetadata()) { 2110 os << '<'; 2111 printAttribute(metadata); 2112 os << '>'; 2113 } 2114 os << '['; 2115 interleave( 2116 loc.getLocations(), 2117 [&](Location loc) { printLocationInternal(loc, pretty); }, 2118 [&]() { os << ", "; }); 2119 os << ']'; 2120 }) 2121 .Default([&](LocationAttr loc) { 2122 // Assumes that this is a dialect-specific attribute and prints it 2123 // directly. 2124 printAttribute(loc); 2125 }); 2126 } 2127 2128 /// Print a floating point value in a way that the parser will be able to 2129 /// round-trip losslessly. 2130 static void printFloatValue(const APFloat &apValue, raw_ostream &os, 2131 bool *printedHex = nullptr) { 2132 // We would like to output the FP constant value in exponential notation, 2133 // but we cannot do this if doing so will lose precision. Check here to 2134 // make sure that we only output it in exponential format if we can parse 2135 // the value back and get the same value. 2136 bool isInf = apValue.isInfinity(); 2137 bool isNaN = apValue.isNaN(); 2138 if (!isInf && !isNaN) { 2139 SmallString<128> strValue; 2140 apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, 2141 /*TruncateZero=*/false); 2142 2143 // Check to make sure that the stringized number is not some string like 2144 // "Inf" or NaN, that atof will accept, but the lexer will not. Check 2145 // that the string matches the "[-+]?[0-9]" regex. 2146 assert(((strValue[0] >= '0' && strValue[0] <= '9') || 2147 ((strValue[0] == '-' || strValue[0] == '+') && 2148 (strValue[1] >= '0' && strValue[1] <= '9'))) && 2149 "[-+]?[0-9] regex does not match!"); 2150 2151 // Parse back the stringized version and check that the value is equal 2152 // (i.e., there is no precision loss). 2153 if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { 2154 os << strValue; 2155 return; 2156 } 2157 2158 // If it is not, use the default format of APFloat instead of the 2159 // exponential notation. 2160 strValue.clear(); 2161 apValue.toString(strValue); 2162 2163 // Make sure that we can parse the default form as a float. 2164 if (strValue.str().contains('.')) { 2165 os << strValue; 2166 return; 2167 } 2168 } 2169 2170 // Print special values in hexadecimal format. The sign bit should be included 2171 // in the literal. 2172 if (printedHex) 2173 *printedHex = true; 2174 SmallVector<char, 16> str; 2175 APInt apInt = apValue.bitcastToAPInt(); 2176 apInt.toString(str, /*Radix=*/16, /*Signed=*/false, 2177 /*formatAsCLiteral=*/true); 2178 os << str; 2179 } 2180 2181 void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) { 2182 if (printerFlags.shouldPrintDebugInfoPrettyForm()) 2183 return printLocationInternal(loc, /*pretty=*/true, /*isTopLevel=*/true); 2184 2185 os << "loc("; 2186 if (!allowAlias || failed(printAlias(loc))) 2187 printLocationInternal(loc, /*pretty=*/false, /*isTopLevel=*/true); 2188 os << ')'; 2189 } 2190 2191 void AsmPrinter::Impl::printResourceHandle( 2192 const AsmDialectResourceHandle &resource) { 2193 auto *interface = cast<OpAsmDialectInterface>(resource.getDialect()); 2194 os << interface->getResourceKey(resource); 2195 state.getDialectResources()[resource.getDialect()].insert(resource); 2196 } 2197 2198 /// Returns true if the given dialect symbol data is simple enough to print in 2199 /// the pretty form. This is essentially when the symbol takes the form: 2200 /// identifier (`<` body `>`)? 2201 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { 2202 // The name must start with an identifier. 2203 if (symName.empty() || !isalpha(symName.front())) 2204 return false; 2205 2206 // Ignore all the characters that are valid in an identifier in the symbol 2207 // name. 2208 symName = symName.drop_while( 2209 [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; }); 2210 if (symName.empty()) 2211 return true; 2212 2213 // If we got to an unexpected character, then it must be a <>. Check that the 2214 // rest of the symbol is wrapped within <>. 2215 return symName.front() == '<' && symName.back() == '>'; 2216 } 2217 2218 /// Print the given dialect symbol to the stream. 2219 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, 2220 StringRef dialectName, StringRef symString) { 2221 os << symPrefix << dialectName; 2222 2223 // If this symbol name is simple enough, print it directly in pretty form, 2224 // otherwise, we print it as an escaped string. 2225 if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { 2226 os << '.' << symString; 2227 return; 2228 } 2229 2230 os << '<' << symString << '>'; 2231 } 2232 2233 /// Returns true if the given string can be represented as a bare identifier. 2234 static bool isBareIdentifier(StringRef name) { 2235 // By making this unsigned, the value passed in to isalnum will always be 2236 // in the range 0-255. This is important when building with MSVC because 2237 // its implementation will assert. This situation can arise when dealing 2238 // with UTF-8 multibyte characters. 2239 if (name.empty() || (!isalpha(name[0]) && name[0] != '_')) 2240 return false; 2241 return llvm::all_of(name.drop_front(), [](unsigned char c) { 2242 return isalnum(c) || c == '_' || c == '$' || c == '.'; 2243 }); 2244 } 2245 2246 /// Print the given string as a keyword, or a quoted and escaped string if it 2247 /// has any special or non-printable characters in it. 2248 static void printKeywordOrString(StringRef keyword, raw_ostream &os) { 2249 // If it can be represented as a bare identifier, write it directly. 2250 if (isBareIdentifier(keyword)) { 2251 os << keyword; 2252 return; 2253 } 2254 2255 // Otherwise, output the keyword wrapped in quotes with proper escaping. 2256 os << "\""; 2257 printEscapedString(keyword, os); 2258 os << '"'; 2259 } 2260 2261 /// Print the given string as a symbol reference. A symbol reference is 2262 /// represented as a string prefixed with '@'. The reference is surrounded with 2263 /// ""'s and escaped if it has any special or non-printable characters in it. 2264 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { 2265 if (symbolRef.empty()) { 2266 os << "@<<INVALID EMPTY SYMBOL>>"; 2267 return; 2268 } 2269 os << '@'; 2270 printKeywordOrString(symbolRef, os); 2271 } 2272 2273 // Print out a valid ElementsAttr that is succinct and can represent any 2274 // potential shape/type, for use when eliding a large ElementsAttr. 2275 // 2276 // We choose to use a dense resource ElementsAttr literal with conspicuous 2277 // content to hopefully alert readers to the fact that this has been elided. 2278 static void printElidedElementsAttr(raw_ostream &os) { 2279 os << R"(dense_resource<__elided__>)"; 2280 } 2281 2282 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) { 2283 return state.getAliasState().getAlias(attr, os); 2284 } 2285 2286 LogicalResult AsmPrinter::Impl::printAlias(Type type) { 2287 return state.getAliasState().getAlias(type, os); 2288 } 2289 2290 void AsmPrinter::Impl::printAttribute(Attribute attr, 2291 AttrTypeElision typeElision) { 2292 if (!attr) { 2293 os << "<<NULL ATTRIBUTE>>"; 2294 return; 2295 } 2296 2297 // Try to print an alias for this attribute. 2298 if (succeeded(printAlias(attr))) 2299 return; 2300 return printAttributeImpl(attr, typeElision); 2301 } 2302 2303 void AsmPrinter::Impl::printAttributeImpl(Attribute attr, 2304 AttrTypeElision typeElision) { 2305 if (!isa<BuiltinDialect>(attr.getDialect())) { 2306 printDialectAttribute(attr); 2307 } else if (auto opaqueAttr = llvm::dyn_cast<OpaqueAttr>(attr)) { 2308 printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), 2309 opaqueAttr.getAttrData()); 2310 } else if (llvm::isa<UnitAttr>(attr)) { 2311 os << "unit"; 2312 return; 2313 } else if (auto distinctAttr = llvm::dyn_cast<DistinctAttr>(attr)) { 2314 os << "distinct[" << state.getDistinctState().getId(distinctAttr) << "]<"; 2315 if (!llvm::isa<UnitAttr>(distinctAttr.getReferencedAttr())) { 2316 printAttribute(distinctAttr.getReferencedAttr()); 2317 } 2318 os << '>'; 2319 return; 2320 } else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) { 2321 os << '{'; 2322 interleaveComma(dictAttr.getValue(), 2323 [&](NamedAttribute attr) { printNamedAttribute(attr); }); 2324 os << '}'; 2325 2326 } else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) { 2327 Type intType = intAttr.getType(); 2328 if (intType.isSignlessInteger(1)) { 2329 os << (intAttr.getValue().getBoolValue() ? "true" : "false"); 2330 2331 // Boolean integer attributes always elides the type. 2332 return; 2333 } 2334 2335 // Only print attributes as unsigned if they are explicitly unsigned or are 2336 // signless 1-bit values. Indexes, signed values, and multi-bit signless 2337 // values print as signed. 2338 bool isUnsigned = 2339 intType.isUnsignedInteger() || intType.isSignlessInteger(1); 2340 intAttr.getValue().print(os, !isUnsigned); 2341 2342 // IntegerAttr elides the type if I64. 2343 if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(64)) 2344 return; 2345 2346 } else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attr)) { 2347 bool printedHex = false; 2348 printFloatValue(floatAttr.getValue(), os, &printedHex); 2349 2350 // FloatAttr elides the type if F64. 2351 if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64() && 2352 !printedHex) 2353 return; 2354 2355 } else if (auto strAttr = llvm::dyn_cast<StringAttr>(attr)) { 2356 printEscapedString(strAttr.getValue()); 2357 2358 } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) { 2359 os << '['; 2360 interleaveComma(arrayAttr.getValue(), [&](Attribute attr) { 2361 printAttribute(attr, AttrTypeElision::May); 2362 }); 2363 os << ']'; 2364 2365 } else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) { 2366 os << "affine_map<"; 2367 affineMapAttr.getValue().print(os); 2368 os << '>'; 2369 2370 // AffineMap always elides the type. 2371 return; 2372 2373 } else if (auto integerSetAttr = llvm::dyn_cast<IntegerSetAttr>(attr)) { 2374 os << "affine_set<"; 2375 integerSetAttr.getValue().print(os); 2376 os << '>'; 2377 2378 // IntegerSet always elides the type. 2379 return; 2380 2381 } else if (auto typeAttr = llvm::dyn_cast<TypeAttr>(attr)) { 2382 printType(typeAttr.getValue()); 2383 2384 } else if (auto refAttr = llvm::dyn_cast<SymbolRefAttr>(attr)) { 2385 printSymbolReference(refAttr.getRootReference().getValue(), os); 2386 for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { 2387 os << "::"; 2388 printSymbolReference(nestedRef.getValue(), os); 2389 } 2390 2391 } else if (auto intOrFpEltAttr = 2392 llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) { 2393 if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) { 2394 printElidedElementsAttr(os); 2395 } else { 2396 os << "dense<"; 2397 printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true); 2398 os << '>'; 2399 } 2400 2401 } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) { 2402 if (printerFlags.shouldElideElementsAttr(strEltAttr)) { 2403 printElidedElementsAttr(os); 2404 } else { 2405 os << "dense<"; 2406 printDenseStringElementsAttr(strEltAttr); 2407 os << '>'; 2408 } 2409 2410 } else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) { 2411 if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) || 2412 printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) { 2413 printElidedElementsAttr(os); 2414 } else { 2415 os << "sparse<"; 2416 DenseIntElementsAttr indices = sparseEltAttr.getIndices(); 2417 if (indices.getNumElements() != 0) { 2418 printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false); 2419 os << ", "; 2420 printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true); 2421 } 2422 os << '>'; 2423 } 2424 } else if (auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(attr)) { 2425 stridedLayoutAttr.print(os); 2426 } else if (auto denseArrayAttr = llvm::dyn_cast<DenseArrayAttr>(attr)) { 2427 os << "array<"; 2428 printType(denseArrayAttr.getElementType()); 2429 if (!denseArrayAttr.empty()) { 2430 os << ": "; 2431 printDenseArrayAttr(denseArrayAttr); 2432 } 2433 os << ">"; 2434 return; 2435 } else if (auto resourceAttr = 2436 llvm::dyn_cast<DenseResourceElementsAttr>(attr)) { 2437 os << "dense_resource<"; 2438 printResourceHandle(resourceAttr.getRawHandle()); 2439 os << ">"; 2440 } else if (auto locAttr = llvm::dyn_cast<LocationAttr>(attr)) { 2441 printLocation(locAttr); 2442 } else { 2443 llvm::report_fatal_error("Unknown builtin attribute"); 2444 } 2445 // Don't print the type if we must elide it, or if it is a None type. 2446 if (typeElision != AttrTypeElision::Must) { 2447 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) { 2448 Type attrType = typedAttr.getType(); 2449 if (!llvm::isa<NoneType>(attrType)) { 2450 os << " : "; 2451 printType(attrType); 2452 } 2453 } 2454 } 2455 } 2456 2457 /// Print the integer element of a DenseElementsAttr. 2458 static void printDenseIntElement(const APInt &value, raw_ostream &os, 2459 Type type) { 2460 if (type.isInteger(1)) 2461 os << (value.getBoolValue() ? "true" : "false"); 2462 else 2463 value.print(os, !type.isUnsignedInteger()); 2464 } 2465 2466 static void 2467 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, 2468 function_ref<void(unsigned)> printEltFn) { 2469 // Special case for 0-d and splat tensors. 2470 if (isSplat) 2471 return printEltFn(0); 2472 2473 // Special case for degenerate tensors. 2474 auto numElements = type.getNumElements(); 2475 if (numElements == 0) 2476 return; 2477 2478 // We use a mixed-radix counter to iterate through the shape. When we bump a 2479 // non-least-significant digit, we emit a close bracket. When we next emit an 2480 // element we re-open all closed brackets. 2481 2482 // The mixed-radix counter, with radices in 'shape'. 2483 int64_t rank = type.getRank(); 2484 SmallVector<unsigned, 4> counter(rank, 0); 2485 // The number of brackets that have been opened and not closed. 2486 unsigned openBrackets = 0; 2487 2488 auto shape = type.getShape(); 2489 auto bumpCounter = [&] { 2490 // Bump the least significant digit. 2491 ++counter[rank - 1]; 2492 // Iterate backwards bubbling back the increment. 2493 for (unsigned i = rank - 1; i > 0; --i) 2494 if (counter[i] >= shape[i]) { 2495 // Index 'i' is rolled over. Bump (i-1) and close a bracket. 2496 counter[i] = 0; 2497 ++counter[i - 1]; 2498 --openBrackets; 2499 os << ']'; 2500 } 2501 }; 2502 2503 for (unsigned idx = 0, e = numElements; idx != e; ++idx) { 2504 if (idx != 0) 2505 os << ", "; 2506 while (openBrackets++ < rank) 2507 os << '['; 2508 openBrackets = rank; 2509 printEltFn(idx); 2510 bumpCounter(); 2511 } 2512 while (openBrackets-- > 0) 2513 os << ']'; 2514 } 2515 2516 void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr, 2517 bool allowHex) { 2518 if (auto stringAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) 2519 return printDenseStringElementsAttr(stringAttr); 2520 2521 printDenseIntOrFPElementsAttr(llvm::cast<DenseIntOrFPElementsAttr>(attr), 2522 allowHex); 2523 } 2524 2525 void AsmPrinter::Impl::printDenseIntOrFPElementsAttr( 2526 DenseIntOrFPElementsAttr attr, bool allowHex) { 2527 auto type = attr.getType(); 2528 auto elementType = type.getElementType(); 2529 2530 // Check to see if we should format this attribute as a hex string. 2531 if (allowHex && printerFlags.shouldPrintElementsAttrWithHex(attr)) { 2532 ArrayRef<char> rawData = attr.getRawData(); 2533 if (llvm::endianness::native == llvm::endianness::big) { 2534 // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE 2535 // machines. It is converted here to print in LE format. 2536 SmallVector<char, 64> outDataVec(rawData.size()); 2537 MutableArrayRef<char> convRawData(outDataVec); 2538 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 2539 rawData, convRawData, type); 2540 printHexString(convRawData); 2541 } else { 2542 printHexString(rawData); 2543 } 2544 2545 return; 2546 } 2547 2548 if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) { 2549 Type complexElementType = complexTy.getElementType(); 2550 // Note: The if and else below had a common lambda function which invoked 2551 // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 2552 // and hence was replaced. 2553 if (llvm::isa<IntegerType>(complexElementType)) { 2554 auto valueIt = attr.value_begin<std::complex<APInt>>(); 2555 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 2556 auto complexValue = *(valueIt + index); 2557 os << "("; 2558 printDenseIntElement(complexValue.real(), os, complexElementType); 2559 os << ","; 2560 printDenseIntElement(complexValue.imag(), os, complexElementType); 2561 os << ")"; 2562 }); 2563 } else { 2564 auto valueIt = attr.value_begin<std::complex<APFloat>>(); 2565 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 2566 auto complexValue = *(valueIt + index); 2567 os << "("; 2568 printFloatValue(complexValue.real(), os); 2569 os << ","; 2570 printFloatValue(complexValue.imag(), os); 2571 os << ")"; 2572 }); 2573 } 2574 } else if (elementType.isIntOrIndex()) { 2575 auto valueIt = attr.value_begin<APInt>(); 2576 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 2577 printDenseIntElement(*(valueIt + index), os, elementType); 2578 }); 2579 } else { 2580 assert(llvm::isa<FloatType>(elementType) && "unexpected element type"); 2581 auto valueIt = attr.value_begin<APFloat>(); 2582 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 2583 printFloatValue(*(valueIt + index), os); 2584 }); 2585 } 2586 } 2587 2588 void AsmPrinter::Impl::printDenseStringElementsAttr( 2589 DenseStringElementsAttr attr) { 2590 ArrayRef<StringRef> data = attr.getRawStringData(); 2591 auto printFn = [&](unsigned index) { printEscapedString(data[index]); }; 2592 printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); 2593 } 2594 2595 void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) { 2596 Type type = attr.getElementType(); 2597 unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth(); 2598 unsigned byteSize = bitwidth / 8; 2599 ArrayRef<char> data = attr.getRawData(); 2600 2601 auto printElementAt = [&](unsigned i) { 2602 APInt value(bitwidth, 0); 2603 if (bitwidth) { 2604 llvm::LoadIntFromMemory( 2605 value, reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i), 2606 byteSize); 2607 } 2608 // Print the data as-is or as a float. 2609 if (type.isIntOrIndex()) { 2610 printDenseIntElement(value, getStream(), type); 2611 } else { 2612 APFloat fltVal(llvm::cast<FloatType>(type).getFloatSemantics(), value); 2613 printFloatValue(fltVal, getStream()); 2614 } 2615 }; 2616 llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(), 2617 printElementAt); 2618 } 2619 2620 void AsmPrinter::Impl::printType(Type type) { 2621 if (!type) { 2622 os << "<<NULL TYPE>>"; 2623 return; 2624 } 2625 2626 // Try to print an alias for this type. 2627 if (succeeded(printAlias(type))) 2628 return; 2629 return printTypeImpl(type); 2630 } 2631 2632 void AsmPrinter::Impl::printTypeImpl(Type type) { 2633 TypeSwitch<Type>(type) 2634 .Case<OpaqueType>([&](OpaqueType opaqueTy) { 2635 printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), 2636 opaqueTy.getTypeData()); 2637 }) 2638 .Case<IndexType>([&](Type) { os << "index"; }) 2639 .Case<Float4E2M1FNType>([&](Type) { os << "f4E2M1FN"; }) 2640 .Case<Float6E2M3FNType>([&](Type) { os << "f6E2M3FN"; }) 2641 .Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN"; }) 2642 .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; }) 2643 .Case<Float8E4M3Type>([&](Type) { os << "f8E4M3"; }) 2644 .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; }) 2645 .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; }) 2646 .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; }) 2647 .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; }) 2648 .Case<Float8E3M4Type>([&](Type) { os << "f8E3M4"; }) 2649 .Case<Float8E8M0FNUType>([&](Type) { os << "f8E8M0FNU"; }) 2650 .Case<BFloat16Type>([&](Type) { os << "bf16"; }) 2651 .Case<Float16Type>([&](Type) { os << "f16"; }) 2652 .Case<FloatTF32Type>([&](Type) { os << "tf32"; }) 2653 .Case<Float32Type>([&](Type) { os << "f32"; }) 2654 .Case<Float64Type>([&](Type) { os << "f64"; }) 2655 .Case<Float80Type>([&](Type) { os << "f80"; }) 2656 .Case<Float128Type>([&](Type) { os << "f128"; }) 2657 .Case<IntegerType>([&](IntegerType integerTy) { 2658 if (integerTy.isSigned()) 2659 os << 's'; 2660 else if (integerTy.isUnsigned()) 2661 os << 'u'; 2662 os << 'i' << integerTy.getWidth(); 2663 }) 2664 .Case<FunctionType>([&](FunctionType funcTy) { 2665 os << '('; 2666 interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); 2667 os << ") -> "; 2668 ArrayRef<Type> results = funcTy.getResults(); 2669 if (results.size() == 1 && !llvm::isa<FunctionType>(results[0])) { 2670 printType(results[0]); 2671 } else { 2672 os << '('; 2673 interleaveComma(results, [&](Type ty) { printType(ty); }); 2674 os << ')'; 2675 } 2676 }) 2677 .Case<VectorType>([&](VectorType vectorTy) { 2678 auto scalableDims = vectorTy.getScalableDims(); 2679 os << "vector<"; 2680 auto vShape = vectorTy.getShape(); 2681 unsigned lastDim = vShape.size(); 2682 unsigned dimIdx = 0; 2683 for (dimIdx = 0; dimIdx < lastDim; dimIdx++) { 2684 if (!scalableDims.empty() && scalableDims[dimIdx]) 2685 os << '['; 2686 os << vShape[dimIdx]; 2687 if (!scalableDims.empty() && scalableDims[dimIdx]) 2688 os << ']'; 2689 os << 'x'; 2690 } 2691 printType(vectorTy.getElementType()); 2692 os << '>'; 2693 }) 2694 .Case<RankedTensorType>([&](RankedTensorType tensorTy) { 2695 os << "tensor<"; 2696 printDimensionList(tensorTy.getShape()); 2697 if (!tensorTy.getShape().empty()) 2698 os << 'x'; 2699 printType(tensorTy.getElementType()); 2700 // Only print the encoding attribute value if set. 2701 if (tensorTy.getEncoding()) { 2702 os << ", "; 2703 printAttribute(tensorTy.getEncoding()); 2704 } 2705 os << '>'; 2706 }) 2707 .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) { 2708 os << "tensor<*x"; 2709 printType(tensorTy.getElementType()); 2710 os << '>'; 2711 }) 2712 .Case<MemRefType>([&](MemRefType memrefTy) { 2713 os << "memref<"; 2714 printDimensionList(memrefTy.getShape()); 2715 if (!memrefTy.getShape().empty()) 2716 os << 'x'; 2717 printType(memrefTy.getElementType()); 2718 MemRefLayoutAttrInterface layout = memrefTy.getLayout(); 2719 if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) { 2720 os << ", "; 2721 printAttribute(memrefTy.getLayout(), AttrTypeElision::May); 2722 } 2723 // Only print the memory space if it is the non-default one. 2724 if (memrefTy.getMemorySpace()) { 2725 os << ", "; 2726 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); 2727 } 2728 os << '>'; 2729 }) 2730 .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) { 2731 os << "memref<*x"; 2732 printType(memrefTy.getElementType()); 2733 // Only print the memory space if it is the non-default one. 2734 if (memrefTy.getMemorySpace()) { 2735 os << ", "; 2736 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); 2737 } 2738 os << '>'; 2739 }) 2740 .Case<ComplexType>([&](ComplexType complexTy) { 2741 os << "complex<"; 2742 printType(complexTy.getElementType()); 2743 os << '>'; 2744 }) 2745 .Case<TupleType>([&](TupleType tupleTy) { 2746 os << "tuple<"; 2747 interleaveComma(tupleTy.getTypes(), 2748 [&](Type type) { printType(type); }); 2749 os << '>'; 2750 }) 2751 .Case<NoneType>([&](Type) { os << "none"; }) 2752 .Default([&](Type type) { return printDialectType(type); }); 2753 } 2754 2755 void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 2756 ArrayRef<StringRef> elidedAttrs, 2757 bool withKeyword) { 2758 // If there are no attributes, then there is nothing to be done. 2759 if (attrs.empty()) 2760 return; 2761 2762 // Functor used to print a filtered attribute list. 2763 auto printFilteredAttributesFn = [&](auto filteredAttrs) { 2764 // Print the 'attributes' keyword if necessary. 2765 if (withKeyword) 2766 os << " attributes"; 2767 2768 // Otherwise, print them all out in braces. 2769 os << " {"; 2770 interleaveComma(filteredAttrs, 2771 [&](NamedAttribute attr) { printNamedAttribute(attr); }); 2772 os << '}'; 2773 }; 2774 2775 // If no attributes are elided, we can directly print with no filtering. 2776 if (elidedAttrs.empty()) 2777 return printFilteredAttributesFn(attrs); 2778 2779 // Otherwise, filter out any attributes that shouldn't be included. 2780 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), 2781 elidedAttrs.end()); 2782 auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 2783 return !elidedAttrsSet.contains(attr.getName().strref()); 2784 }); 2785 if (!filteredAttrs.empty()) 2786 printFilteredAttributesFn(filteredAttrs); 2787 } 2788 void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { 2789 // Print the name without quotes if possible. 2790 ::printKeywordOrString(attr.getName().strref(), os); 2791 2792 // Pretty printing elides the attribute value for unit attributes. 2793 if (llvm::isa<UnitAttr>(attr.getValue())) 2794 return; 2795 2796 os << " = "; 2797 printAttribute(attr.getValue()); 2798 } 2799 2800 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { 2801 auto &dialect = attr.getDialect(); 2802 2803 // Ask the dialect to serialize the attribute to a string. 2804 std::string attrName; 2805 { 2806 llvm::raw_string_ostream attrNameStr(attrName); 2807 Impl subPrinter(attrNameStr, state); 2808 DialectAsmPrinter printer(subPrinter); 2809 dialect.printAttribute(attr, printer); 2810 } 2811 printDialectSymbol(os, "#", dialect.getNamespace(), attrName); 2812 } 2813 2814 void AsmPrinter::Impl::printDialectType(Type type) { 2815 auto &dialect = type.getDialect(); 2816 2817 // Ask the dialect to serialize the type to a string. 2818 std::string typeName; 2819 { 2820 llvm::raw_string_ostream typeNameStr(typeName); 2821 Impl subPrinter(typeNameStr, state); 2822 DialectAsmPrinter printer(subPrinter); 2823 dialect.printType(type, printer); 2824 } 2825 printDialectSymbol(os, "!", dialect.getNamespace(), typeName); 2826 } 2827 2828 void AsmPrinter::Impl::printEscapedString(StringRef str) { 2829 os << "\""; 2830 llvm::printEscapedString(str, os); 2831 os << "\""; 2832 } 2833 2834 void AsmPrinter::Impl::printHexString(StringRef str) { 2835 os << "\"0x" << llvm::toHex(str) << "\""; 2836 } 2837 void AsmPrinter::Impl::printHexString(ArrayRef<char> data) { 2838 printHexString(StringRef(data.data(), data.size())); 2839 } 2840 2841 LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) { 2842 return state.pushCyclicPrinting(opaquePointer); 2843 } 2844 2845 void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); } 2846 2847 void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) { 2848 detail::printDimensionList(os, shape); 2849 } 2850 2851 //===--------------------------------------------------------------------===// 2852 // AsmPrinter 2853 //===--------------------------------------------------------------------===// 2854 2855 AsmPrinter::~AsmPrinter() = default; 2856 2857 raw_ostream &AsmPrinter::getStream() const { 2858 assert(impl && "expected AsmPrinter::getStream to be overriden"); 2859 return impl->getStream(); 2860 } 2861 2862 /// Print the given floating point value in a stablized form. 2863 void AsmPrinter::printFloat(const APFloat &value) { 2864 assert(impl && "expected AsmPrinter::printFloat to be overriden"); 2865 printFloatValue(value, impl->getStream()); 2866 } 2867 2868 void AsmPrinter::printType(Type type) { 2869 assert(impl && "expected AsmPrinter::printType to be overriden"); 2870 impl->printType(type); 2871 } 2872 2873 void AsmPrinter::printAttribute(Attribute attr) { 2874 assert(impl && "expected AsmPrinter::printAttribute to be overriden"); 2875 impl->printAttribute(attr); 2876 } 2877 2878 LogicalResult AsmPrinter::printAlias(Attribute attr) { 2879 assert(impl && "expected AsmPrinter::printAlias to be overriden"); 2880 return impl->printAlias(attr); 2881 } 2882 2883 LogicalResult AsmPrinter::printAlias(Type type) { 2884 assert(impl && "expected AsmPrinter::printAlias to be overriden"); 2885 return impl->printAlias(type); 2886 } 2887 2888 void AsmPrinter::printAttributeWithoutType(Attribute attr) { 2889 assert(impl && 2890 "expected AsmPrinter::printAttributeWithoutType to be overriden"); 2891 impl->printAttribute(attr, Impl::AttrTypeElision::Must); 2892 } 2893 2894 void AsmPrinter::printKeywordOrString(StringRef keyword) { 2895 assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden"); 2896 ::printKeywordOrString(keyword, impl->getStream()); 2897 } 2898 2899 void AsmPrinter::printString(StringRef keyword) { 2900 assert(impl && "expected AsmPrinter::printString to be overriden"); 2901 *this << '"'; 2902 printEscapedString(keyword, getStream()); 2903 *this << '"'; 2904 } 2905 2906 void AsmPrinter::printSymbolName(StringRef symbolRef) { 2907 assert(impl && "expected AsmPrinter::printSymbolName to be overriden"); 2908 ::printSymbolReference(symbolRef, impl->getStream()); 2909 } 2910 2911 void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) { 2912 assert(impl && "expected AsmPrinter::printResourceHandle to be overriden"); 2913 impl->printResourceHandle(resource); 2914 } 2915 2916 void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) { 2917 detail::printDimensionList(getStream(), shape); 2918 } 2919 2920 LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) { 2921 return impl->pushCyclicPrinting(opaquePointer); 2922 } 2923 2924 void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); } 2925 2926 //===----------------------------------------------------------------------===// 2927 // Affine expressions and maps 2928 //===----------------------------------------------------------------------===// 2929 2930 void AsmPrinter::Impl::printAffineExpr( 2931 AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) { 2932 printAffineExprInternal(expr, BindingStrength::Weak, printValueName); 2933 } 2934 2935 void AsmPrinter::Impl::printAffineExprInternal( 2936 AffineExpr expr, BindingStrength enclosingTightness, 2937 function_ref<void(unsigned, bool)> printValueName) { 2938 const char *binopSpelling = nullptr; 2939 switch (expr.getKind()) { 2940 case AffineExprKind::SymbolId: { 2941 unsigned pos = cast<AffineSymbolExpr>(expr).getPosition(); 2942 if (printValueName) 2943 printValueName(pos, /*isSymbol=*/true); 2944 else 2945 os << 's' << pos; 2946 return; 2947 } 2948 case AffineExprKind::DimId: { 2949 unsigned pos = cast<AffineDimExpr>(expr).getPosition(); 2950 if (printValueName) 2951 printValueName(pos, /*isSymbol=*/false); 2952 else 2953 os << 'd' << pos; 2954 return; 2955 } 2956 case AffineExprKind::Constant: 2957 os << cast<AffineConstantExpr>(expr).getValue(); 2958 return; 2959 case AffineExprKind::Add: 2960 binopSpelling = " + "; 2961 break; 2962 case AffineExprKind::Mul: 2963 binopSpelling = " * "; 2964 break; 2965 case AffineExprKind::FloorDiv: 2966 binopSpelling = " floordiv "; 2967 break; 2968 case AffineExprKind::CeilDiv: 2969 binopSpelling = " ceildiv "; 2970 break; 2971 case AffineExprKind::Mod: 2972 binopSpelling = " mod "; 2973 break; 2974 } 2975 2976 auto binOp = cast<AffineBinaryOpExpr>(expr); 2977 AffineExpr lhsExpr = binOp.getLHS(); 2978 AffineExpr rhsExpr = binOp.getRHS(); 2979 2980 // Handle tightly binding binary operators. 2981 if (binOp.getKind() != AffineExprKind::Add) { 2982 if (enclosingTightness == BindingStrength::Strong) 2983 os << '('; 2984 2985 // Pretty print multiplication with -1. 2986 auto rhsConst = dyn_cast<AffineConstantExpr>(rhsExpr); 2987 if (rhsConst && binOp.getKind() == AffineExprKind::Mul && 2988 rhsConst.getValue() == -1) { 2989 os << "-"; 2990 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 2991 if (enclosingTightness == BindingStrength::Strong) 2992 os << ')'; 2993 return; 2994 } 2995 2996 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 2997 2998 os << binopSpelling; 2999 printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); 3000 3001 if (enclosingTightness == BindingStrength::Strong) 3002 os << ')'; 3003 return; 3004 } 3005 3006 // Print out special "pretty" forms for add. 3007 if (enclosingTightness == BindingStrength::Strong) 3008 os << '('; 3009 3010 // Pretty print addition to a product that has a negative operand as a 3011 // subtraction. 3012 if (auto rhs = dyn_cast<AffineBinaryOpExpr>(rhsExpr)) { 3013 if (rhs.getKind() == AffineExprKind::Mul) { 3014 AffineExpr rrhsExpr = rhs.getRHS(); 3015 if (auto rrhs = dyn_cast<AffineConstantExpr>(rrhsExpr)) { 3016 if (rrhs.getValue() == -1) { 3017 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 3018 printValueName); 3019 os << " - "; 3020 if (rhs.getLHS().getKind() == AffineExprKind::Add) { 3021 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 3022 printValueName); 3023 } else { 3024 printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, 3025 printValueName); 3026 } 3027 3028 if (enclosingTightness == BindingStrength::Strong) 3029 os << ')'; 3030 return; 3031 } 3032 3033 if (rrhs.getValue() < -1) { 3034 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 3035 printValueName); 3036 os << " - "; 3037 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 3038 printValueName); 3039 os << " * " << -rrhs.getValue(); 3040 if (enclosingTightness == BindingStrength::Strong) 3041 os << ')'; 3042 return; 3043 } 3044 } 3045 } 3046 } 3047 3048 // Pretty print addition to a negative number as a subtraction. 3049 if (auto rhsConst = dyn_cast<AffineConstantExpr>(rhsExpr)) { 3050 if (rhsConst.getValue() < 0) { 3051 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 3052 os << " - " << -rhsConst.getValue(); 3053 if (enclosingTightness == BindingStrength::Strong) 3054 os << ')'; 3055 return; 3056 } 3057 } 3058 3059 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 3060 3061 os << " + "; 3062 printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); 3063 3064 if (enclosingTightness == BindingStrength::Strong) 3065 os << ')'; 3066 } 3067 3068 void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) { 3069 printAffineExprInternal(expr, BindingStrength::Weak); 3070 isEq ? os << " == 0" : os << " >= 0"; 3071 } 3072 3073 void AsmPrinter::Impl::printAffineMap(AffineMap map) { 3074 // Dimension identifiers. 3075 os << '('; 3076 for (int i = 0; i < (int)map.getNumDims() - 1; ++i) 3077 os << 'd' << i << ", "; 3078 if (map.getNumDims() >= 1) 3079 os << 'd' << map.getNumDims() - 1; 3080 os << ')'; 3081 3082 // Symbolic identifiers. 3083 if (map.getNumSymbols() != 0) { 3084 os << '['; 3085 for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) 3086 os << 's' << i << ", "; 3087 if (map.getNumSymbols() >= 1) 3088 os << 's' << map.getNumSymbols() - 1; 3089 os << ']'; 3090 } 3091 3092 // Result affine expressions. 3093 os << " -> ("; 3094 interleaveComma(map.getResults(), 3095 [&](AffineExpr expr) { printAffineExpr(expr); }); 3096 os << ')'; 3097 } 3098 3099 void AsmPrinter::Impl::printIntegerSet(IntegerSet set) { 3100 // Dimension identifiers. 3101 os << '('; 3102 for (unsigned i = 1; i < set.getNumDims(); ++i) 3103 os << 'd' << i - 1 << ", "; 3104 if (set.getNumDims() >= 1) 3105 os << 'd' << set.getNumDims() - 1; 3106 os << ')'; 3107 3108 // Symbolic identifiers. 3109 if (set.getNumSymbols() != 0) { 3110 os << '['; 3111 for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) 3112 os << 's' << i << ", "; 3113 if (set.getNumSymbols() >= 1) 3114 os << 's' << set.getNumSymbols() - 1; 3115 os << ']'; 3116 } 3117 3118 // Print constraints. 3119 os << " : ("; 3120 int numConstraints = set.getNumConstraints(); 3121 for (int i = 1; i < numConstraints; ++i) { 3122 printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); 3123 os << ", "; 3124 } 3125 if (numConstraints >= 1) 3126 printAffineConstraint(set.getConstraint(numConstraints - 1), 3127 set.isEq(numConstraints - 1)); 3128 os << ')'; 3129 } 3130 3131 //===----------------------------------------------------------------------===// 3132 // OperationPrinter 3133 //===----------------------------------------------------------------------===// 3134 3135 namespace { 3136 /// This class contains the logic for printing operations, regions, and blocks. 3137 class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { 3138 public: 3139 using Impl = AsmPrinter::Impl; 3140 using Impl::printType; 3141 3142 explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) 3143 : Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)) {} 3144 3145 /// Print the given top-level operation. 3146 void printTopLevelOperation(Operation *op); 3147 3148 /// Print the given operation, including its left-hand side and its right-hand 3149 /// side, with its indent and location. 3150 void printFullOpWithIndentAndLoc(Operation *op); 3151 /// Print the given operation, including its left-hand side and its right-hand 3152 /// side, but not including indentation and location. 3153 void printFullOp(Operation *op); 3154 /// Print the right-hand size of the given operation in the custom or generic 3155 /// form. 3156 void printCustomOrGenericOp(Operation *op) override; 3157 /// Print the right-hand side of the given operation in the generic form. 3158 void printGenericOp(Operation *op, bool printOpName) override; 3159 3160 /// Print the name of the given block. 3161 void printBlockName(Block *block); 3162 3163 /// Print the given block. If 'printBlockArgs' is false, the arguments of the 3164 /// block are not printed. If 'printBlockTerminator' is false, the terminator 3165 /// operation of the block is not printed. 3166 void print(Block *block, bool printBlockArgs = true, 3167 bool printBlockTerminator = true); 3168 3169 /// Print the ID of the given value, optionally with its result number. 3170 void printValueID(Value value, bool printResultNo = true, 3171 raw_ostream *streamOverride = nullptr) const; 3172 3173 /// Print the ID of the given operation. 3174 void printOperationID(Operation *op, 3175 raw_ostream *streamOverride = nullptr) const; 3176 3177 //===--------------------------------------------------------------------===// 3178 // OpAsmPrinter methods 3179 //===--------------------------------------------------------------------===// 3180 3181 /// Print a loc(...) specifier if printing debug info is enabled. Locations 3182 /// may be deferred with an alias. 3183 void printOptionalLocationSpecifier(Location loc) override { 3184 printTrailingLocation(loc); 3185 } 3186 3187 /// Print a newline and indent the printer to the start of the current 3188 /// operation. 3189 void printNewline() override { 3190 os << newLine; 3191 os.indent(currentIndent); 3192 } 3193 3194 /// Increase indentation. 3195 void increaseIndent() override { currentIndent += indentWidth; } 3196 3197 /// Decrease indentation. 3198 void decreaseIndent() override { currentIndent -= indentWidth; } 3199 3200 /// Print a block argument in the usual format of: 3201 /// %ssaName : type {attr1=42} loc("here") 3202 /// where location printing is controlled by the standard internal option. 3203 /// You may pass omitType=true to not print a type, and pass an empty 3204 /// attribute list if you don't care for attributes. 3205 void printRegionArgument(BlockArgument arg, 3206 ArrayRef<NamedAttribute> argAttrs = {}, 3207 bool omitType = false) override; 3208 3209 /// Print the ID for the given value. 3210 void printOperand(Value value) override { printValueID(value); } 3211 void printOperand(Value value, raw_ostream &os) override { 3212 printValueID(value, /*printResultNo=*/true, &os); 3213 } 3214 3215 /// Print an optional attribute dictionary with a given set of elided values. 3216 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 3217 ArrayRef<StringRef> elidedAttrs = {}) override { 3218 Impl::printOptionalAttrDict(attrs, elidedAttrs); 3219 } 3220 void printOptionalAttrDictWithKeyword( 3221 ArrayRef<NamedAttribute> attrs, 3222 ArrayRef<StringRef> elidedAttrs = {}) override { 3223 Impl::printOptionalAttrDict(attrs, elidedAttrs, 3224 /*withKeyword=*/true); 3225 } 3226 3227 /// Print the given successor. 3228 void printSuccessor(Block *successor) override; 3229 3230 /// Print an operation successor with the operands used for the block 3231 /// arguments. 3232 void printSuccessorAndUseList(Block *successor, 3233 ValueRange succOperands) override; 3234 3235 /// Print the given region. 3236 void printRegion(Region ®ion, bool printEntryBlockArgs, 3237 bool printBlockTerminators, bool printEmptyBlock) override; 3238 3239 /// Renumber the arguments for the specified region to the same names as the 3240 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove 3241 /// operations. If any entry in namesToUse is null, the corresponding 3242 /// argument name is left alone. 3243 void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { 3244 state.getSSANameState().shadowRegionArgs(region, namesToUse); 3245 } 3246 3247 /// Print the given affine map with the symbol and dimension operands printed 3248 /// inline with the map. 3249 void printAffineMapOfSSAIds(AffineMapAttr mapAttr, 3250 ValueRange operands) override; 3251 3252 /// Print the given affine expression with the symbol and dimension operands 3253 /// printed inline with the expression. 3254 void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, 3255 ValueRange symOperands) override; 3256 3257 /// Print users of this operation or id of this operation if it has no result. 3258 void printUsersComment(Operation *op); 3259 3260 /// Print users of this block arg. 3261 void printUsersComment(BlockArgument arg); 3262 3263 /// Print the users of a value. 3264 void printValueUsers(Value value); 3265 3266 /// Print either the ids of the result values or the id of the operation if 3267 /// the operation has no results. 3268 void printUserIDs(Operation *user, bool prefixComma = false); 3269 3270 private: 3271 /// This class represents a resource builder implementation for the MLIR 3272 /// textual assembly format. 3273 class ResourceBuilder : public AsmResourceBuilder { 3274 public: 3275 using ValueFn = function_ref<void(raw_ostream &)>; 3276 using PrintFn = function_ref<void(StringRef, ValueFn)>; 3277 3278 ResourceBuilder(PrintFn printFn) : printFn(printFn) {} 3279 ~ResourceBuilder() override = default; 3280 3281 void buildBool(StringRef key, bool data) final { 3282 printFn(key, [&](raw_ostream &os) { os << (data ? "true" : "false"); }); 3283 } 3284 3285 void buildString(StringRef key, StringRef data) final { 3286 printFn(key, [&](raw_ostream &os) { 3287 os << "\""; 3288 llvm::printEscapedString(data, os); 3289 os << "\""; 3290 }); 3291 } 3292 3293 void buildBlob(StringRef key, ArrayRef<char> data, 3294 uint32_t dataAlignment) final { 3295 printFn(key, [&](raw_ostream &os) { 3296 // Store the blob in a hex string containing the alignment and the data. 3297 llvm::support::ulittle32_t dataAlignmentLE(dataAlignment); 3298 os << "\"0x" 3299 << llvm::toHex(StringRef(reinterpret_cast<char *>(&dataAlignmentLE), 3300 sizeof(dataAlignment))) 3301 << llvm::toHex(StringRef(data.data(), data.size())) << "\""; 3302 }); 3303 } 3304 3305 private: 3306 PrintFn printFn; 3307 }; 3308 3309 /// Print the metadata dictionary for the file, eliding it if it is empty. 3310 void printFileMetadataDictionary(Operation *op); 3311 3312 /// Print the resource sections for the file metadata dictionary. 3313 /// `checkAddMetadataDict` is used to indicate that metadata is going to be 3314 /// added, and the file metadata dictionary should be started if it hasn't 3315 /// yet. 3316 void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict, 3317 Operation *op); 3318 3319 // Contains the stack of default dialects to use when printing regions. 3320 // A new dialect is pushed to the stack before parsing regions nested under an 3321 // operation implementing `OpAsmOpInterface`, and popped when done. At the 3322 // top-level we start with "builtin" as the default, so that the top-level 3323 // `module` operation prints as-is. 3324 SmallVector<StringRef> defaultDialectStack{"builtin"}; 3325 3326 /// The number of spaces used for indenting nested operations. 3327 const static unsigned indentWidth = 2; 3328 3329 // This is the current indentation level for nested structures. 3330 unsigned currentIndent = 0; 3331 }; 3332 } // namespace 3333 3334 void OperationPrinter::printTopLevelOperation(Operation *op) { 3335 // Output the aliases at the top level that can't be deferred. 3336 state.getAliasState().printNonDeferredAliases(*this, newLine); 3337 3338 // Print the module. 3339 printFullOpWithIndentAndLoc(op); 3340 os << newLine; 3341 3342 // Output the aliases at the top level that can be deferred. 3343 state.getAliasState().printDeferredAliases(*this, newLine); 3344 3345 // Output any file level metadata. 3346 printFileMetadataDictionary(op); 3347 } 3348 3349 void OperationPrinter::printFileMetadataDictionary(Operation *op) { 3350 bool sawMetadataEntry = false; 3351 auto checkAddMetadataDict = [&] { 3352 if (!std::exchange(sawMetadataEntry, true)) 3353 os << newLine << "{-#" << newLine; 3354 }; 3355 3356 // Add the various types of metadata. 3357 printResourceFileMetadata(checkAddMetadataDict, op); 3358 3359 // If the file dictionary exists, close it. 3360 if (sawMetadataEntry) 3361 os << newLine << "#-}" << newLine; 3362 } 3363 3364 void OperationPrinter::printResourceFileMetadata( 3365 function_ref<void()> checkAddMetadataDict, Operation *op) { 3366 // Functor used to add data entries to the file metadata dictionary. 3367 bool hadResource = false; 3368 bool needResourceComma = false; 3369 bool needEntryComma = false; 3370 auto processProvider = [&](StringRef dictName, StringRef name, auto &provider, 3371 auto &&...providerArgs) { 3372 bool hadEntry = false; 3373 auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) { 3374 checkAddMetadataDict(); 3375 3376 auto printFormatting = [&]() { 3377 // Emit the top-level resource entry if we haven't yet. 3378 if (!std::exchange(hadResource, true)) { 3379 if (needResourceComma) 3380 os << "," << newLine; 3381 os << " " << dictName << "_resources: {" << newLine; 3382 } 3383 // Emit the parent resource entry if we haven't yet. 3384 if (!std::exchange(hadEntry, true)) { 3385 if (needEntryComma) 3386 os << "," << newLine; 3387 os << " " << name << ": {" << newLine; 3388 } else { 3389 os << "," << newLine; 3390 } 3391 }; 3392 3393 std::optional<uint64_t> charLimit = 3394 printerFlags.getLargeResourceStringLimit(); 3395 if (charLimit.has_value()) { 3396 std::string resourceStr; 3397 llvm::raw_string_ostream ss(resourceStr); 3398 valueFn(ss); 3399 3400 // Only print entry if it's string is small enough 3401 if (resourceStr.size() > charLimit.value()) 3402 return; 3403 3404 printFormatting(); 3405 os << " " << key << ": " << resourceStr; 3406 } else { 3407 printFormatting(); 3408 os << " " << key << ": "; 3409 valueFn(os); 3410 } 3411 }; 3412 ResourceBuilder entryBuilder(printFn); 3413 provider.buildResources(op, providerArgs..., entryBuilder); 3414 3415 needEntryComma |= hadEntry; 3416 if (hadEntry) 3417 os << newLine << " }"; 3418 }; 3419 3420 // Print the `dialect_resources` section if we have any dialects with 3421 // resources. 3422 for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) { 3423 auto &dialectResources = state.getDialectResources(); 3424 StringRef name = interface.getDialect()->getNamespace(); 3425 auto it = dialectResources.find(interface.getDialect()); 3426 if (it != dialectResources.end()) 3427 processProvider("dialect", name, interface, it->second); 3428 else 3429 processProvider("dialect", name, interface, 3430 SetVector<AsmDialectResourceHandle>()); 3431 } 3432 if (hadResource) 3433 os << newLine << " }"; 3434 3435 // Print the `external_resources` section if we have any external clients with 3436 // resources. 3437 needEntryComma = false; 3438 needResourceComma = hadResource; 3439 hadResource = false; 3440 for (const auto &printer : state.getResourcePrinters()) 3441 processProvider("external", printer.getName(), printer); 3442 if (hadResource) 3443 os << newLine << " }"; 3444 } 3445 3446 /// Print a block argument in the usual format of: 3447 /// %ssaName : type {attr1=42} loc("here") 3448 /// where location printing is controlled by the standard internal option. 3449 /// You may pass omitType=true to not print a type, and pass an empty 3450 /// attribute list if you don't care for attributes. 3451 void OperationPrinter::printRegionArgument(BlockArgument arg, 3452 ArrayRef<NamedAttribute> argAttrs, 3453 bool omitType) { 3454 printOperand(arg); 3455 if (!omitType) { 3456 os << ": "; 3457 printType(arg.getType()); 3458 } 3459 printOptionalAttrDict(argAttrs); 3460 // TODO: We should allow location aliases on block arguments. 3461 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); 3462 } 3463 3464 void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) { 3465 // Track the location of this operation. 3466 state.registerOperationLocation(op, newLine.curLine, currentIndent); 3467 3468 os.indent(currentIndent); 3469 printFullOp(op); 3470 printTrailingLocation(op->getLoc()); 3471 if (printerFlags.shouldPrintValueUsers()) 3472 printUsersComment(op); 3473 } 3474 3475 void OperationPrinter::printFullOp(Operation *op) { 3476 if (size_t numResults = op->getNumResults()) { 3477 auto printResultGroup = [&](size_t resultNo, size_t resultCount) { 3478 printValueID(op->getResult(resultNo), /*printResultNo=*/false); 3479 if (resultCount > 1) 3480 os << ':' << resultCount; 3481 }; 3482 3483 // Check to see if this operation has multiple result groups. 3484 ArrayRef<int> resultGroups = state.getSSANameState().getOpResultGroups(op); 3485 if (!resultGroups.empty()) { 3486 // Interleave the groups excluding the last one, this one will be handled 3487 // separately. 3488 interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) { 3489 printResultGroup(resultGroups[i], 3490 resultGroups[i + 1] - resultGroups[i]); 3491 }); 3492 os << ", "; 3493 printResultGroup(resultGroups.back(), numResults - resultGroups.back()); 3494 3495 } else { 3496 printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); 3497 } 3498 3499 os << " = "; 3500 } 3501 3502 printCustomOrGenericOp(op); 3503 } 3504 3505 void OperationPrinter::printUsersComment(Operation *op) { 3506 unsigned numResults = op->getNumResults(); 3507 if (!numResults && op->getNumOperands()) { 3508 os << " // id: "; 3509 printOperationID(op); 3510 } else if (numResults && op->use_empty()) { 3511 os << " // unused"; 3512 } else if (numResults && !op->use_empty()) { 3513 // Print "user" if the operation has one result used to compute one other 3514 // result, or is used in one operation with no result. 3515 unsigned usedInNResults = 0; 3516 unsigned usedInNOperations = 0; 3517 SmallPtrSet<Operation *, 1> userSet; 3518 for (Operation *user : op->getUsers()) { 3519 if (userSet.insert(user).second) { 3520 ++usedInNOperations; 3521 usedInNResults += user->getNumResults(); 3522 } 3523 } 3524 3525 // We already know that users is not empty. 3526 bool exactlyOneUniqueUse = 3527 usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1; 3528 os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": "; 3529 bool shouldPrintBrackets = numResults > 1; 3530 auto printOpResult = [&](OpResult opResult) { 3531 if (shouldPrintBrackets) 3532 os << "("; 3533 printValueUsers(opResult); 3534 if (shouldPrintBrackets) 3535 os << ")"; 3536 }; 3537 3538 interleaveComma(op->getResults(), printOpResult); 3539 } 3540 } 3541 3542 void OperationPrinter::printUsersComment(BlockArgument arg) { 3543 os << "// "; 3544 printValueID(arg); 3545 if (arg.use_empty()) { 3546 os << " is unused"; 3547 } else { 3548 os << " is used by "; 3549 printValueUsers(arg); 3550 } 3551 os << newLine; 3552 } 3553 3554 void OperationPrinter::printValueUsers(Value value) { 3555 if (value.use_empty()) 3556 os << "unused"; 3557 3558 // One value might be used as the operand of an operation more than once. 3559 // Only print the operations results once in that case. 3560 SmallPtrSet<Operation *, 1> userSet; 3561 for (auto [index, user] : enumerate(value.getUsers())) { 3562 if (userSet.insert(user).second) 3563 printUserIDs(user, index); 3564 } 3565 } 3566 3567 void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) { 3568 if (prefixComma) 3569 os << ", "; 3570 3571 if (!user->getNumResults()) { 3572 printOperationID(user); 3573 } else { 3574 interleaveComma(user->getResults(), 3575 [this](Value result) { printValueID(result); }); 3576 } 3577 } 3578 3579 void OperationPrinter::printCustomOrGenericOp(Operation *op) { 3580 // If requested, always print the generic form. 3581 if (!printerFlags.shouldPrintGenericOpForm()) { 3582 // Check to see if this is a known operation. If so, use the registered 3583 // custom printer hook. 3584 if (auto opInfo = op->getRegisteredInfo()) { 3585 opInfo->printAssembly(op, *this, defaultDialectStack.back()); 3586 return; 3587 } 3588 // Otherwise try to dispatch to the dialect, if available. 3589 if (Dialect *dialect = op->getDialect()) { 3590 if (auto opPrinter = dialect->getOperationPrinter(op)) { 3591 // Print the op name first. 3592 StringRef name = op->getName().getStringRef(); 3593 // Only drop the default dialect prefix when it cannot lead to 3594 // ambiguities. 3595 if (name.count('.') == 1) 3596 name.consume_front((defaultDialectStack.back() + ".").str()); 3597 os << name; 3598 3599 // Print the rest of the op now. 3600 opPrinter(op, *this); 3601 return; 3602 } 3603 } 3604 } 3605 3606 // Otherwise print with the generic assembly form. 3607 printGenericOp(op, /*printOpName=*/true); 3608 } 3609 3610 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { 3611 if (printOpName) 3612 printEscapedString(op->getName().getStringRef()); 3613 os << '('; 3614 interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); }); 3615 os << ')'; 3616 3617 // For terminators, print the list of successors and their operands. 3618 if (op->getNumSuccessors() != 0) { 3619 os << '['; 3620 interleaveComma(op->getSuccessors(), 3621 [&](Block *successor) { printBlockName(successor); }); 3622 os << ']'; 3623 } 3624 3625 // Print the properties. 3626 if (Attribute prop = op->getPropertiesAsAttribute()) { 3627 os << " <"; 3628 Impl::printAttribute(prop); 3629 os << '>'; 3630 } 3631 3632 // Print regions. 3633 if (op->getNumRegions() != 0) { 3634 os << " ("; 3635 interleaveComma(op->getRegions(), [&](Region ®ion) { 3636 printRegion(region, /*printEntryBlockArgs=*/true, 3637 /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); 3638 }); 3639 os << ')'; 3640 } 3641 3642 printOptionalAttrDict(op->getPropertiesStorage() 3643 ? llvm::to_vector(op->getDiscardableAttrs()) 3644 : op->getAttrs()); 3645 3646 // Print the type signature of the operation. 3647 os << " : "; 3648 printFunctionalType(op); 3649 } 3650 3651 void OperationPrinter::printBlockName(Block *block) { 3652 os << state.getSSANameState().getBlockInfo(block).name; 3653 } 3654 3655 void OperationPrinter::print(Block *block, bool printBlockArgs, 3656 bool printBlockTerminator) { 3657 // Print the block label and argument list if requested. 3658 if (printBlockArgs) { 3659 os.indent(currentIndent); 3660 printBlockName(block); 3661 3662 // Print the argument list if non-empty. 3663 if (!block->args_empty()) { 3664 os << '('; 3665 interleaveComma(block->getArguments(), [&](BlockArgument arg) { 3666 printValueID(arg); 3667 os << ": "; 3668 printType(arg.getType()); 3669 // TODO: We should allow location aliases on block arguments. 3670 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); 3671 }); 3672 os << ')'; 3673 } 3674 os << ':'; 3675 3676 // Print out some context information about the predecessors of this block. 3677 if (!block->getParent()) { 3678 os << " // block is not in a region!"; 3679 } else if (block->hasNoPredecessors()) { 3680 if (!block->isEntryBlock()) 3681 os << " // no predecessors"; 3682 } else if (auto *pred = block->getSinglePredecessor()) { 3683 os << " // pred: "; 3684 printBlockName(pred); 3685 } else { 3686 // We want to print the predecessors in a stable order, not in 3687 // whatever order the use-list is in, so gather and sort them. 3688 SmallVector<BlockInfo, 4> predIDs; 3689 for (auto *pred : block->getPredecessors()) 3690 predIDs.push_back(state.getSSANameState().getBlockInfo(pred)); 3691 llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) { 3692 return lhs.ordering < rhs.ordering; 3693 }); 3694 3695 os << " // " << predIDs.size() << " preds: "; 3696 3697 interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; }); 3698 } 3699 os << newLine; 3700 } 3701 3702 currentIndent += indentWidth; 3703 3704 if (printerFlags.shouldPrintValueUsers()) { 3705 for (BlockArgument arg : block->getArguments()) { 3706 os.indent(currentIndent); 3707 printUsersComment(arg); 3708 } 3709 } 3710 3711 bool hasTerminator = 3712 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); 3713 auto range = llvm::make_range( 3714 block->begin(), 3715 std::prev(block->end(), 3716 (!hasTerminator || printBlockTerminator) ? 0 : 1)); 3717 for (auto &op : range) { 3718 printFullOpWithIndentAndLoc(&op); 3719 os << newLine; 3720 } 3721 currentIndent -= indentWidth; 3722 } 3723 3724 void OperationPrinter::printValueID(Value value, bool printResultNo, 3725 raw_ostream *streamOverride) const { 3726 state.getSSANameState().printValueID(value, printResultNo, 3727 streamOverride ? *streamOverride : os); 3728 } 3729 3730 void OperationPrinter::printOperationID(Operation *op, 3731 raw_ostream *streamOverride) const { 3732 state.getSSANameState().printOperationID(op, streamOverride ? *streamOverride 3733 : os); 3734 } 3735 3736 void OperationPrinter::printSuccessor(Block *successor) { 3737 printBlockName(successor); 3738 } 3739 3740 void OperationPrinter::printSuccessorAndUseList(Block *successor, 3741 ValueRange succOperands) { 3742 printBlockName(successor); 3743 if (succOperands.empty()) 3744 return; 3745 3746 os << '('; 3747 interleaveComma(succOperands, 3748 [this](Value operand) { printValueID(operand); }); 3749 os << " : "; 3750 interleaveComma(succOperands, 3751 [this](Value operand) { printType(operand.getType()); }); 3752 os << ')'; 3753 } 3754 3755 void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, 3756 bool printBlockTerminators, 3757 bool printEmptyBlock) { 3758 if (printerFlags.shouldSkipRegions()) { 3759 os << "{...}"; 3760 return; 3761 } 3762 os << "{" << newLine; 3763 if (!region.empty()) { 3764 auto restoreDefaultDialect = 3765 llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); }); 3766 if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp())) 3767 defaultDialectStack.push_back(iface.getDefaultDialect()); 3768 else 3769 defaultDialectStack.push_back(""); 3770 3771 auto *entryBlock = ®ion.front(); 3772 // Force printing the block header if printEmptyBlock is set and the block 3773 // is empty or if printEntryBlockArgs is set and there are arguments to 3774 // print. 3775 bool shouldAlwaysPrintBlockHeader = 3776 (printEmptyBlock && entryBlock->empty()) || 3777 (printEntryBlockArgs && entryBlock->getNumArguments() != 0); 3778 print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators); 3779 for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) 3780 print(&b); 3781 } 3782 os.indent(currentIndent) << "}"; 3783 } 3784 3785 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, 3786 ValueRange operands) { 3787 if (!mapAttr) { 3788 os << "<<NULL AFFINE MAP>>"; 3789 return; 3790 } 3791 AffineMap map = mapAttr.getValue(); 3792 unsigned numDims = map.getNumDims(); 3793 auto printValueName = [&](unsigned pos, bool isSymbol) { 3794 unsigned index = isSymbol ? numDims + pos : pos; 3795 assert(index < operands.size()); 3796 if (isSymbol) 3797 os << "symbol("; 3798 printValueID(operands[index]); 3799 if (isSymbol) 3800 os << ')'; 3801 }; 3802 3803 interleaveComma(map.getResults(), [&](AffineExpr expr) { 3804 printAffineExpr(expr, printValueName); 3805 }); 3806 } 3807 3808 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr, 3809 ValueRange dimOperands, 3810 ValueRange symOperands) { 3811 auto printValueName = [&](unsigned pos, bool isSymbol) { 3812 if (!isSymbol) 3813 return printValueID(dimOperands[pos]); 3814 os << "symbol("; 3815 printValueID(symOperands[pos]); 3816 os << ')'; 3817 }; 3818 printAffineExpr(expr, printValueName); 3819 } 3820 3821 //===----------------------------------------------------------------------===// 3822 // print and dump methods 3823 //===----------------------------------------------------------------------===// 3824 3825 void Attribute::print(raw_ostream &os, bool elideType) const { 3826 if (!*this) { 3827 os << "<<NULL ATTRIBUTE>>"; 3828 return; 3829 } 3830 3831 AsmState state(getContext()); 3832 print(os, state, elideType); 3833 } 3834 void Attribute::print(raw_ostream &os, AsmState &state, bool elideType) const { 3835 using AttrTypeElision = AsmPrinter::Impl::AttrTypeElision; 3836 AsmPrinter::Impl(os, state.getImpl()) 3837 .printAttribute(*this, elideType ? AttrTypeElision::Must 3838 : AttrTypeElision::Never); 3839 } 3840 3841 void Attribute::dump() const { 3842 print(llvm::errs()); 3843 llvm::errs() << "\n"; 3844 } 3845 3846 void Attribute::printStripped(raw_ostream &os, AsmState &state) const { 3847 if (!*this) { 3848 os << "<<NULL ATTRIBUTE>>"; 3849 return; 3850 } 3851 3852 AsmPrinter::Impl subPrinter(os, state.getImpl()); 3853 if (succeeded(subPrinter.printAlias(*this))) 3854 return; 3855 3856 auto &dialect = this->getDialect(); 3857 uint64_t posPrior = os.tell(); 3858 DialectAsmPrinter printer(subPrinter); 3859 dialect.printAttribute(*this, printer); 3860 if (posPrior != os.tell()) 3861 return; 3862 3863 // Fallback to printing with prefix if the above failed to write anything 3864 // to the output stream. 3865 print(os, state); 3866 } 3867 void Attribute::printStripped(raw_ostream &os) const { 3868 if (!*this) { 3869 os << "<<NULL ATTRIBUTE>>"; 3870 return; 3871 } 3872 3873 AsmState state(getContext()); 3874 printStripped(os, state); 3875 } 3876 3877 void Type::print(raw_ostream &os) const { 3878 if (!*this) { 3879 os << "<<NULL TYPE>>"; 3880 return; 3881 } 3882 3883 AsmState state(getContext()); 3884 print(os, state); 3885 } 3886 void Type::print(raw_ostream &os, AsmState &state) const { 3887 AsmPrinter::Impl(os, state.getImpl()).printType(*this); 3888 } 3889 3890 void Type::dump() const { 3891 print(llvm::errs()); 3892 llvm::errs() << "\n"; 3893 } 3894 3895 void AffineMap::dump() const { 3896 print(llvm::errs()); 3897 llvm::errs() << "\n"; 3898 } 3899 3900 void IntegerSet::dump() const { 3901 print(llvm::errs()); 3902 llvm::errs() << "\n"; 3903 } 3904 3905 void AffineExpr::print(raw_ostream &os) const { 3906 if (!expr) { 3907 os << "<<NULL AFFINE EXPR>>"; 3908 return; 3909 } 3910 AsmState state(getContext()); 3911 AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(*this); 3912 } 3913 3914 void AffineExpr::dump() const { 3915 print(llvm::errs()); 3916 llvm::errs() << "\n"; 3917 } 3918 3919 void AffineMap::print(raw_ostream &os) const { 3920 if (!map) { 3921 os << "<<NULL AFFINE MAP>>"; 3922 return; 3923 } 3924 AsmState state(getContext()); 3925 AsmPrinter::Impl(os, state.getImpl()).printAffineMap(*this); 3926 } 3927 3928 void IntegerSet::print(raw_ostream &os) const { 3929 AsmState state(getContext()); 3930 AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(*this); 3931 } 3932 3933 void Value::print(raw_ostream &os) const { print(os, OpPrintingFlags()); } 3934 void Value::print(raw_ostream &os, const OpPrintingFlags &flags) const { 3935 if (!impl) { 3936 os << "<<NULL VALUE>>"; 3937 return; 3938 } 3939 3940 if (auto *op = getDefiningOp()) 3941 return op->print(os, flags); 3942 // TODO: Improve BlockArgument print'ing. 3943 BlockArgument arg = llvm::cast<BlockArgument>(*this); 3944 os << "<block argument> of type '" << arg.getType() 3945 << "' at index: " << arg.getArgNumber(); 3946 } 3947 void Value::print(raw_ostream &os, AsmState &state) const { 3948 if (!impl) { 3949 os << "<<NULL VALUE>>"; 3950 return; 3951 } 3952 3953 if (auto *op = getDefiningOp()) 3954 return op->print(os, state); 3955 3956 // TODO: Improve BlockArgument print'ing. 3957 BlockArgument arg = llvm::cast<BlockArgument>(*this); 3958 os << "<block argument> of type '" << arg.getType() 3959 << "' at index: " << arg.getArgNumber(); 3960 } 3961 3962 void Value::dump() const { 3963 print(llvm::errs()); 3964 llvm::errs() << "\n"; 3965 } 3966 3967 void Value::printAsOperand(raw_ostream &os, AsmState &state) const { 3968 // TODO: This doesn't necessarily capture all potential cases. 3969 // Currently, region arguments can be shadowed when printing the main 3970 // operation. If the IR hasn't been printed, this will produce the old SSA 3971 // name and not the shadowed name. 3972 state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true, 3973 os); 3974 } 3975 3976 static Operation *findParent(Operation *op, bool shouldUseLocalScope) { 3977 do { 3978 // If we are printing local scope, stop at the first operation that is 3979 // isolated from above. 3980 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 3981 break; 3982 3983 // Otherwise, traverse up to the next parent. 3984 Operation *parentOp = op->getParentOp(); 3985 if (!parentOp) 3986 break; 3987 op = parentOp; 3988 } while (true); 3989 return op; 3990 } 3991 3992 void Value::printAsOperand(raw_ostream &os, 3993 const OpPrintingFlags &flags) const { 3994 Operation *op; 3995 if (auto result = llvm::dyn_cast<OpResult>(*this)) { 3996 op = result.getOwner(); 3997 } else { 3998 op = llvm::cast<BlockArgument>(*this).getOwner()->getParentOp(); 3999 if (!op) { 4000 os << "<<UNKNOWN SSA VALUE>>"; 4001 return; 4002 } 4003 } 4004 op = findParent(op, flags.shouldUseLocalScope()); 4005 AsmState state(op, flags); 4006 printAsOperand(os, state); 4007 } 4008 4009 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { 4010 // Find the operation to number from based upon the provided flags. 4011 Operation *op = findParent(this, printerFlags.shouldUseLocalScope()); 4012 AsmState state(op, printerFlags); 4013 print(os, state); 4014 } 4015 void Operation::print(raw_ostream &os, AsmState &state) { 4016 OperationPrinter printer(os, state.getImpl()); 4017 if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) { 4018 state.getImpl().initializeAliases(this); 4019 printer.printTopLevelOperation(this); 4020 } else { 4021 printer.printFullOpWithIndentAndLoc(this); 4022 } 4023 } 4024 4025 void Operation::dump() { 4026 print(llvm::errs(), OpPrintingFlags().useLocalScope()); 4027 llvm::errs() << "\n"; 4028 } 4029 4030 void Operation::dumpPretty() { 4031 print(llvm::errs(), OpPrintingFlags().useLocalScope().assumeVerified()); 4032 llvm::errs() << "\n"; 4033 } 4034 4035 void Block::print(raw_ostream &os) { 4036 Operation *parentOp = getParentOp(); 4037 if (!parentOp) { 4038 os << "<<UNLINKED BLOCK>>\n"; 4039 return; 4040 } 4041 // Get the top-level op. 4042 while (auto *nextOp = parentOp->getParentOp()) 4043 parentOp = nextOp; 4044 4045 AsmState state(parentOp); 4046 print(os, state); 4047 } 4048 void Block::print(raw_ostream &os, AsmState &state) { 4049 OperationPrinter(os, state.getImpl()).print(this); 4050 } 4051 4052 void Block::dump() { print(llvm::errs()); } 4053 4054 /// Print out the name of the block without printing its body. 4055 void Block::printAsOperand(raw_ostream &os, bool printType) { 4056 Operation *parentOp = getParentOp(); 4057 if (!parentOp) { 4058 os << "<<UNLINKED BLOCK>>\n"; 4059 return; 4060 } 4061 AsmState state(parentOp); 4062 printAsOperand(os, state); 4063 } 4064 void Block::printAsOperand(raw_ostream &os, AsmState &state) { 4065 OperationPrinter printer(os, state.getImpl()); 4066 printer.printBlockName(this); 4067 } 4068 4069 raw_ostream &mlir::operator<<(raw_ostream &os, Block &block) { 4070 block.print(os); 4071 return os; 4072 } 4073 4074 //===--------------------------------------------------------------------===// 4075 // Custom printers 4076 //===--------------------------------------------------------------------===// 4077 namespace mlir { 4078 4079 void printDimensionList(OpAsmPrinter &printer, Operation *op, 4080 ArrayRef<int64_t> dimensions) { 4081 if (dimensions.empty()) 4082 printer << "["; 4083 printer.printDimensionList(dimensions); 4084 if (dimensions.empty()) 4085 printer << "]"; 4086 } 4087 4088 ParseResult parseDimensionList(OpAsmParser &parser, 4089 DenseI64ArrayAttr &dimensions) { 4090 // Empty list case denoted by "[]". 4091 if (succeeded(parser.parseOptionalLSquare())) { 4092 if (failed(parser.parseRSquare())) { 4093 return parser.emitError(parser.getCurrentLocation()) 4094 << "Failed parsing dimension list."; 4095 } 4096 dimensions = 4097 DenseI64ArrayAttr::get(parser.getContext(), ArrayRef<int64_t>()); 4098 return success(); 4099 } 4100 4101 // Non-empty list case. 4102 SmallVector<int64_t> shapeArr; 4103 if (failed(parser.parseDimensionList(shapeArr, true, false))) { 4104 return parser.emitError(parser.getCurrentLocation()) 4105 << "Failed parsing dimension list."; 4106 } 4107 if (shapeArr.empty()) { 4108 return parser.emitError(parser.getCurrentLocation()) 4109 << "Failed parsing dimension list. Did you mean an empty list? It " 4110 "must be denoted by \"[]\"."; 4111 } 4112 dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr); 4113 return success(); 4114 } 4115 4116 } // namespace mlir 4117