xref: /llvm-project/mlir/lib/IR/AsmPrinter.cpp (revision 3c64f86314fbf9a3cd578419f16e621a4de57eaa)
1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/BuiltinDialect.h"
21 #include "mlir/IR/BuiltinTypeInterfaces.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Dialect.h"
24 #include "mlir/IR/DialectImplementation.h"
25 #include "mlir/IR/DialectResourceBlobManager.h"
26 #include "mlir/IR/IntegerSet.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/OpImplementation.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/Verifier.h"
31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/ADT/ArrayRef.h"
33 #include "llvm/ADT/DenseMap.h"
34 #include "llvm/ADT/MapVector.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/ScopeExit.h"
37 #include "llvm/ADT/ScopedHashTable.h"
38 #include "llvm/ADT/SetVector.h"
39 #include "llvm/ADT/SmallString.h"
40 #include "llvm/ADT/StringExtras.h"
41 #include "llvm/ADT/StringSet.h"
42 #include "llvm/ADT/TypeSwitch.h"
43 #include "llvm/Support/CommandLine.h"
44 #include "llvm/Support/Debug.h"
45 #include "llvm/Support/Endian.h"
46 #include "llvm/Support/ManagedStatic.h"
47 #include "llvm/Support/Regex.h"
48 #include "llvm/Support/SaveAndRestore.h"
49 #include "llvm/Support/Threading.h"
50 #include "llvm/Support/raw_ostream.h"
51 #include <type_traits>
52 
53 #include <optional>
54 #include <tuple>
55 
56 using namespace mlir;
57 using namespace mlir::detail;
58 
59 #define DEBUG_TYPE "mlir-asm-printer"
60 
61 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
62 
63 void OperationName::dump() const { print(llvm::errs()); }
64 
65 //===--------------------------------------------------------------------===//
66 // AsmParser
67 //===--------------------------------------------------------------------===//
68 
69 AsmParser::~AsmParser() = default;
70 DialectAsmParser::~DialectAsmParser() = default;
71 OpAsmParser::~OpAsmParser() = default;
72 
73 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
74 
75 /// Parse a type list.
76 /// This is out-of-line to work-around
77 /// https://github.com/llvm/llvm-project/issues/62918
78 ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
79   return parseCommaSeparatedList(
80       [&]() { return parseType(result.emplace_back()); });
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // DialectAsmPrinter
85 //===----------------------------------------------------------------------===//
86 
87 DialectAsmPrinter::~DialectAsmPrinter() = default;
88 
89 //===----------------------------------------------------------------------===//
90 // OpAsmPrinter
91 //===----------------------------------------------------------------------===//
92 
93 OpAsmPrinter::~OpAsmPrinter() = default;
94 
95 void OpAsmPrinter::printFunctionalType(Operation *op) {
96   auto &os = getStream();
97   os << '(';
98   llvm::interleaveComma(op->getOperands(), os, [&](Value operand) {
99     // Print the types of null values as <<NULL TYPE>>.
100     *this << (operand ? operand.getType() : Type());
101   });
102   os << ") -> ";
103 
104   // Print the result list.  We don't parenthesize single result types unless
105   // it is a function (avoiding a grammar ambiguity).
106   bool wrapped = op->getNumResults() != 1;
107   if (!wrapped && op->getResult(0).getType() &&
108       llvm::isa<FunctionType>(op->getResult(0).getType()))
109     wrapped = true;
110 
111   if (wrapped)
112     os << '(';
113 
114   llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) {
115     // Print the types of null values as <<NULL TYPE>>.
116     *this << (result ? result.getType() : Type());
117   });
118 
119   if (wrapped)
120     os << ')';
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // Operation OpAsm interface.
125 //===----------------------------------------------------------------------===//
126 
127 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
128 #include "mlir/IR/OpAsmOpInterface.cpp.inc"
129 #include "mlir/IR/OpAsmTypeInterface.cpp.inc"
130 
131 LogicalResult
132 OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
133   return entry.emitError() << "unknown 'resource' key '" << entry.getKey()
134                            << "' for dialect '" << getDialect()->getNamespace()
135                            << "'";
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // OpPrintingFlags
140 //===----------------------------------------------------------------------===//
141 
142 namespace {
143 /// This struct contains command line options that can be used to initialize
144 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
145 /// for global command line options.
146 struct AsmPrinterOptions {
147   llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
148       "mlir-print-elementsattrs-with-hex-if-larger",
149       llvm::cl::desc(
150           "Print DenseElementsAttrs with a hex string that have "
151           "more elements than the given upper limit (use -1 to disable)")};
152 
153   llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
154       "mlir-elide-elementsattrs-if-larger",
155       llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
156                      "more elements than the given upper limit")};
157 
158   llvm::cl::opt<unsigned> elideResourceStringsIfLarger{
159       "mlir-elide-resource-strings-if-larger",
160       llvm::cl::desc(
161           "Elide printing value of resources if string is too long in chars.")};
162 
163   llvm::cl::opt<bool> printDebugInfoOpt{
164       "mlir-print-debuginfo", llvm::cl::init(false),
165       llvm::cl::desc("Print debug info in MLIR output")};
166 
167   llvm::cl::opt<bool> printPrettyDebugInfoOpt{
168       "mlir-pretty-debuginfo", llvm::cl::init(false),
169       llvm::cl::desc("Print pretty debug info in MLIR output")};
170 
171   // Use the generic op output form in the operation printer even if the custom
172   // form is defined.
173   llvm::cl::opt<bool> printGenericOpFormOpt{
174       "mlir-print-op-generic", llvm::cl::init(false),
175       llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
176 
177   llvm::cl::opt<bool> assumeVerifiedOpt{
178       "mlir-print-assume-verified", llvm::cl::init(false),
179       llvm::cl::desc("Skip op verification when using custom printers"),
180       llvm::cl::Hidden};
181 
182   llvm::cl::opt<bool> printLocalScopeOpt{
183       "mlir-print-local-scope", llvm::cl::init(false),
184       llvm::cl::desc("Print with local scope and inline information (eliding "
185                      "aliases for attributes, types, and locations")};
186 
187   llvm::cl::opt<bool> skipRegionsOpt{
188       "mlir-print-skip-regions", llvm::cl::init(false),
189       llvm::cl::desc("Skip regions when printing ops.")};
190 
191   llvm::cl::opt<bool> printValueUsers{
192       "mlir-print-value-users", llvm::cl::init(false),
193       llvm::cl::desc(
194           "Print users of operation results and block arguments as a comment")};
195 
196   llvm::cl::opt<bool> printUniqueSSAIDs{
197       "mlir-print-unique-ssa-ids", llvm::cl::init(false),
198       llvm::cl::desc("Print unique SSA ID numbers for values, block arguments "
199                      "and naming conflicts across all regions")};
200 
201   llvm::cl::opt<bool> useNameLocAsPrefix{
202       "mlir-use-nameloc-as-prefix", llvm::cl::init(false),
203       llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")};
204 };
205 } // namespace
206 
207 static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
208 
209 /// Register a set of useful command-line options that can be used to configure
210 /// various flags within the AsmPrinter.
211 void mlir::registerAsmPrinterCLOptions() {
212   // Make sure that the options struct has been initialized.
213   *clOptions;
214 }
215 
216 /// Initialize the printing flags with default supplied by the cl::opts above.
217 OpPrintingFlags::OpPrintingFlags()
218     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
219       printGenericOpFormFlag(false), skipRegionsFlag(false),
220       assumeVerifiedFlag(false), printLocalScope(false),
221       printValueUsersFlag(false), printUniqueSSAIDsFlag(false),
222       useNameLocAsPrefix(false) {
223   // Initialize based upon command line options, if they are available.
224   if (!clOptions.isConstructed())
225     return;
226   if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
227     elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
228   if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences())
229     elementsAttrHexElementLimit =
230         clOptions->printElementsAttrWithHexIfLarger.getValue();
231   if (clOptions->elideResourceStringsIfLarger.getNumOccurrences())
232     resourceStringCharLimit = clOptions->elideResourceStringsIfLarger;
233   printDebugInfoFlag = clOptions->printDebugInfoOpt;
234   printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
235   printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
236   assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
237   printLocalScope = clOptions->printLocalScopeOpt;
238   skipRegionsFlag = clOptions->skipRegionsOpt;
239   printValueUsersFlag = clOptions->printValueUsers;
240   printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs;
241   useNameLocAsPrefix = clOptions->useNameLocAsPrefix;
242 }
243 
244 /// Enable the elision of large elements attributes, by printing a '...'
245 /// instead of the element data, when the number of elements is greater than
246 /// `largeElementLimit`. Note: The IR generated with this option is not
247 /// parsable.
248 OpPrintingFlags &
249 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
250   elementsAttrElementLimit = largeElementLimit;
251   return *this;
252 }
253 
254 OpPrintingFlags &
255 OpPrintingFlags::printLargeElementsAttrWithHex(int64_t largeElementLimit) {
256   elementsAttrHexElementLimit = largeElementLimit;
257   return *this;
258 }
259 
260 OpPrintingFlags &
261 OpPrintingFlags::elideLargeResourceString(int64_t largeResourceLimit) {
262   resourceStringCharLimit = largeResourceLimit;
263   return *this;
264 }
265 
266 /// Enable printing of debug information. If 'prettyForm' is set to true,
267 /// debug information is printed in a more readable 'pretty' form.
268 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool enable,
269                                                   bool prettyForm) {
270   printDebugInfoFlag = enable;
271   printDebugInfoPrettyFormFlag = prettyForm;
272   return *this;
273 }
274 
275 /// Always print operations in the generic form.
276 OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) {
277   printGenericOpFormFlag = enable;
278   return *this;
279 }
280 
281 /// Always skip Regions.
282 OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) {
283   skipRegionsFlag = skip;
284   return *this;
285 }
286 
287 /// Do not verify the operation when using custom operation printers.
288 OpPrintingFlags &OpPrintingFlags::assumeVerified(bool enable) {
289   assumeVerifiedFlag = enable;
290   return *this;
291 }
292 
293 /// Use local scope when printing the operation. This allows for using the
294 /// printer in a more localized and thread-safe setting, but may not necessarily
295 /// be identical of what the IR will look like when dumping the full module.
296 OpPrintingFlags &OpPrintingFlags::useLocalScope(bool enable) {
297   printLocalScope = enable;
298   return *this;
299 }
300 
301 /// Print users of values as comments.
302 OpPrintingFlags &OpPrintingFlags::printValueUsers(bool enable) {
303   printValueUsersFlag = enable;
304   return *this;
305 }
306 
307 /// Print unique SSA ID numbers for values, block arguments and naming conflicts
308 /// across all regions
309 OpPrintingFlags &OpPrintingFlags::printUniqueSSAIDs(bool enable) {
310   printUniqueSSAIDsFlag = enable;
311   return *this;
312 }
313 
314 /// Return if the given ElementsAttr should be elided.
315 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
316   return elementsAttrElementLimit &&
317          *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
318          !llvm::isa<SplatElementsAttr>(attr);
319 }
320 
321 /// Return if the given ElementsAttr should be printed as hex string.
322 bool OpPrintingFlags::shouldPrintElementsAttrWithHex(ElementsAttr attr) const {
323   // -1 is used to disable hex printing.
324   return (elementsAttrHexElementLimit != -1) &&
325          (elementsAttrHexElementLimit < int64_t(attr.getNumElements())) &&
326          !llvm::isa<SplatElementsAttr>(attr);
327 }
328 
329 /// Return the size limit for printing large ElementsAttr.
330 std::optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
331   return elementsAttrElementLimit;
332 }
333 
334 /// Return the size limit for printing large ElementsAttr as hex string.
335 int64_t OpPrintingFlags::getLargeElementsAttrHexLimit() const {
336   return elementsAttrHexElementLimit;
337 }
338 
339 /// Return the size limit for printing large ElementsAttr.
340 std::optional<uint64_t> OpPrintingFlags::getLargeResourceStringLimit() const {
341   return resourceStringCharLimit;
342 }
343 
344 /// Return if debug information should be printed.
345 bool OpPrintingFlags::shouldPrintDebugInfo() const {
346   return printDebugInfoFlag;
347 }
348 
349 /// Return if debug information should be printed in the pretty form.
350 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
351   return printDebugInfoPrettyFormFlag;
352 }
353 
354 /// Return if operations should be printed in the generic form.
355 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
356   return printGenericOpFormFlag;
357 }
358 
359 /// Return if Region should be skipped.
360 bool OpPrintingFlags::shouldSkipRegions() const { return skipRegionsFlag; }
361 
362 /// Return if operation verification should be skipped.
363 bool OpPrintingFlags::shouldAssumeVerified() const {
364   return assumeVerifiedFlag;
365 }
366 
367 /// Return if the printer should use local scope when dumping the IR.
368 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
369 
370 /// Return if the printer should print users of values.
371 bool OpPrintingFlags::shouldPrintValueUsers() const {
372   return printValueUsersFlag;
373 }
374 
375 /// Return if the printer should use unique IDs.
376 bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const {
377   return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
378 }
379 
380 /// Return if the printer should use NameLocs as prefixes when printing SSA IDs.
381 bool OpPrintingFlags::shouldUseNameLocAsPrefix() const {
382   return useNameLocAsPrefix;
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // NewLineCounter
387 //===----------------------------------------------------------------------===//
388 
389 namespace {
390 /// This class is a simple formatter that emits a new line when inputted into a
391 /// stream, that enables counting the number of newlines emitted. This class
392 /// should be used whenever emitting newlines in the printer.
393 struct NewLineCounter {
394   unsigned curLine = 1;
395 };
396 
397 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
398   ++newLine.curLine;
399   return os << '\n';
400 }
401 } // namespace
402 
403 //===----------------------------------------------------------------------===//
404 // AsmPrinter::Impl
405 //===----------------------------------------------------------------------===//
406 
407 namespace mlir {
408 class AsmPrinter::Impl {
409 public:
410   Impl(raw_ostream &os, AsmStateImpl &state);
411   explicit Impl(Impl &other) : Impl(other.os, other.state) {}
412 
413   /// Returns the output stream of the printer.
414   raw_ostream &getStream() { return os; }
415 
416   template <typename Container, typename UnaryFunctor>
417   inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
418     llvm::interleaveComma(c, os, eachFn);
419   }
420 
421   /// This enum describes the different kinds of elision for the type of an
422   /// attribute when printing it.
423   enum class AttrTypeElision {
424     /// The type must not be elided,
425     Never,
426     /// The type may be elided when it matches the default used in the parser
427     /// (for example i64 is the default for integer attributes).
428     May,
429     /// The type must be elided.
430     Must
431   };
432 
433   /// Print the given attribute or an alias.
434   void printAttribute(Attribute attr,
435                       AttrTypeElision typeElision = AttrTypeElision::Never);
436   /// Print the given attribute without considering an alias.
437   void printAttributeImpl(Attribute attr,
438                           AttrTypeElision typeElision = AttrTypeElision::Never);
439 
440   /// Print the alias for the given attribute, return failure if no alias could
441   /// be printed.
442   LogicalResult printAlias(Attribute attr);
443 
444   /// Print the given type or an alias.
445   void printType(Type type);
446   /// Print the given type.
447   void printTypeImpl(Type type);
448 
449   /// Print the alias for the given type, return failure if no alias could
450   /// be printed.
451   LogicalResult printAlias(Type type);
452 
453   /// Print the given location to the stream. If `allowAlias` is true, this
454   /// allows for the internal location to use an attribute alias.
455   void printLocation(LocationAttr loc, bool allowAlias = false);
456 
457   /// Print a reference to the given resource that is owned by the given
458   /// dialect.
459   void printResourceHandle(const AsmDialectResourceHandle &resource);
460 
461   void printAffineMap(AffineMap map);
462   void
463   printAffineExpr(AffineExpr expr,
464                   function_ref<void(unsigned, bool)> printValueName = nullptr);
465   void printAffineConstraint(AffineExpr expr, bool isEq);
466   void printIntegerSet(IntegerSet set);
467 
468   LogicalResult pushCyclicPrinting(const void *opaquePointer);
469 
470   void popCyclicPrinting();
471 
472   void printDimensionList(ArrayRef<int64_t> shape);
473 
474 protected:
475   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
476                              ArrayRef<StringRef> elidedAttrs = {},
477                              bool withKeyword = false);
478   void printNamedAttribute(NamedAttribute attr);
479   void printTrailingLocation(Location loc, bool allowAlias = true);
480   void printLocationInternal(LocationAttr loc, bool pretty = false,
481                              bool isTopLevel = false);
482 
483   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
484   /// used instead of individual elements when the elements attr is large.
485   void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
486 
487   /// Print a dense string elements attribute.
488   void printDenseStringElementsAttr(DenseStringElementsAttr attr);
489 
490   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
491   /// used instead of individual elements when the elements attr is large.
492   void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
493                                      bool allowHex);
494 
495   /// Print a dense array attribute.
496   void printDenseArrayAttr(DenseArrayAttr attr);
497 
498   void printDialectAttribute(Attribute attr);
499   void printDialectType(Type type);
500 
501   /// Print an escaped string, wrapped with "".
502   void printEscapedString(StringRef str);
503 
504   /// Print a hex string, wrapped with "".
505   void printHexString(StringRef str);
506   void printHexString(ArrayRef<char> data);
507 
508   /// This enum is used to represent the binding strength of the enclosing
509   /// context that an AffineExprStorage is being printed in, so we can
510   /// intelligently produce parens.
511   enum class BindingStrength {
512     Weak,   // + and -
513     Strong, // All other binary operators.
514   };
515   void printAffineExprInternal(
516       AffineExpr expr, BindingStrength enclosingTightness,
517       function_ref<void(unsigned, bool)> printValueName = nullptr);
518 
519   /// The output stream for the printer.
520   raw_ostream &os;
521 
522   /// An underlying assembly printer state.
523   AsmStateImpl &state;
524 
525   /// A set of flags to control the printer's behavior.
526   OpPrintingFlags printerFlags;
527 
528   /// A tracker for the number of new lines emitted during printing.
529   NewLineCounter newLine;
530 };
531 } // namespace mlir
532 
533 //===----------------------------------------------------------------------===//
534 // AliasInitializer
535 //===----------------------------------------------------------------------===//
536 
537 namespace {
538 /// This class represents a specific instance of a symbol Alias.
539 class SymbolAlias {
540 public:
541   SymbolAlias(StringRef name, uint32_t suffixIndex, bool isType,
542               bool isDeferrable)
543       : name(name), suffixIndex(suffixIndex), isType(isType),
544         isDeferrable(isDeferrable) {}
545 
546   /// Print this alias to the given stream.
547   void print(raw_ostream &os) const {
548     os << (isType ? "!" : "#") << name;
549     if (suffixIndex)
550       os << suffixIndex;
551   }
552 
553   /// Returns true if this is a type alias.
554   bool isTypeAlias() const { return isType; }
555 
556   /// Returns true if this alias supports deferred resolution when parsing.
557   bool canBeDeferred() const { return isDeferrable; }
558 
559 private:
560   /// The main name of the alias.
561   StringRef name;
562   /// The suffix index of the alias.
563   uint32_t suffixIndex : 30;
564   /// A flag indicating whether this alias is for a type.
565   bool isType : 1;
566   /// A flag indicating whether this alias may be deferred or not.
567   bool isDeferrable : 1;
568 
569 public:
570   /// Used to avoid printing incomplete aliases for recursive types.
571   bool isPrinted = false;
572 };
573 
574 /// This class represents a utility that initializes the set of attribute and
575 /// type aliases, without the need to store the extra information within the
576 /// main AliasState class or pass it around via function arguments.
577 class AliasInitializer {
578 public:
579   AliasInitializer(
580       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
581       llvm::BumpPtrAllocator &aliasAllocator)
582       : interfaces(interfaces), aliasAllocator(aliasAllocator),
583         aliasOS(aliasBuffer) {}
584 
585   void initialize(Operation *op, const OpPrintingFlags &printerFlags,
586                   llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias);
587 
588   /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
589   /// set to true if the originator of this attribute can resolve the alias
590   /// after parsing has completed (e.g. in the case of operation locations).
591   /// `elideType` indicates if the type of the attribute should be skipped when
592   /// looking for nested aliases. Returns the maximum alias depth of the
593   /// attribute, and the alias index of this attribute.
594   std::pair<size_t, size_t> visit(Attribute attr, bool canBeDeferred = false,
595                                   bool elideType = false) {
596     return visitImpl(attr, aliases, canBeDeferred, elideType);
597   }
598 
599   /// Visit the given type to see if it has an alias. `canBeDeferred` is
600   /// set to true if the originator of this attribute can resolve the alias
601   /// after parsing has completed. Returns the maximum alias depth of the type,
602   /// and the alias index of this type.
603   std::pair<size_t, size_t> visit(Type type, bool canBeDeferred = false) {
604     return visitImpl(type, aliases, canBeDeferred);
605   }
606 
607 private:
608   struct InProgressAliasInfo {
609     InProgressAliasInfo()
610         : aliasDepth(0), isType(false), canBeDeferred(false) {}
611     InProgressAliasInfo(StringRef alias)
612         : alias(alias), aliasDepth(1), isType(false), canBeDeferred(false) {}
613 
614     bool operator<(const InProgressAliasInfo &rhs) const {
615       // Order first by depth, then by attr/type kind, and then by name.
616       if (aliasDepth != rhs.aliasDepth)
617         return aliasDepth < rhs.aliasDepth;
618       if (isType != rhs.isType)
619         return isType;
620       return alias < rhs.alias;
621     }
622 
623     /// The alias for the attribute or type, or std::nullopt if the value has no
624     /// alias.
625     std::optional<StringRef> alias;
626     /// The alias depth of this attribute or type, i.e. an indication of the
627     /// relative ordering of when to print this alias.
628     unsigned aliasDepth : 30;
629     /// If this alias represents a type or an attribute.
630     bool isType : 1;
631     /// If this alias can be deferred or not.
632     bool canBeDeferred : 1;
633     /// Indices for child aliases.
634     SmallVector<size_t> childIndices;
635   };
636 
637   /// Visit the given attribute or type to see if it has an alias.
638   /// `canBeDeferred` is set to true if the originator of this value can resolve
639   /// the alias after parsing has completed (e.g. in the case of operation
640   /// locations). Returns the maximum alias depth of the value, and its alias
641   /// index.
642   template <typename T, typename... PrintArgs>
643   std::pair<size_t, size_t>
644   visitImpl(T value,
645             llvm::MapVector<const void *, InProgressAliasInfo> &aliases,
646             bool canBeDeferred, PrintArgs &&...printArgs);
647 
648   /// Mark the given alias as non-deferrable.
649   void markAliasNonDeferrable(size_t aliasIndex);
650 
651   /// Try to generate an alias for the provided symbol. If an alias is
652   /// generated, the provided alias mapping and reverse mapping are updated.
653   template <typename T>
654   void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred);
655 
656   /// Given a collection of aliases and symbols, initialize a mapping from a
657   /// symbol to a given alias.
658   static void initializeAliases(
659       llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
660       llvm::MapVector<const void *, SymbolAlias> &symbolToAlias);
661 
662   /// The set of asm interfaces within the context.
663   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
664 
665   /// An allocator used for alias names.
666   llvm::BumpPtrAllocator &aliasAllocator;
667 
668   /// The set of built aliases.
669   llvm::MapVector<const void *, InProgressAliasInfo> aliases;
670 
671   /// Storage and stream used when generating an alias.
672   SmallString<32> aliasBuffer;
673   llvm::raw_svector_ostream aliasOS;
674 };
675 
676 /// This class implements a dummy OpAsmPrinter that doesn't print any output,
677 /// and merely collects the attributes and types that *would* be printed in a
678 /// normal print invocation so that we can generate proper aliases. This allows
679 /// for us to generate aliases only for the attributes and types that would be
680 /// in the output, and trims down unnecessary output.
681 class DummyAliasOperationPrinter : private OpAsmPrinter {
682 public:
683   explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
684                                       AliasInitializer &initializer)
685       : printerFlags(printerFlags), initializer(initializer) {}
686 
687   /// Prints the entire operation with the custom assembly form, if available,
688   /// or the generic assembly form, otherwise.
689   void printCustomOrGenericOp(Operation *op) override {
690     // Visit the operation location.
691     if (printerFlags.shouldPrintDebugInfo())
692       initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
693 
694     // If requested, always print the generic form.
695     if (!printerFlags.shouldPrintGenericOpForm()) {
696       op->getName().printAssembly(op, *this, /*defaultDialect=*/"");
697       return;
698     }
699 
700     // Otherwise print with the generic assembly form.
701     printGenericOp(op);
702   }
703 
704 private:
705   /// Print the given operation in the generic form.
706   void printGenericOp(Operation *op, bool printOpName = true) override {
707     // Consider nested operations for aliases.
708     if (!printerFlags.shouldSkipRegions()) {
709       for (Region &region : op->getRegions())
710         printRegion(region, /*printEntryBlockArgs=*/true,
711                     /*printBlockTerminators=*/true);
712     }
713 
714     // Visit all the types used in the operation.
715     for (Type type : op->getOperandTypes())
716       printType(type);
717     for (Type type : op->getResultTypes())
718       printType(type);
719 
720     // Consider the attributes of the operation for aliases.
721     for (const NamedAttribute &attr : op->getAttrs())
722       printAttribute(attr.getValue());
723   }
724 
725   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
726   /// block are not printed. If 'printBlockTerminator' is false, the terminator
727   /// operation of the block is not printed.
728   void print(Block *block, bool printBlockArgs = true,
729              bool printBlockTerminator = true) {
730     // Consider the types of the block arguments for aliases if 'printBlockArgs'
731     // is set to true.
732     if (printBlockArgs) {
733       for (BlockArgument arg : block->getArguments()) {
734         printType(arg.getType());
735 
736         // Visit the argument location.
737         if (printerFlags.shouldPrintDebugInfo())
738           // TODO: Allow deferring argument locations.
739           initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
740       }
741     }
742 
743     // Consider the operations within this block, ignoring the terminator if
744     // requested.
745     bool hasTerminator =
746         !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
747     auto range = llvm::make_range(
748         block->begin(),
749         std::prev(block->end(),
750                   (!hasTerminator || printBlockTerminator) ? 0 : 1));
751     for (Operation &op : range)
752       printCustomOrGenericOp(&op);
753   }
754 
755   /// Print the given region.
756   void printRegion(Region &region, bool printEntryBlockArgs,
757                    bool printBlockTerminators,
758                    bool printEmptyBlock = false) override {
759     if (region.empty())
760       return;
761     if (printerFlags.shouldSkipRegions()) {
762       os << "{...}";
763       return;
764     }
765 
766     auto *entryBlock = &region.front();
767     print(entryBlock, printEntryBlockArgs, printBlockTerminators);
768     for (Block &b : llvm::drop_begin(region, 1))
769       print(&b);
770   }
771 
772   void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
773                            bool omitType) override {
774     printType(arg.getType());
775     // Visit the argument location.
776     if (printerFlags.shouldPrintDebugInfo())
777       // TODO: Allow deferring argument locations.
778       initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
779   }
780 
781   /// Consider the given type to be printed for an alias.
782   void printType(Type type) override { initializer.visit(type); }
783 
784   /// Consider the given attribute to be printed for an alias.
785   void printAttribute(Attribute attr) override { initializer.visit(attr); }
786   void printAttributeWithoutType(Attribute attr) override {
787     printAttribute(attr);
788   }
789   LogicalResult printAlias(Attribute attr) override {
790     initializer.visit(attr);
791     return success();
792   }
793   LogicalResult printAlias(Type type) override {
794     initializer.visit(type);
795     return success();
796   }
797 
798   /// Consider the given location to be printed for an alias.
799   void printOptionalLocationSpecifier(Location loc) override {
800     printAttribute(loc);
801   }
802 
803   /// Print the given set of attributes with names not included within
804   /// 'elidedAttrs'.
805   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
806                              ArrayRef<StringRef> elidedAttrs = {}) override {
807     if (attrs.empty())
808       return;
809     if (elidedAttrs.empty()) {
810       for (const NamedAttribute &attr : attrs)
811         printAttribute(attr.getValue());
812       return;
813     }
814     llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
815                                                   elidedAttrs.end());
816     for (const NamedAttribute &attr : attrs)
817       if (!elidedAttrsSet.contains(attr.getName().strref()))
818         printAttribute(attr.getValue());
819   }
820   void printOptionalAttrDictWithKeyword(
821       ArrayRef<NamedAttribute> attrs,
822       ArrayRef<StringRef> elidedAttrs = {}) override {
823     printOptionalAttrDict(attrs, elidedAttrs);
824   }
825 
826   /// Return a null stream as the output stream, this will ignore any data fed
827   /// to it.
828   raw_ostream &getStream() const override { return os; }
829 
830   /// The following are hooks of `OpAsmPrinter` that are not necessary for
831   /// determining potential aliases.
832   void printFloat(const APFloat &) override {}
833   void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
834   void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
835   void printNewline() override {}
836   void increaseIndent() override {}
837   void decreaseIndent() override {}
838   void printOperand(Value) override {}
839   void printOperand(Value, raw_ostream &os) override {
840     // Users expect the output string to have at least the prefixed % to signal
841     // a value name. To maintain this invariant, emit a name even if it is
842     // guaranteed to go unused.
843     os << "%";
844   }
845   void printKeywordOrString(StringRef) override {}
846   void printString(StringRef) override {}
847   void printResourceHandle(const AsmDialectResourceHandle &) override {}
848   void printSymbolName(StringRef) override {}
849   void printSuccessor(Block *) override {}
850   void printSuccessorAndUseList(Block *, ValueRange) override {}
851   void shadowRegionArgs(Region &, ValueRange) override {}
852 
853   /// The printer flags to use when determining potential aliases.
854   const OpPrintingFlags &printerFlags;
855 
856   /// The initializer to use when identifying aliases.
857   AliasInitializer &initializer;
858 
859   /// A dummy output stream.
860   mutable llvm::raw_null_ostream os;
861 };
862 
863 class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
864 public:
865   explicit DummyAliasDialectAsmPrinter(AliasInitializer &initializer,
866                                        bool canBeDeferred,
867                                        SmallVectorImpl<size_t> &childIndices)
868       : initializer(initializer), canBeDeferred(canBeDeferred),
869         childIndices(childIndices) {}
870 
871   /// Print the given attribute/type, visiting any nested aliases that would be
872   /// generated as part of printing. Returns the maximum alias depth found while
873   /// printing the given value.
874   template <typename T, typename... PrintArgs>
875   size_t printAndVisitNestedAliases(T value, PrintArgs &&...printArgs) {
876     printAndVisitNestedAliasesImpl(value, printArgs...);
877     return maxAliasDepth;
878   }
879 
880 private:
881   /// Print the given attribute/type, visiting any nested aliases that would be
882   /// generated as part of printing.
883   void printAndVisitNestedAliasesImpl(Attribute attr, bool elideType) {
884     if (!isa<BuiltinDialect>(attr.getDialect())) {
885       attr.getDialect().printAttribute(attr, *this);
886 
887       // Process the builtin attributes.
888     } else if (llvm::isa<AffineMapAttr, DenseArrayAttr, FloatAttr, IntegerAttr,
889                          IntegerSetAttr, UnitAttr>(attr)) {
890       return;
891     } else if (auto distinctAttr = dyn_cast<DistinctAttr>(attr)) {
892       printAttribute(distinctAttr.getReferencedAttr());
893     } else if (auto dictAttr = dyn_cast<DictionaryAttr>(attr)) {
894       for (const NamedAttribute &nestedAttr : dictAttr.getValue()) {
895         printAttribute(nestedAttr.getName());
896         printAttribute(nestedAttr.getValue());
897       }
898     } else if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
899       for (Attribute nestedAttr : arrayAttr.getValue())
900         printAttribute(nestedAttr);
901     } else if (auto typeAttr = dyn_cast<TypeAttr>(attr)) {
902       printType(typeAttr.getValue());
903     } else if (auto locAttr = dyn_cast<OpaqueLoc>(attr)) {
904       printAttribute(locAttr.getFallbackLocation());
905     } else if (auto locAttr = dyn_cast<NameLoc>(attr)) {
906       if (!isa<UnknownLoc>(locAttr.getChildLoc()))
907         printAttribute(locAttr.getChildLoc());
908     } else if (auto locAttr = dyn_cast<CallSiteLoc>(attr)) {
909       printAttribute(locAttr.getCallee());
910       printAttribute(locAttr.getCaller());
911     } else if (auto locAttr = dyn_cast<FusedLoc>(attr)) {
912       if (Attribute metadata = locAttr.getMetadata())
913         printAttribute(metadata);
914       for (Location nestedLoc : locAttr.getLocations())
915         printAttribute(nestedLoc);
916     }
917 
918     // Don't print the type if we must elide it, or if it is a None type.
919     if (!elideType) {
920       if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) {
921         Type attrType = typedAttr.getType();
922         if (!llvm::isa<NoneType>(attrType))
923           printType(attrType);
924       }
925     }
926   }
927   void printAndVisitNestedAliasesImpl(Type type) {
928     if (!isa<BuiltinDialect>(type.getDialect()))
929       return type.getDialect().printType(type, *this);
930 
931     // Only visit the layout of memref if it isn't the identity.
932     if (auto memrefTy = llvm::dyn_cast<MemRefType>(type)) {
933       printType(memrefTy.getElementType());
934       MemRefLayoutAttrInterface layout = memrefTy.getLayout();
935       if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity())
936         printAttribute(memrefTy.getLayout());
937       if (memrefTy.getMemorySpace())
938         printAttribute(memrefTy.getMemorySpace());
939       return;
940     }
941 
942     // For most builtin types, we can simply walk the sub elements.
943     auto visitFn = [&](auto element) {
944       if (element)
945         (void)printAlias(element);
946     };
947     type.walkImmediateSubElements(visitFn, visitFn);
948   }
949 
950   /// Consider the given type to be printed for an alias.
951   void printType(Type type) override {
952     recordAliasResult(initializer.visit(type, canBeDeferred));
953   }
954 
955   /// Consider the given attribute to be printed for an alias.
956   void printAttribute(Attribute attr) override {
957     recordAliasResult(initializer.visit(attr, canBeDeferred));
958   }
959   void printAttributeWithoutType(Attribute attr) override {
960     recordAliasResult(
961         initializer.visit(attr, canBeDeferred, /*elideType=*/true));
962   }
963   LogicalResult printAlias(Attribute attr) override {
964     printAttribute(attr);
965     return success();
966   }
967   LogicalResult printAlias(Type type) override {
968     printType(type);
969     return success();
970   }
971 
972   /// Record the alias result of a child element.
973   void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) {
974     childIndices.push_back(aliasDepthAndIndex.second);
975     if (aliasDepthAndIndex.first > maxAliasDepth)
976       maxAliasDepth = aliasDepthAndIndex.first;
977   }
978 
979   /// Return a null stream as the output stream, this will ignore any data fed
980   /// to it.
981   raw_ostream &getStream() const override { return os; }
982 
983   /// The following are hooks of `DialectAsmPrinter` that are not necessary for
984   /// determining potential aliases.
985   void printFloat(const APFloat &) override {}
986   void printKeywordOrString(StringRef) override {}
987   void printString(StringRef) override {}
988   void printSymbolName(StringRef) override {}
989   void printResourceHandle(const AsmDialectResourceHandle &) override {}
990 
991   LogicalResult pushCyclicPrinting(const void *opaquePointer) override {
992     return success(cyclicPrintingStack.insert(opaquePointer));
993   }
994 
995   void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); }
996 
997   /// Stack of potentially cyclic mutable attributes or type currently being
998   /// printed.
999   SetVector<const void *> cyclicPrintingStack;
1000 
1001   /// The initializer to use when identifying aliases.
1002   AliasInitializer &initializer;
1003 
1004   /// If the aliases visited by this printer can be deferred.
1005   bool canBeDeferred;
1006 
1007   /// The indices of child aliases.
1008   SmallVectorImpl<size_t> &childIndices;
1009 
1010   /// The maximum alias depth found by the printer.
1011   size_t maxAliasDepth = 0;
1012 
1013   /// A dummy output stream.
1014   mutable llvm::raw_null_ostream os;
1015 };
1016 } // namespace
1017 
1018 /// Sanitize the given name such that it can be used as a valid identifier. If
1019 /// the string needs to be modified in any way, the provided buffer is used to
1020 /// store the new copy,
1021 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
1022                                     StringRef allowedPunctChars = "$._-",
1023                                     bool allowTrailingDigit = true) {
1024   assert(!name.empty() && "Shouldn't have an empty name here");
1025 
1026   auto validChar = [&](char ch) {
1027     return llvm::isAlnum(ch) || allowedPunctChars.contains(ch);
1028   };
1029 
1030   auto copyNameToBuffer = [&] {
1031     for (char ch : name) {
1032       if (validChar(ch))
1033         buffer.push_back(ch);
1034       else if (ch == ' ')
1035         buffer.push_back('_');
1036       else
1037         buffer.append(llvm::utohexstr((unsigned char)ch));
1038     }
1039   };
1040 
1041   // Check to see if this name is valid. If it starts with a digit, then it
1042   // could conflict with the autogenerated numeric ID's, so add an underscore
1043   // prefix to avoid problems.
1044   if (isdigit(name[0]) || (!validChar(name[0]) && name[0] != ' ')) {
1045     buffer.push_back('_');
1046     copyNameToBuffer();
1047     return buffer;
1048   }
1049 
1050   // If the name ends with a trailing digit, add a '_' to avoid potential
1051   // conflicts with autogenerated ID's.
1052   if (!allowTrailingDigit && isdigit(name.back())) {
1053     copyNameToBuffer();
1054     buffer.push_back('_');
1055     return buffer;
1056   }
1057 
1058   // Check to see that the name consists of only valid identifier characters.
1059   for (char ch : name) {
1060     if (!validChar(ch)) {
1061       copyNameToBuffer();
1062       return buffer;
1063     }
1064   }
1065 
1066   // If there are no invalid characters, return the original name.
1067   return name;
1068 }
1069 
1070 /// Given a collection of aliases and symbols, initialize a mapping from a
1071 /// symbol to a given alias.
1072 void AliasInitializer::initializeAliases(
1073     llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
1074     llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) {
1075   SmallVector<std::pair<const void *, InProgressAliasInfo>, 0>
1076       unprocessedAliases = visitedSymbols.takeVector();
1077   llvm::stable_sort(unprocessedAliases, [](const auto &lhs, const auto &rhs) {
1078     return lhs.second < rhs.second;
1079   });
1080 
1081   llvm::StringMap<unsigned> nameCounts;
1082   for (auto &[symbol, aliasInfo] : unprocessedAliases) {
1083     if (!aliasInfo.alias)
1084       continue;
1085     StringRef alias = *aliasInfo.alias;
1086     unsigned nameIndex = nameCounts[alias]++;
1087     symbolToAlias.insert(
1088         {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
1089                              aliasInfo.canBeDeferred)});
1090   }
1091 }
1092 
1093 void AliasInitializer::initialize(
1094     Operation *op, const OpPrintingFlags &printerFlags,
1095     llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) {
1096   // Use a dummy printer when walking the IR so that we can collect the
1097   // attributes/types that will actually be used during printing when
1098   // considering aliases.
1099   DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
1100   aliasPrinter.printCustomOrGenericOp(op);
1101 
1102   // Initialize the aliases.
1103   initializeAliases(aliases, attrTypeToAlias);
1104 }
1105 
1106 template <typename T, typename... PrintArgs>
1107 std::pair<size_t, size_t> AliasInitializer::visitImpl(
1108     T value, llvm::MapVector<const void *, InProgressAliasInfo> &aliases,
1109     bool canBeDeferred, PrintArgs &&...printArgs) {
1110   auto [it, inserted] =
1111       aliases.insert({value.getAsOpaquePointer(), InProgressAliasInfo()});
1112   size_t aliasIndex = std::distance(aliases.begin(), it);
1113   if (!inserted) {
1114     // Make sure that the alias isn't deferred if we don't permit it.
1115     if (!canBeDeferred)
1116       markAliasNonDeferrable(aliasIndex);
1117     return {static_cast<size_t>(it->second.aliasDepth), aliasIndex};
1118   }
1119 
1120   // Try to generate an alias for this value.
1121   generateAlias(value, it->second, canBeDeferred);
1122   it->second.isType = std::is_base_of_v<Type, T>;
1123   it->second.canBeDeferred = canBeDeferred;
1124 
1125   // Print the value, capturing any nested elements that require aliases.
1126   SmallVector<size_t> childAliases;
1127   DummyAliasDialectAsmPrinter printer(*this, canBeDeferred, childAliases);
1128   size_t maxAliasDepth =
1129       printer.printAndVisitNestedAliases(value, printArgs...);
1130 
1131   // Make sure to recompute `it` in case the map was reallocated.
1132   it = std::next(aliases.begin(), aliasIndex);
1133 
1134   // If we had sub elements, update to account for the depth.
1135   it->second.childIndices = std::move(childAliases);
1136   if (maxAliasDepth)
1137     it->second.aliasDepth = maxAliasDepth + 1;
1138 
1139   // Propagate the alias depth of the value.
1140   return {(size_t)it->second.aliasDepth, aliasIndex};
1141 }
1142 
1143 void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) {
1144   auto *it = std::next(aliases.begin(), aliasIndex);
1145 
1146   // If already marked non-deferrable stop the recursion.
1147   // All children should already be marked non-deferrable as well.
1148   if (!it->second.canBeDeferred)
1149     return;
1150 
1151   it->second.canBeDeferred = false;
1152 
1153   // Propagate the non-deferrable flag to any child aliases.
1154   for (size_t childIndex : it->second.childIndices)
1155     markAliasNonDeferrable(childIndex);
1156 }
1157 
1158 template <typename T>
1159 void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
1160                                      bool canBeDeferred) {
1161   SmallString<32> nameBuffer;
1162   for (const auto &interface : interfaces) {
1163     OpAsmDialectInterface::AliasResult result =
1164         interface.getAlias(symbol, aliasOS);
1165     if (result == OpAsmDialectInterface::AliasResult::NoAlias)
1166       continue;
1167     nameBuffer = std::move(aliasBuffer);
1168     assert(!nameBuffer.empty() && "expected valid alias name");
1169     if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
1170       break;
1171   }
1172 
1173   if (nameBuffer.empty())
1174     return;
1175 
1176   SmallString<16> tempBuffer;
1177   StringRef name =
1178       sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
1179                          /*allowTrailingDigit=*/false);
1180   name = name.copy(aliasAllocator);
1181   alias = InProgressAliasInfo(name);
1182 }
1183 
1184 //===----------------------------------------------------------------------===//
1185 // AliasState
1186 //===----------------------------------------------------------------------===//
1187 
1188 namespace {
1189 /// This class manages the state for type and attribute aliases.
1190 class AliasState {
1191 public:
1192   // Initialize the internal aliases.
1193   void
1194   initialize(Operation *op, const OpPrintingFlags &printerFlags,
1195              DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
1196 
1197   /// Get an alias for the given attribute if it has one and print it in `os`.
1198   /// Returns success if an alias was printed, failure otherwise.
1199   LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
1200 
1201   /// Get an alias for the given type if it has one and print it in `os`.
1202   /// Returns success if an alias was printed, failure otherwise.
1203   LogicalResult getAlias(Type ty, raw_ostream &os) const;
1204 
1205   /// Print all of the referenced aliases that can not be resolved in a deferred
1206   /// manner.
1207   void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
1208     printAliases(p, newLine, /*isDeferred=*/false);
1209   }
1210 
1211   /// Print all of the referenced aliases that support deferred resolution.
1212   void printDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
1213     printAliases(p, newLine, /*isDeferred=*/true);
1214   }
1215 
1216 private:
1217   /// Print all of the referenced aliases that support the provided resolution
1218   /// behavior.
1219   void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
1220                     bool isDeferred);
1221 
1222   /// Mapping between attribute/type and alias.
1223   llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias;
1224 
1225   /// An allocator used for alias names.
1226   llvm::BumpPtrAllocator aliasAllocator;
1227 };
1228 } // namespace
1229 
1230 void AliasState::initialize(
1231     Operation *op, const OpPrintingFlags &printerFlags,
1232     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
1233   AliasInitializer initializer(interfaces, aliasAllocator);
1234   initializer.initialize(op, printerFlags, attrTypeToAlias);
1235 }
1236 
1237 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
1238   const auto *it = attrTypeToAlias.find(attr.getAsOpaquePointer());
1239   if (it == attrTypeToAlias.end())
1240     return failure();
1241   it->second.print(os);
1242   return success();
1243 }
1244 
1245 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
1246   const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer());
1247   if (it == attrTypeToAlias.end())
1248     return failure();
1249   if (!it->second.isPrinted)
1250     return failure();
1251 
1252   it->second.print(os);
1253   return success();
1254 }
1255 
1256 void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
1257                               bool isDeferred) {
1258   auto filterFn = [=](const auto &aliasIt) {
1259     return aliasIt.second.canBeDeferred() == isDeferred;
1260   };
1261   for (auto &[opaqueSymbol, alias] :
1262        llvm::make_filter_range(attrTypeToAlias, filterFn)) {
1263     alias.print(p.getStream());
1264     p.getStream() << " = ";
1265 
1266     if (alias.isTypeAlias()) {
1267       Type type = Type::getFromOpaquePointer(opaqueSymbol);
1268       p.printTypeImpl(type);
1269       alias.isPrinted = true;
1270     } else {
1271       // TODO: Support nested aliases in mutable attributes.
1272       Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol);
1273       if (attr.hasTrait<AttributeTrait::IsMutable>())
1274         p.getStream() << attr;
1275       else
1276         p.printAttributeImpl(attr);
1277     }
1278 
1279     p.getStream() << newLine;
1280   }
1281 }
1282 
1283 //===----------------------------------------------------------------------===//
1284 // SSANameState
1285 //===----------------------------------------------------------------------===//
1286 
1287 namespace {
1288 /// Info about block printing: a number which is its position in the visitation
1289 /// order, and a name that is used to print reference to it, e.g. ^bb42.
1290 struct BlockInfo {
1291   int ordering;
1292   StringRef name;
1293 };
1294 
1295 /// This class manages the state of SSA value names.
1296 class SSANameState {
1297 public:
1298   /// A sentinel value used for values with names set.
1299   enum : unsigned { NameSentinel = ~0U };
1300 
1301   SSANameState(Operation *op, const OpPrintingFlags &printerFlags);
1302   SSANameState() = default;
1303 
1304   /// Print the SSA identifier for the given value to 'stream'. If
1305   /// 'printResultNo' is true, it also presents the result number ('#' number)
1306   /// of this value.
1307   void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
1308 
1309   /// Print the operation identifier.
1310   void printOperationID(Operation *op, raw_ostream &stream) const;
1311 
1312   /// Return the result indices for each of the result groups registered by this
1313   /// operation, or empty if none exist.
1314   ArrayRef<int> getOpResultGroups(Operation *op);
1315 
1316   /// Get the info for the given block.
1317   BlockInfo getBlockInfo(Block *block);
1318 
1319   /// Renumber the arguments for the specified region to the same names as the
1320   /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
1321   /// details.
1322   void shadowRegionArgs(Region &region, ValueRange namesToUse);
1323 
1324 private:
1325   /// Number the SSA values within the given IR unit.
1326   void numberValuesInRegion(Region &region);
1327   void numberValuesInBlock(Block &block);
1328   void numberValuesInOp(Operation &op);
1329 
1330   /// Given a result of an operation 'result', find the result group head
1331   /// 'lookupValue' and the result of 'result' within that group in
1332   /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
1333   /// has more than 1 result.
1334   void getResultIDAndNumber(OpResult result, Value &lookupValue,
1335                             std::optional<int> &lookupResultNo) const;
1336 
1337   /// Set a special value name for the given value.
1338   void setValueName(Value value, StringRef name);
1339 
1340   /// Uniques the given value name within the printer. If the given name
1341   /// conflicts, it is automatically renamed.
1342   StringRef uniqueValueName(StringRef name);
1343 
1344   /// This is the value ID for each SSA value. If this returns NameSentinel,
1345   /// then the valueID has an entry in valueNames.
1346   DenseMap<Value, unsigned> valueIDs;
1347   DenseMap<Value, StringRef> valueNames;
1348 
1349   /// When printing users of values, an operation without a result might
1350   /// be the user. This map holds ids for such operations.
1351   DenseMap<Operation *, unsigned> operationIDs;
1352 
1353   /// This is a map of operations that contain multiple named result groups,
1354   /// i.e. there may be multiple names for the results of the operation. The
1355   /// value of this map are the result numbers that start a result group.
1356   DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
1357 
1358   /// This maps blocks to there visitation number in the current region as well
1359   /// as the string representing their name.
1360   DenseMap<Block *, BlockInfo> blockNames;
1361 
1362   /// This keeps track of all of the non-numeric names that are in flight,
1363   /// allowing us to check for duplicates.
1364   /// Note: the value of the map is unused.
1365   llvm::ScopedHashTable<StringRef, char> usedNames;
1366   llvm::BumpPtrAllocator usedNameAllocator;
1367 
1368   /// This is the next value ID to assign in numbering.
1369   unsigned nextValueID = 0;
1370   /// This is the next ID to assign to a region entry block argument.
1371   unsigned nextArgumentID = 0;
1372   /// This is the next ID to assign when a name conflict is detected.
1373   unsigned nextConflictID = 0;
1374 
1375   /// These are the printing flags.  They control, eg., whether to print in
1376   /// generic form.
1377   OpPrintingFlags printerFlags;
1378 };
1379 } // namespace
1380 
1381 SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags)
1382     : printerFlags(printerFlags) {
1383   llvm::SaveAndRestore valueIDSaver(nextValueID);
1384   llvm::SaveAndRestore argumentIDSaver(nextArgumentID);
1385   llvm::SaveAndRestore conflictIDSaver(nextConflictID);
1386 
1387   // The naming context includes `nextValueID`, `nextArgumentID`,
1388   // `nextConflictID` and `usedNames` scoped HashTable. This information is
1389   // carried from the parent region.
1390   using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
1391   using NamingContext =
1392       std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
1393 
1394   // Allocator for UsedNamesScopeTy
1395   llvm::BumpPtrAllocator allocator;
1396 
1397   // Add a scope for the top level operation.
1398   auto *topLevelNamesScope =
1399       new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
1400 
1401   SmallVector<NamingContext, 8> nameContext;
1402   for (Region &region : op->getRegions())
1403     nameContext.push_back(std::make_tuple(&region, nextValueID, nextArgumentID,
1404                                           nextConflictID, topLevelNamesScope));
1405 
1406   numberValuesInOp(*op);
1407 
1408   while (!nameContext.empty()) {
1409     Region *region;
1410     UsedNamesScopeTy *parentScope;
1411 
1412     if (printerFlags.shouldPrintUniqueSSAIDs())
1413       // To print unique SSA IDs, ignore saved ID counts from parent regions
1414       std::tie(region, std::ignore, std::ignore, std::ignore, parentScope) =
1415           nameContext.pop_back_val();
1416     else
1417       std::tie(region, nextValueID, nextArgumentID, nextConflictID,
1418                parentScope) = nameContext.pop_back_val();
1419 
1420     // When we switch from one subtree to another, pop the scopes(needless)
1421     // until the parent scope.
1422     while (usedNames.getCurScope() != parentScope) {
1423       usedNames.getCurScope()->~UsedNamesScopeTy();
1424       assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
1425              "top level parentScope must be a nullptr");
1426     }
1427 
1428     // Add a scope for the current region.
1429     auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
1430         UsedNamesScopeTy(usedNames);
1431 
1432     numberValuesInRegion(*region);
1433 
1434     for (Operation &op : region->getOps())
1435       for (Region &region : op.getRegions())
1436         nameContext.push_back(std::make_tuple(&region, nextValueID,
1437                                               nextArgumentID, nextConflictID,
1438                                               curNamesScope));
1439   }
1440 
1441   // Manually remove all the scopes.
1442   while (usedNames.getCurScope() != nullptr)
1443     usedNames.getCurScope()->~UsedNamesScopeTy();
1444 }
1445 
1446 void SSANameState::printValueID(Value value, bool printResultNo,
1447                                 raw_ostream &stream) const {
1448   if (!value) {
1449     stream << "<<NULL VALUE>>";
1450     return;
1451   }
1452 
1453   std::optional<int> resultNo;
1454   auto lookupValue = value;
1455 
1456   // If this is an operation result, collect the head lookup value of the result
1457   // group and the result number of 'result' within that group.
1458   if (OpResult result = dyn_cast<OpResult>(value))
1459     getResultIDAndNumber(result, lookupValue, resultNo);
1460 
1461   auto it = valueIDs.find(lookupValue);
1462   if (it == valueIDs.end()) {
1463     stream << "<<UNKNOWN SSA VALUE>>";
1464     return;
1465   }
1466 
1467   stream << '%';
1468   if (it->second != NameSentinel) {
1469     stream << it->second;
1470   } else {
1471     auto nameIt = valueNames.find(lookupValue);
1472     assert(nameIt != valueNames.end() && "Didn't have a name entry?");
1473     stream << nameIt->second;
1474   }
1475 
1476   if (resultNo && printResultNo)
1477     stream << '#' << *resultNo;
1478 }
1479 
1480 void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const {
1481   auto it = operationIDs.find(op);
1482   if (it == operationIDs.end()) {
1483     stream << "<<UNKNOWN OPERATION>>";
1484   } else {
1485     stream << '%' << it->second;
1486   }
1487 }
1488 
1489 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
1490   auto it = opResultGroups.find(op);
1491   return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
1492 }
1493 
1494 BlockInfo SSANameState::getBlockInfo(Block *block) {
1495   auto it = blockNames.find(block);
1496   BlockInfo invalidBlock{-1, "INVALIDBLOCK"};
1497   return it != blockNames.end() ? it->second : invalidBlock;
1498 }
1499 
1500 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
1501   assert(!region.empty() && "cannot shadow arguments of an empty region");
1502   assert(region.getNumArguments() == namesToUse.size() &&
1503          "incorrect number of names passed in");
1504   assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
1505          "only KnownIsolatedFromAbove ops can shadow names");
1506 
1507   SmallVector<char, 16> nameStr;
1508   for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
1509     auto nameToUse = namesToUse[i];
1510     if (nameToUse == nullptr)
1511       continue;
1512     auto nameToReplace = region.getArgument(i);
1513 
1514     nameStr.clear();
1515     llvm::raw_svector_ostream nameStream(nameStr);
1516     printValueID(nameToUse, /*printResultNo=*/true, nameStream);
1517 
1518     // Entry block arguments should already have a pretty "arg" name.
1519     assert(valueIDs[nameToReplace] == NameSentinel);
1520 
1521     // Use the name without the leading %.
1522     auto name = StringRef(nameStream.str()).drop_front();
1523 
1524     // Overwrite the name.
1525     valueNames[nameToReplace] = name.copy(usedNameAllocator);
1526   }
1527 }
1528 
1529 namespace {
1530 /// Try to get value name from value's location, fallback to `name`.
1531 StringRef maybeGetValueNameFromLoc(Value value, StringRef name) {
1532   if (auto maybeNameLoc = value.getLoc()->findInstanceOf<NameLoc>())
1533     return maybeNameLoc.getName();
1534   return name;
1535 }
1536 } // namespace
1537 
1538 void SSANameState::numberValuesInRegion(Region &region) {
1539   auto setBlockArgNameFn = [&](Value arg, StringRef name) {
1540     assert(!valueIDs.count(arg) && "arg numbered multiple times");
1541     assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
1542            "arg not defined in current region");
1543     if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
1544       name = maybeGetValueNameFromLoc(arg, name);
1545     setValueName(arg, name);
1546   };
1547 
1548   if (!printerFlags.shouldPrintGenericOpForm()) {
1549     if (Operation *op = region.getParentOp()) {
1550       if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
1551         asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1552     }
1553   }
1554 
1555   // Number the values within this region in a breadth-first order.
1556   unsigned nextBlockID = 0;
1557   for (auto &block : region) {
1558     // Each block gets a unique ID, and all of the operations within it get
1559     // numbered as well.
1560     auto blockInfoIt = blockNames.insert({&block, {-1, ""}});
1561     if (blockInfoIt.second) {
1562       // This block hasn't been named through `getAsmBlockArgumentNames`, use
1563       // default `^bbNNN` format.
1564       std::string name;
1565       llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
1566       blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator);
1567     }
1568     blockInfoIt.first->second.ordering = nextBlockID++;
1569 
1570     numberValuesInBlock(block);
1571   }
1572 }
1573 
1574 void SSANameState::numberValuesInBlock(Block &block) {
1575   // Number the block arguments. We give entry block arguments a special name
1576   // 'arg'.
1577   bool isEntryBlock = block.isEntryBlock();
1578   SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
1579   llvm::raw_svector_ostream specialName(specialNameBuffer);
1580   for (auto arg : block.getArguments()) {
1581     if (valueIDs.count(arg))
1582       continue;
1583     if (isEntryBlock) {
1584       specialNameBuffer.resize(strlen("arg"));
1585       specialName << nextArgumentID++;
1586     }
1587     StringRef specialNameStr = specialName.str();
1588     if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
1589       specialNameStr = maybeGetValueNameFromLoc(arg, specialNameStr);
1590     setValueName(arg, specialNameStr);
1591   }
1592 
1593   // Number the operations in this block.
1594   for (auto &op : block)
1595     numberValuesInOp(op);
1596 }
1597 
1598 void SSANameState::numberValuesInOp(Operation &op) {
1599   // Function used to set the special result names for the operation.
1600   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1601   auto setResultNameFn = [&](Value result, StringRef name) {
1602     assert(!valueIDs.count(result) && "result numbered multiple times");
1603     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1604     if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
1605       name = maybeGetValueNameFromLoc(result, name);
1606     setValueName(result, name);
1607 
1608     // Record the result number for groups not anchored at 0.
1609     if (int resultNo = llvm::cast<OpResult>(result).getResultNumber())
1610       resultGroups.push_back(resultNo);
1611   };
1612   // Operations can customize the printing of block names in OpAsmOpInterface.
1613   auto setBlockNameFn = [&](Block *block, StringRef name) {
1614     assert(block->getParentOp() == &op &&
1615            "getAsmBlockArgumentNames callback invoked on a block not directly "
1616            "nested under the current operation");
1617     assert(!blockNames.count(block) && "block numbered multiple times");
1618     SmallString<16> tmpBuffer{"^"};
1619     name = sanitizeIdentifier(name, tmpBuffer);
1620     if (name.data() != tmpBuffer.data()) {
1621       tmpBuffer.append(name);
1622       name = tmpBuffer.str();
1623     }
1624     name = name.copy(usedNameAllocator);
1625     blockNames[block] = {-1, name};
1626   };
1627 
1628   if (!printerFlags.shouldPrintGenericOpForm()) {
1629     if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
1630       asmInterface.getAsmBlockNames(setBlockNameFn);
1631       asmInterface.getAsmResultNames(setResultNameFn);
1632     }
1633   }
1634 
1635   unsigned numResults = op.getNumResults();
1636   if (numResults == 0) {
1637     // If value users should be printed, operations with no result need an id.
1638     if (printerFlags.shouldPrintValueUsers()) {
1639       if (operationIDs.try_emplace(&op, nextValueID).second)
1640         ++nextValueID;
1641     }
1642     return;
1643   }
1644   Value resultBegin = op.getResult(0);
1645 
1646   if (printerFlags.shouldUseNameLocAsPrefix() && !valueIDs.count(resultBegin)) {
1647     if (auto nameLoc = resultBegin.getLoc()->findInstanceOf<NameLoc>()) {
1648       setValueName(resultBegin, nameLoc.getName());
1649     }
1650   }
1651 
1652   // If the first result wasn't numbered, give it a default number.
1653   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
1654     ++nextValueID;
1655 
1656   // If this operation has multiple result groups, mark it.
1657   if (resultGroups.size() != 1) {
1658     llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
1659     opResultGroups.try_emplace(&op, std::move(resultGroups));
1660   }
1661 }
1662 
1663 void SSANameState::getResultIDAndNumber(
1664     OpResult result, Value &lookupValue,
1665     std::optional<int> &lookupResultNo) const {
1666   Operation *owner = result.getOwner();
1667   if (owner->getNumResults() == 1)
1668     return;
1669   int resultNo = result.getResultNumber();
1670 
1671   // If this operation has multiple result groups, we will need to find the
1672   // one corresponding to this result.
1673   auto resultGroupIt = opResultGroups.find(owner);
1674   if (resultGroupIt == opResultGroups.end()) {
1675     // If not, just use the first result.
1676     lookupResultNo = resultNo;
1677     lookupValue = owner->getResult(0);
1678     return;
1679   }
1680 
1681   // Find the correct index using a binary search, as the groups are ordered.
1682   ArrayRef<int> resultGroups = resultGroupIt->second;
1683   const auto *it = llvm::upper_bound(resultGroups, resultNo);
1684   int groupResultNo = 0, groupSize = 0;
1685 
1686   // If there are no smaller elements, the last result group is the lookup.
1687   if (it == resultGroups.end()) {
1688     groupResultNo = resultGroups.back();
1689     groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1690   } else {
1691     // Otherwise, the previous element is the lookup.
1692     groupResultNo = *std::prev(it);
1693     groupSize = *it - groupResultNo;
1694   }
1695 
1696   // We only record the result number for a group of size greater than 1.
1697   if (groupSize != 1)
1698     lookupResultNo = resultNo - groupResultNo;
1699   lookupValue = owner->getResult(groupResultNo);
1700 }
1701 
1702 void SSANameState::setValueName(Value value, StringRef name) {
1703   // If the name is empty, the value uses the default numbering.
1704   if (name.empty()) {
1705     valueIDs[value] = nextValueID++;
1706     return;
1707   }
1708 
1709   valueIDs[value] = NameSentinel;
1710   valueNames[value] = uniqueValueName(name);
1711 }
1712 
1713 StringRef SSANameState::uniqueValueName(StringRef name) {
1714   SmallString<16> tmpBuffer;
1715   name = sanitizeIdentifier(name, tmpBuffer);
1716 
1717   // Check to see if this name is already unique.
1718   if (!usedNames.count(name)) {
1719     name = name.copy(usedNameAllocator);
1720   } else {
1721     // Otherwise, we had a conflict - probe until we find a unique name. This
1722     // is guaranteed to terminate (and usually in a single iteration) because it
1723     // generates new names by incrementing nextConflictID.
1724     SmallString<64> probeName(name);
1725     probeName.push_back('_');
1726     while (true) {
1727       probeName += llvm::utostr(nextConflictID++);
1728       if (!usedNames.count(probeName)) {
1729         name = probeName.str().copy(usedNameAllocator);
1730         break;
1731       }
1732       probeName.resize(name.size() + 1);
1733     }
1734   }
1735 
1736   usedNames.insert(name, char());
1737   return name;
1738 }
1739 
1740 //===----------------------------------------------------------------------===//
1741 // DistinctState
1742 //===----------------------------------------------------------------------===//
1743 
1744 namespace {
1745 /// This class manages the state for distinct attributes.
1746 class DistinctState {
1747 public:
1748   /// Returns a unique identifier for the given distinct attribute.
1749   uint64_t getId(DistinctAttr distinctAttr);
1750 
1751 private:
1752   uint64_t distinctCounter = 0;
1753   DenseMap<DistinctAttr, uint64_t> distinctAttrMap;
1754 };
1755 } // namespace
1756 
1757 uint64_t DistinctState::getId(DistinctAttr distinctAttr) {
1758   auto [it, inserted] =
1759       distinctAttrMap.try_emplace(distinctAttr, distinctCounter);
1760   if (inserted)
1761     distinctCounter++;
1762   return it->getSecond();
1763 }
1764 
1765 //===----------------------------------------------------------------------===//
1766 // Resources
1767 //===----------------------------------------------------------------------===//
1768 
1769 AsmParsedResourceEntry::~AsmParsedResourceEntry() = default;
1770 AsmResourceBuilder::~AsmResourceBuilder() = default;
1771 AsmResourceParser::~AsmResourceParser() = default;
1772 AsmResourcePrinter::~AsmResourcePrinter() = default;
1773 
1774 StringRef mlir::toString(AsmResourceEntryKind kind) {
1775   switch (kind) {
1776   case AsmResourceEntryKind::Blob:
1777     return "blob";
1778   case AsmResourceEntryKind::Bool:
1779     return "bool";
1780   case AsmResourceEntryKind::String:
1781     return "string";
1782   }
1783   llvm_unreachable("unknown AsmResourceEntryKind");
1784 }
1785 
1786 AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) {
1787   std::unique_ptr<ResourceCollection> &collection = keyToResources[key.str()];
1788   if (!collection)
1789     collection = std::make_unique<ResourceCollection>(key);
1790   return *collection;
1791 }
1792 
1793 std::vector<std::unique_ptr<AsmResourcePrinter>>
1794 FallbackAsmResourceMap::getPrinters() {
1795   std::vector<std::unique_ptr<AsmResourcePrinter>> printers;
1796   for (auto &it : keyToResources) {
1797     ResourceCollection *collection = it.second.get();
1798     auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) {
1799       return collection->buildResources(op, builder);
1800     };
1801     printers.emplace_back(
1802         AsmResourcePrinter::fromCallable(collection->getName(), buildValues));
1803   }
1804   return printers;
1805 }
1806 
1807 LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource(
1808     AsmParsedResourceEntry &entry) {
1809   switch (entry.getKind()) {
1810   case AsmResourceEntryKind::Blob: {
1811     FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
1812     if (failed(blob))
1813       return failure();
1814     resources.emplace_back(entry.getKey(), std::move(*blob));
1815     return success();
1816   }
1817   case AsmResourceEntryKind::Bool: {
1818     FailureOr<bool> value = entry.parseAsBool();
1819     if (failed(value))
1820       return failure();
1821     resources.emplace_back(entry.getKey(), *value);
1822     break;
1823   }
1824   case AsmResourceEntryKind::String: {
1825     FailureOr<std::string> str = entry.parseAsString();
1826     if (failed(str))
1827       return failure();
1828     resources.emplace_back(entry.getKey(), std::move(*str));
1829     break;
1830   }
1831   }
1832   return success();
1833 }
1834 
1835 void FallbackAsmResourceMap::ResourceCollection::buildResources(
1836     Operation *op, AsmResourceBuilder &builder) const {
1837   for (const auto &entry : resources) {
1838     if (const auto *value = std::get_if<AsmResourceBlob>(&entry.value))
1839       builder.buildBlob(entry.key, *value);
1840     else if (const auto *value = std::get_if<bool>(&entry.value))
1841       builder.buildBool(entry.key, *value);
1842     else if (const auto *value = std::get_if<std::string>(&entry.value))
1843       builder.buildString(entry.key, *value);
1844     else
1845       llvm_unreachable("unknown AsmResourceEntryKind");
1846   }
1847 }
1848 
1849 //===----------------------------------------------------------------------===//
1850 // AsmState
1851 //===----------------------------------------------------------------------===//
1852 
1853 namespace mlir {
1854 namespace detail {
1855 class AsmStateImpl {
1856 public:
1857   explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
1858                         AsmState::LocationMap *locationMap)
1859       : interfaces(op->getContext()), nameState(op, printerFlags),
1860         printerFlags(printerFlags), locationMap(locationMap) {}
1861   explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
1862                         AsmState::LocationMap *locationMap)
1863       : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {}
1864 
1865   /// Initialize the alias state to enable the printing of aliases.
1866   void initializeAliases(Operation *op) {
1867     aliasState.initialize(op, printerFlags, interfaces);
1868   }
1869 
1870   /// Get the state used for aliases.
1871   AliasState &getAliasState() { return aliasState; }
1872 
1873   /// Get the state used for SSA names.
1874   SSANameState &getSSANameState() { return nameState; }
1875 
1876   /// Get the state used for distinct attribute identifiers.
1877   DistinctState &getDistinctState() { return distinctState; }
1878 
1879   /// Return the dialects within the context that implement
1880   /// OpAsmDialectInterface.
1881   DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() {
1882     return interfaces;
1883   }
1884 
1885   /// Return the non-dialect resource printers.
1886   auto getResourcePrinters() {
1887     return llvm::make_pointee_range(externalResourcePrinters);
1888   }
1889 
1890   /// Get the printer flags.
1891   const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }
1892 
1893   /// Register the location, line and column, within the buffer that the given
1894   /// operation was printed at.
1895   void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1896     if (locationMap)
1897       (*locationMap)[op] = std::make_pair(line, col);
1898   }
1899 
1900   /// Return the referenced dialect resources within the printer.
1901   DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
1902   getDialectResources() {
1903     return dialectResources;
1904   }
1905 
1906   LogicalResult pushCyclicPrinting(const void *opaquePointer) {
1907     return success(cyclicPrintingStack.insert(opaquePointer));
1908   }
1909 
1910   void popCyclicPrinting() { cyclicPrintingStack.pop_back(); }
1911 
1912 private:
1913   /// Collection of OpAsm interfaces implemented in the context.
1914   DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1915 
1916   /// A collection of non-dialect resource printers.
1917   SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
1918 
1919   /// A set of dialect resources that were referenced during printing.
1920   DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources;
1921 
1922   /// The state used for attribute and type aliases.
1923   AliasState aliasState;
1924 
1925   /// The state used for SSA value names.
1926   SSANameState nameState;
1927 
1928   /// The state used for distinct attribute identifiers.
1929   DistinctState distinctState;
1930 
1931   /// Flags that control op output.
1932   OpPrintingFlags printerFlags;
1933 
1934   /// An optional location map to be populated.
1935   AsmState::LocationMap *locationMap;
1936 
1937   /// Stack of potentially cyclic mutable attributes or type currently being
1938   /// printed.
1939   SetVector<const void *> cyclicPrintingStack;
1940 
1941   // Allow direct access to the impl fields.
1942   friend AsmState;
1943 };
1944 
1945 template <typename Range>
1946 void printDimensionList(raw_ostream &stream, Range &&shape) {
1947   llvm::interleave(
1948       shape, stream,
1949       [&stream](const auto &dimSize) {
1950         if (ShapedType::isDynamic(dimSize))
1951           stream << "?";
1952         else
1953           stream << dimSize;
1954       },
1955       "x");
1956 }
1957 
1958 } // namespace detail
1959 } // namespace mlir
1960 
1961 /// Verifies the operation and switches to generic op printing if verification
1962 /// fails. We need to do this because custom print functions may fail for
1963 /// invalid ops.
1964 static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
1965                                               OpPrintingFlags printerFlags) {
1966   if (printerFlags.shouldPrintGenericOpForm() ||
1967       printerFlags.shouldAssumeVerified())
1968     return printerFlags;
1969 
1970   // Ignore errors emitted by the verifier. We check the thread id to avoid
1971   // consuming other threads' errors.
1972   auto parentThreadId = llvm::get_threadid();
1973   ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) {
1974     if (parentThreadId == llvm::get_threadid()) {
1975       LLVM_DEBUG({
1976         diag.print(llvm::dbgs());
1977         llvm::dbgs() << "\n";
1978       });
1979       return success();
1980     }
1981     return failure();
1982   });
1983   if (failed(verify(op))) {
1984     LLVM_DEBUG(llvm::dbgs()
1985                << DEBUG_TYPE << ": '" << op->getName()
1986                << "' failed to verify and will be printed in generic form\n");
1987     printerFlags.printGenericOpForm();
1988   }
1989 
1990   return printerFlags;
1991 }
1992 
1993 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
1994                    LocationMap *locationMap, FallbackAsmResourceMap *map)
1995     : impl(std::make_unique<AsmStateImpl>(
1996           op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {
1997   if (map)
1998     attachFallbackResourcePrinter(*map);
1999 }
2000 AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
2001                    LocationMap *locationMap, FallbackAsmResourceMap *map)
2002     : impl(std::make_unique<AsmStateImpl>(ctx, printerFlags, locationMap)) {
2003   if (map)
2004     attachFallbackResourcePrinter(*map);
2005 }
2006 AsmState::~AsmState() = default;
2007 
2008 const OpPrintingFlags &AsmState::getPrinterFlags() const {
2009   return impl->getPrinterFlags();
2010 }
2011 
2012 void AsmState::attachResourcePrinter(
2013     std::unique_ptr<AsmResourcePrinter> printer) {
2014   impl->externalResourcePrinters.emplace_back(std::move(printer));
2015 }
2016 
2017 DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
2018 AsmState::getDialectResources() const {
2019   return impl->getDialectResources();
2020 }
2021 
2022 //===----------------------------------------------------------------------===//
2023 // AsmPrinter::Impl
2024 //===----------------------------------------------------------------------===//
2025 
2026 AsmPrinter::Impl::Impl(raw_ostream &os, AsmStateImpl &state)
2027     : os(os), state(state), printerFlags(state.getPrinterFlags()) {}
2028 
2029 void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) {
2030   // Check to see if we are printing debug information.
2031   if (!printerFlags.shouldPrintDebugInfo())
2032     return;
2033 
2034   os << " ";
2035   printLocation(loc, /*allowAlias=*/allowAlias);
2036 }
2037 
2038 void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
2039                                              bool isTopLevel) {
2040   // If this isn't a top-level location, check for an alias.
2041   if (!isTopLevel && succeeded(state.getAliasState().getAlias(loc, os)))
2042     return;
2043 
2044   TypeSwitch<LocationAttr>(loc)
2045       .Case<OpaqueLoc>([&](OpaqueLoc loc) {
2046         printLocationInternal(loc.getFallbackLocation(), pretty);
2047       })
2048       .Case<UnknownLoc>([&](UnknownLoc loc) {
2049         if (pretty)
2050           os << "[unknown]";
2051         else
2052           os << "unknown";
2053       })
2054       .Case<FileLineColRange>([&](FileLineColRange loc) {
2055         if (pretty)
2056           os << loc.getFilename().getValue();
2057         else
2058           printEscapedString(loc.getFilename());
2059         if (loc.getEndColumn() == loc.getStartColumn() &&
2060             loc.getStartLine() == loc.getEndLine()) {
2061           os << ':' << loc.getStartLine() << ':' << loc.getStartColumn();
2062           return;
2063         }
2064         if (loc.getStartLine() == loc.getEndLine()) {
2065           os << ':' << loc.getStartLine() << ':' << loc.getStartColumn()
2066              << " to :" << loc.getEndColumn();
2067           return;
2068         }
2069         os << ':' << loc.getStartLine() << ':' << loc.getStartColumn() << " to "
2070            << loc.getEndLine() << ':' << loc.getEndColumn();
2071       })
2072       .Case<NameLoc>([&](NameLoc loc) {
2073         printEscapedString(loc.getName());
2074 
2075         // Print the child if it isn't unknown.
2076         auto childLoc = loc.getChildLoc();
2077         if (!llvm::isa<UnknownLoc>(childLoc)) {
2078           os << '(';
2079           printLocationInternal(childLoc, pretty);
2080           os << ')';
2081         }
2082       })
2083       .Case<CallSiteLoc>([&](CallSiteLoc loc) {
2084         Location caller = loc.getCaller();
2085         Location callee = loc.getCallee();
2086         if (!pretty)
2087           os << "callsite(";
2088         printLocationInternal(callee, pretty);
2089         if (pretty) {
2090           if (llvm::isa<NameLoc>(callee)) {
2091             if (llvm::isa<FileLineColLoc>(caller)) {
2092               os << " at ";
2093             } else {
2094               os << newLine << " at ";
2095             }
2096           } else {
2097             os << newLine << " at ";
2098           }
2099         } else {
2100           os << " at ";
2101         }
2102         printLocationInternal(caller, pretty);
2103         if (!pretty)
2104           os << ")";
2105       })
2106       .Case<FusedLoc>([&](FusedLoc loc) {
2107         if (!pretty)
2108           os << "fused";
2109         if (Attribute metadata = loc.getMetadata()) {
2110           os << '<';
2111           printAttribute(metadata);
2112           os << '>';
2113         }
2114         os << '[';
2115         interleave(
2116             loc.getLocations(),
2117             [&](Location loc) { printLocationInternal(loc, pretty); },
2118             [&]() { os << ", "; });
2119         os << ']';
2120       })
2121       .Default([&](LocationAttr loc) {
2122         // Assumes that this is a dialect-specific attribute and prints it
2123         // directly.
2124         printAttribute(loc);
2125       });
2126 }
2127 
2128 /// Print a floating point value in a way that the parser will be able to
2129 /// round-trip losslessly.
2130 static void printFloatValue(const APFloat &apValue, raw_ostream &os,
2131                             bool *printedHex = nullptr) {
2132   // We would like to output the FP constant value in exponential notation,
2133   // but we cannot do this if doing so will lose precision.  Check here to
2134   // make sure that we only output it in exponential format if we can parse
2135   // the value back and get the same value.
2136   bool isInf = apValue.isInfinity();
2137   bool isNaN = apValue.isNaN();
2138   if (!isInf && !isNaN) {
2139     SmallString<128> strValue;
2140     apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
2141                      /*TruncateZero=*/false);
2142 
2143     // Check to make sure that the stringized number is not some string like
2144     // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
2145     // that the string matches the "[-+]?[0-9]" regex.
2146     assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
2147             ((strValue[0] == '-' || strValue[0] == '+') &&
2148              (strValue[1] >= '0' && strValue[1] <= '9'))) &&
2149            "[-+]?[0-9] regex does not match!");
2150 
2151     // Parse back the stringized version and check that the value is equal
2152     // (i.e., there is no precision loss).
2153     if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
2154       os << strValue;
2155       return;
2156     }
2157 
2158     // If it is not, use the default format of APFloat instead of the
2159     // exponential notation.
2160     strValue.clear();
2161     apValue.toString(strValue);
2162 
2163     // Make sure that we can parse the default form as a float.
2164     if (strValue.str().contains('.')) {
2165       os << strValue;
2166       return;
2167     }
2168   }
2169 
2170   // Print special values in hexadecimal format. The sign bit should be included
2171   // in the literal.
2172   if (printedHex)
2173     *printedHex = true;
2174   SmallVector<char, 16> str;
2175   APInt apInt = apValue.bitcastToAPInt();
2176   apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
2177                  /*formatAsCLiteral=*/true);
2178   os << str;
2179 }
2180 
2181 void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
2182   if (printerFlags.shouldPrintDebugInfoPrettyForm())
2183     return printLocationInternal(loc, /*pretty=*/true, /*isTopLevel=*/true);
2184 
2185   os << "loc(";
2186   if (!allowAlias || failed(printAlias(loc)))
2187     printLocationInternal(loc, /*pretty=*/false, /*isTopLevel=*/true);
2188   os << ')';
2189 }
2190 
2191 void AsmPrinter::Impl::printResourceHandle(
2192     const AsmDialectResourceHandle &resource) {
2193   auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
2194   os << interface->getResourceKey(resource);
2195   state.getDialectResources()[resource.getDialect()].insert(resource);
2196 }
2197 
2198 /// Returns true if the given dialect symbol data is simple enough to print in
2199 /// the pretty form. This is essentially when the symbol takes the form:
2200 ///   identifier (`<` body `>`)?
2201 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
2202   // The name must start with an identifier.
2203   if (symName.empty() || !isalpha(symName.front()))
2204     return false;
2205 
2206   // Ignore all the characters that are valid in an identifier in the symbol
2207   // name.
2208   symName = symName.drop_while(
2209       [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
2210   if (symName.empty())
2211     return true;
2212 
2213   // If we got to an unexpected character, then it must be a <>. Check that the
2214   // rest of the symbol is wrapped within <>.
2215   return symName.front() == '<' && symName.back() == '>';
2216 }
2217 
2218 /// Print the given dialect symbol to the stream.
2219 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
2220                                StringRef dialectName, StringRef symString) {
2221   os << symPrefix << dialectName;
2222 
2223   // If this symbol name is simple enough, print it directly in pretty form,
2224   // otherwise, we print it as an escaped string.
2225   if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
2226     os << '.' << symString;
2227     return;
2228   }
2229 
2230   os << '<' << symString << '>';
2231 }
2232 
2233 /// Returns true if the given string can be represented as a bare identifier.
2234 static bool isBareIdentifier(StringRef name) {
2235   // By making this unsigned, the value passed in to isalnum will always be
2236   // in the range 0-255. This is important when building with MSVC because
2237   // its implementation will assert. This situation can arise when dealing
2238   // with UTF-8 multibyte characters.
2239   if (name.empty() || (!isalpha(name[0]) && name[0] != '_'))
2240     return false;
2241   return llvm::all_of(name.drop_front(), [](unsigned char c) {
2242     return isalnum(c) || c == '_' || c == '$' || c == '.';
2243   });
2244 }
2245 
2246 /// Print the given string as a keyword, or a quoted and escaped string if it
2247 /// has any special or non-printable characters in it.
2248 static void printKeywordOrString(StringRef keyword, raw_ostream &os) {
2249   // If it can be represented as a bare identifier, write it directly.
2250   if (isBareIdentifier(keyword)) {
2251     os << keyword;
2252     return;
2253   }
2254 
2255   // Otherwise, output the keyword wrapped in quotes with proper escaping.
2256   os << "\"";
2257   printEscapedString(keyword, os);
2258   os << '"';
2259 }
2260 
2261 /// Print the given string as a symbol reference. A symbol reference is
2262 /// represented as a string prefixed with '@'. The reference is surrounded with
2263 /// ""'s and escaped if it has any special or non-printable characters in it.
2264 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
2265   if (symbolRef.empty()) {
2266     os << "@<<INVALID EMPTY SYMBOL>>";
2267     return;
2268   }
2269   os << '@';
2270   printKeywordOrString(symbolRef, os);
2271 }
2272 
2273 // Print out a valid ElementsAttr that is succinct and can represent any
2274 // potential shape/type, for use when eliding a large ElementsAttr.
2275 //
2276 // We choose to use a dense resource ElementsAttr literal with conspicuous
2277 // content to hopefully alert readers to the fact that this has been elided.
2278 static void printElidedElementsAttr(raw_ostream &os) {
2279   os << R"(dense_resource<__elided__>)";
2280 }
2281 
2282 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
2283   return state.getAliasState().getAlias(attr, os);
2284 }
2285 
2286 LogicalResult AsmPrinter::Impl::printAlias(Type type) {
2287   return state.getAliasState().getAlias(type, os);
2288 }
2289 
2290 void AsmPrinter::Impl::printAttribute(Attribute attr,
2291                                       AttrTypeElision typeElision) {
2292   if (!attr) {
2293     os << "<<NULL ATTRIBUTE>>";
2294     return;
2295   }
2296 
2297   // Try to print an alias for this attribute.
2298   if (succeeded(printAlias(attr)))
2299     return;
2300   return printAttributeImpl(attr, typeElision);
2301 }
2302 
2303 void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
2304                                           AttrTypeElision typeElision) {
2305   if (!isa<BuiltinDialect>(attr.getDialect())) {
2306     printDialectAttribute(attr);
2307   } else if (auto opaqueAttr = llvm::dyn_cast<OpaqueAttr>(attr)) {
2308     printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
2309                        opaqueAttr.getAttrData());
2310   } else if (llvm::isa<UnitAttr>(attr)) {
2311     os << "unit";
2312     return;
2313   } else if (auto distinctAttr = llvm::dyn_cast<DistinctAttr>(attr)) {
2314     os << "distinct[" << state.getDistinctState().getId(distinctAttr) << "]<";
2315     if (!llvm::isa<UnitAttr>(distinctAttr.getReferencedAttr())) {
2316       printAttribute(distinctAttr.getReferencedAttr());
2317     }
2318     os << '>';
2319     return;
2320   } else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) {
2321     os << '{';
2322     interleaveComma(dictAttr.getValue(),
2323                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
2324     os << '}';
2325 
2326   } else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
2327     Type intType = intAttr.getType();
2328     if (intType.isSignlessInteger(1)) {
2329       os << (intAttr.getValue().getBoolValue() ? "true" : "false");
2330 
2331       // Boolean integer attributes always elides the type.
2332       return;
2333     }
2334 
2335     // Only print attributes as unsigned if they are explicitly unsigned or are
2336     // signless 1-bit values.  Indexes, signed values, and multi-bit signless
2337     // values print as signed.
2338     bool isUnsigned =
2339         intType.isUnsignedInteger() || intType.isSignlessInteger(1);
2340     intAttr.getValue().print(os, !isUnsigned);
2341 
2342     // IntegerAttr elides the type if I64.
2343     if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(64))
2344       return;
2345 
2346   } else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attr)) {
2347     bool printedHex = false;
2348     printFloatValue(floatAttr.getValue(), os, &printedHex);
2349 
2350     // FloatAttr elides the type if F64.
2351     if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64() &&
2352         !printedHex)
2353       return;
2354 
2355   } else if (auto strAttr = llvm::dyn_cast<StringAttr>(attr)) {
2356     printEscapedString(strAttr.getValue());
2357 
2358   } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) {
2359     os << '[';
2360     interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
2361       printAttribute(attr, AttrTypeElision::May);
2362     });
2363     os << ']';
2364 
2365   } else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) {
2366     os << "affine_map<";
2367     affineMapAttr.getValue().print(os);
2368     os << '>';
2369 
2370     // AffineMap always elides the type.
2371     return;
2372 
2373   } else if (auto integerSetAttr = llvm::dyn_cast<IntegerSetAttr>(attr)) {
2374     os << "affine_set<";
2375     integerSetAttr.getValue().print(os);
2376     os << '>';
2377 
2378     // IntegerSet always elides the type.
2379     return;
2380 
2381   } else if (auto typeAttr = llvm::dyn_cast<TypeAttr>(attr)) {
2382     printType(typeAttr.getValue());
2383 
2384   } else if (auto refAttr = llvm::dyn_cast<SymbolRefAttr>(attr)) {
2385     printSymbolReference(refAttr.getRootReference().getValue(), os);
2386     for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
2387       os << "::";
2388       printSymbolReference(nestedRef.getValue(), os);
2389     }
2390 
2391   } else if (auto intOrFpEltAttr =
2392                  llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) {
2393     if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
2394       printElidedElementsAttr(os);
2395     } else {
2396       os << "dense<";
2397       printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
2398       os << '>';
2399     }
2400 
2401   } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) {
2402     if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
2403       printElidedElementsAttr(os);
2404     } else {
2405       os << "dense<";
2406       printDenseStringElementsAttr(strEltAttr);
2407       os << '>';
2408     }
2409 
2410   } else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) {
2411     if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
2412         printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
2413       printElidedElementsAttr(os);
2414     } else {
2415       os << "sparse<";
2416       DenseIntElementsAttr indices = sparseEltAttr.getIndices();
2417       if (indices.getNumElements() != 0) {
2418         printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
2419         os << ", ";
2420         printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
2421       }
2422       os << '>';
2423     }
2424   } else if (auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(attr)) {
2425     stridedLayoutAttr.print(os);
2426   } else if (auto denseArrayAttr = llvm::dyn_cast<DenseArrayAttr>(attr)) {
2427     os << "array<";
2428     printType(denseArrayAttr.getElementType());
2429     if (!denseArrayAttr.empty()) {
2430       os << ": ";
2431       printDenseArrayAttr(denseArrayAttr);
2432     }
2433     os << ">";
2434     return;
2435   } else if (auto resourceAttr =
2436                  llvm::dyn_cast<DenseResourceElementsAttr>(attr)) {
2437     os << "dense_resource<";
2438     printResourceHandle(resourceAttr.getRawHandle());
2439     os << ">";
2440   } else if (auto locAttr = llvm::dyn_cast<LocationAttr>(attr)) {
2441     printLocation(locAttr);
2442   } else {
2443     llvm::report_fatal_error("Unknown builtin attribute");
2444   }
2445   // Don't print the type if we must elide it, or if it is a None type.
2446   if (typeElision != AttrTypeElision::Must) {
2447     if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) {
2448       Type attrType = typedAttr.getType();
2449       if (!llvm::isa<NoneType>(attrType)) {
2450         os << " : ";
2451         printType(attrType);
2452       }
2453     }
2454   }
2455 }
2456 
2457 /// Print the integer element of a DenseElementsAttr.
2458 static void printDenseIntElement(const APInt &value, raw_ostream &os,
2459                                  Type type) {
2460   if (type.isInteger(1))
2461     os << (value.getBoolValue() ? "true" : "false");
2462   else
2463     value.print(os, !type.isUnsignedInteger());
2464 }
2465 
2466 static void
2467 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
2468                            function_ref<void(unsigned)> printEltFn) {
2469   // Special case for 0-d and splat tensors.
2470   if (isSplat)
2471     return printEltFn(0);
2472 
2473   // Special case for degenerate tensors.
2474   auto numElements = type.getNumElements();
2475   if (numElements == 0)
2476     return;
2477 
2478   // We use a mixed-radix counter to iterate through the shape. When we bump a
2479   // non-least-significant digit, we emit a close bracket. When we next emit an
2480   // element we re-open all closed brackets.
2481 
2482   // The mixed-radix counter, with radices in 'shape'.
2483   int64_t rank = type.getRank();
2484   SmallVector<unsigned, 4> counter(rank, 0);
2485   // The number of brackets that have been opened and not closed.
2486   unsigned openBrackets = 0;
2487 
2488   auto shape = type.getShape();
2489   auto bumpCounter = [&] {
2490     // Bump the least significant digit.
2491     ++counter[rank - 1];
2492     // Iterate backwards bubbling back the increment.
2493     for (unsigned i = rank - 1; i > 0; --i)
2494       if (counter[i] >= shape[i]) {
2495         // Index 'i' is rolled over. Bump (i-1) and close a bracket.
2496         counter[i] = 0;
2497         ++counter[i - 1];
2498         --openBrackets;
2499         os << ']';
2500       }
2501   };
2502 
2503   for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
2504     if (idx != 0)
2505       os << ", ";
2506     while (openBrackets++ < rank)
2507       os << '[';
2508     openBrackets = rank;
2509     printEltFn(idx);
2510     bumpCounter();
2511   }
2512   while (openBrackets-- > 0)
2513     os << ']';
2514 }
2515 
2516 void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
2517                                               bool allowHex) {
2518   if (auto stringAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr))
2519     return printDenseStringElementsAttr(stringAttr);
2520 
2521   printDenseIntOrFPElementsAttr(llvm::cast<DenseIntOrFPElementsAttr>(attr),
2522                                 allowHex);
2523 }
2524 
2525 void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
2526     DenseIntOrFPElementsAttr attr, bool allowHex) {
2527   auto type = attr.getType();
2528   auto elementType = type.getElementType();
2529 
2530   // Check to see if we should format this attribute as a hex string.
2531   if (allowHex && printerFlags.shouldPrintElementsAttrWithHex(attr)) {
2532     ArrayRef<char> rawData = attr.getRawData();
2533     if (llvm::endianness::native == llvm::endianness::big) {
2534       // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
2535       // machines. It is converted here to print in LE format.
2536       SmallVector<char, 64> outDataVec(rawData.size());
2537       MutableArrayRef<char> convRawData(outDataVec);
2538       DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
2539           rawData, convRawData, type);
2540       printHexString(convRawData);
2541     } else {
2542       printHexString(rawData);
2543     }
2544 
2545     return;
2546   }
2547 
2548   if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) {
2549     Type complexElementType = complexTy.getElementType();
2550     // Note: The if and else below had a common lambda function which invoked
2551     // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
2552     // and hence was replaced.
2553     if (llvm::isa<IntegerType>(complexElementType)) {
2554       auto valueIt = attr.value_begin<std::complex<APInt>>();
2555       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2556         auto complexValue = *(valueIt + index);
2557         os << "(";
2558         printDenseIntElement(complexValue.real(), os, complexElementType);
2559         os << ",";
2560         printDenseIntElement(complexValue.imag(), os, complexElementType);
2561         os << ")";
2562       });
2563     } else {
2564       auto valueIt = attr.value_begin<std::complex<APFloat>>();
2565       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2566         auto complexValue = *(valueIt + index);
2567         os << "(";
2568         printFloatValue(complexValue.real(), os);
2569         os << ",";
2570         printFloatValue(complexValue.imag(), os);
2571         os << ")";
2572       });
2573     }
2574   } else if (elementType.isIntOrIndex()) {
2575     auto valueIt = attr.value_begin<APInt>();
2576     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2577       printDenseIntElement(*(valueIt + index), os, elementType);
2578     });
2579   } else {
2580     assert(llvm::isa<FloatType>(elementType) && "unexpected element type");
2581     auto valueIt = attr.value_begin<APFloat>();
2582     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2583       printFloatValue(*(valueIt + index), os);
2584     });
2585   }
2586 }
2587 
2588 void AsmPrinter::Impl::printDenseStringElementsAttr(
2589     DenseStringElementsAttr attr) {
2590   ArrayRef<StringRef> data = attr.getRawStringData();
2591   auto printFn = [&](unsigned index) { printEscapedString(data[index]); };
2592   printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
2593 }
2594 
2595 void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
2596   Type type = attr.getElementType();
2597   unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth();
2598   unsigned byteSize = bitwidth / 8;
2599   ArrayRef<char> data = attr.getRawData();
2600 
2601   auto printElementAt = [&](unsigned i) {
2602     APInt value(bitwidth, 0);
2603     if (bitwidth) {
2604       llvm::LoadIntFromMemory(
2605           value, reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i),
2606           byteSize);
2607     }
2608     // Print the data as-is or as a float.
2609     if (type.isIntOrIndex()) {
2610       printDenseIntElement(value, getStream(), type);
2611     } else {
2612       APFloat fltVal(llvm::cast<FloatType>(type).getFloatSemantics(), value);
2613       printFloatValue(fltVal, getStream());
2614     }
2615   };
2616   llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(),
2617                         printElementAt);
2618 }
2619 
2620 void AsmPrinter::Impl::printType(Type type) {
2621   if (!type) {
2622     os << "<<NULL TYPE>>";
2623     return;
2624   }
2625 
2626   // Try to print an alias for this type.
2627   if (succeeded(printAlias(type)))
2628     return;
2629   return printTypeImpl(type);
2630 }
2631 
2632 void AsmPrinter::Impl::printTypeImpl(Type type) {
2633   TypeSwitch<Type>(type)
2634       .Case<OpaqueType>([&](OpaqueType opaqueTy) {
2635         printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
2636                            opaqueTy.getTypeData());
2637       })
2638       .Case<IndexType>([&](Type) { os << "index"; })
2639       .Case<Float4E2M1FNType>([&](Type) { os << "f4E2M1FN"; })
2640       .Case<Float6E2M3FNType>([&](Type) { os << "f6E2M3FN"; })
2641       .Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN"; })
2642       .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
2643       .Case<Float8E4M3Type>([&](Type) { os << "f8E4M3"; })
2644       .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
2645       .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
2646       .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
2647       .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
2648       .Case<Float8E3M4Type>([&](Type) { os << "f8E3M4"; })
2649       .Case<Float8E8M0FNUType>([&](Type) { os << "f8E8M0FNU"; })
2650       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
2651       .Case<Float16Type>([&](Type) { os << "f16"; })
2652       .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
2653       .Case<Float32Type>([&](Type) { os << "f32"; })
2654       .Case<Float64Type>([&](Type) { os << "f64"; })
2655       .Case<Float80Type>([&](Type) { os << "f80"; })
2656       .Case<Float128Type>([&](Type) { os << "f128"; })
2657       .Case<IntegerType>([&](IntegerType integerTy) {
2658         if (integerTy.isSigned())
2659           os << 's';
2660         else if (integerTy.isUnsigned())
2661           os << 'u';
2662         os << 'i' << integerTy.getWidth();
2663       })
2664       .Case<FunctionType>([&](FunctionType funcTy) {
2665         os << '(';
2666         interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
2667         os << ") -> ";
2668         ArrayRef<Type> results = funcTy.getResults();
2669         if (results.size() == 1 && !llvm::isa<FunctionType>(results[0])) {
2670           printType(results[0]);
2671         } else {
2672           os << '(';
2673           interleaveComma(results, [&](Type ty) { printType(ty); });
2674           os << ')';
2675         }
2676       })
2677       .Case<VectorType>([&](VectorType vectorTy) {
2678         auto scalableDims = vectorTy.getScalableDims();
2679         os << "vector<";
2680         auto vShape = vectorTy.getShape();
2681         unsigned lastDim = vShape.size();
2682         unsigned dimIdx = 0;
2683         for (dimIdx = 0; dimIdx < lastDim; dimIdx++) {
2684           if (!scalableDims.empty() && scalableDims[dimIdx])
2685             os << '[';
2686           os << vShape[dimIdx];
2687           if (!scalableDims.empty() && scalableDims[dimIdx])
2688             os << ']';
2689           os << 'x';
2690         }
2691         printType(vectorTy.getElementType());
2692         os << '>';
2693       })
2694       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
2695         os << "tensor<";
2696         printDimensionList(tensorTy.getShape());
2697         if (!tensorTy.getShape().empty())
2698           os << 'x';
2699         printType(tensorTy.getElementType());
2700         // Only print the encoding attribute value if set.
2701         if (tensorTy.getEncoding()) {
2702           os << ", ";
2703           printAttribute(tensorTy.getEncoding());
2704         }
2705         os << '>';
2706       })
2707       .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
2708         os << "tensor<*x";
2709         printType(tensorTy.getElementType());
2710         os << '>';
2711       })
2712       .Case<MemRefType>([&](MemRefType memrefTy) {
2713         os << "memref<";
2714         printDimensionList(memrefTy.getShape());
2715         if (!memrefTy.getShape().empty())
2716           os << 'x';
2717         printType(memrefTy.getElementType());
2718         MemRefLayoutAttrInterface layout = memrefTy.getLayout();
2719         if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
2720           os << ", ";
2721           printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
2722         }
2723         // Only print the memory space if it is the non-default one.
2724         if (memrefTy.getMemorySpace()) {
2725           os << ", ";
2726           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2727         }
2728         os << '>';
2729       })
2730       .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
2731         os << "memref<*x";
2732         printType(memrefTy.getElementType());
2733         // Only print the memory space if it is the non-default one.
2734         if (memrefTy.getMemorySpace()) {
2735           os << ", ";
2736           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2737         }
2738         os << '>';
2739       })
2740       .Case<ComplexType>([&](ComplexType complexTy) {
2741         os << "complex<";
2742         printType(complexTy.getElementType());
2743         os << '>';
2744       })
2745       .Case<TupleType>([&](TupleType tupleTy) {
2746         os << "tuple<";
2747         interleaveComma(tupleTy.getTypes(),
2748                         [&](Type type) { printType(type); });
2749         os << '>';
2750       })
2751       .Case<NoneType>([&](Type) { os << "none"; })
2752       .Default([&](Type type) { return printDialectType(type); });
2753 }
2754 
2755 void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2756                                              ArrayRef<StringRef> elidedAttrs,
2757                                              bool withKeyword) {
2758   // If there are no attributes, then there is nothing to be done.
2759   if (attrs.empty())
2760     return;
2761 
2762   // Functor used to print a filtered attribute list.
2763   auto printFilteredAttributesFn = [&](auto filteredAttrs) {
2764     // Print the 'attributes' keyword if necessary.
2765     if (withKeyword)
2766       os << " attributes";
2767 
2768     // Otherwise, print them all out in braces.
2769     os << " {";
2770     interleaveComma(filteredAttrs,
2771                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
2772     os << '}';
2773   };
2774 
2775   // If no attributes are elided, we can directly print with no filtering.
2776   if (elidedAttrs.empty())
2777     return printFilteredAttributesFn(attrs);
2778 
2779   // Otherwise, filter out any attributes that shouldn't be included.
2780   llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
2781                                                 elidedAttrs.end());
2782   auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
2783     return !elidedAttrsSet.contains(attr.getName().strref());
2784   });
2785   if (!filteredAttrs.empty())
2786     printFilteredAttributesFn(filteredAttrs);
2787 }
2788 void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
2789   // Print the name without quotes if possible.
2790   ::printKeywordOrString(attr.getName().strref(), os);
2791 
2792   // Pretty printing elides the attribute value for unit attributes.
2793   if (llvm::isa<UnitAttr>(attr.getValue()))
2794     return;
2795 
2796   os << " = ";
2797   printAttribute(attr.getValue());
2798 }
2799 
2800 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
2801   auto &dialect = attr.getDialect();
2802 
2803   // Ask the dialect to serialize the attribute to a string.
2804   std::string attrName;
2805   {
2806     llvm::raw_string_ostream attrNameStr(attrName);
2807     Impl subPrinter(attrNameStr, state);
2808     DialectAsmPrinter printer(subPrinter);
2809     dialect.printAttribute(attr, printer);
2810   }
2811   printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
2812 }
2813 
2814 void AsmPrinter::Impl::printDialectType(Type type) {
2815   auto &dialect = type.getDialect();
2816 
2817   // Ask the dialect to serialize the type to a string.
2818   std::string typeName;
2819   {
2820     llvm::raw_string_ostream typeNameStr(typeName);
2821     Impl subPrinter(typeNameStr, state);
2822     DialectAsmPrinter printer(subPrinter);
2823     dialect.printType(type, printer);
2824   }
2825   printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
2826 }
2827 
2828 void AsmPrinter::Impl::printEscapedString(StringRef str) {
2829   os << "\"";
2830   llvm::printEscapedString(str, os);
2831   os << "\"";
2832 }
2833 
2834 void AsmPrinter::Impl::printHexString(StringRef str) {
2835   os << "\"0x" << llvm::toHex(str) << "\"";
2836 }
2837 void AsmPrinter::Impl::printHexString(ArrayRef<char> data) {
2838   printHexString(StringRef(data.data(), data.size()));
2839 }
2840 
2841 LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) {
2842   return state.pushCyclicPrinting(opaquePointer);
2843 }
2844 
2845 void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); }
2846 
2847 void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) {
2848   detail::printDimensionList(os, shape);
2849 }
2850 
2851 //===--------------------------------------------------------------------===//
2852 // AsmPrinter
2853 //===--------------------------------------------------------------------===//
2854 
2855 AsmPrinter::~AsmPrinter() = default;
2856 
2857 raw_ostream &AsmPrinter::getStream() const {
2858   assert(impl && "expected AsmPrinter::getStream to be overriden");
2859   return impl->getStream();
2860 }
2861 
2862 /// Print the given floating point value in a stablized form.
2863 void AsmPrinter::printFloat(const APFloat &value) {
2864   assert(impl && "expected AsmPrinter::printFloat to be overriden");
2865   printFloatValue(value, impl->getStream());
2866 }
2867 
2868 void AsmPrinter::printType(Type type) {
2869   assert(impl && "expected AsmPrinter::printType to be overriden");
2870   impl->printType(type);
2871 }
2872 
2873 void AsmPrinter::printAttribute(Attribute attr) {
2874   assert(impl && "expected AsmPrinter::printAttribute to be overriden");
2875   impl->printAttribute(attr);
2876 }
2877 
2878 LogicalResult AsmPrinter::printAlias(Attribute attr) {
2879   assert(impl && "expected AsmPrinter::printAlias to be overriden");
2880   return impl->printAlias(attr);
2881 }
2882 
2883 LogicalResult AsmPrinter::printAlias(Type type) {
2884   assert(impl && "expected AsmPrinter::printAlias to be overriden");
2885   return impl->printAlias(type);
2886 }
2887 
2888 void AsmPrinter::printAttributeWithoutType(Attribute attr) {
2889   assert(impl &&
2890          "expected AsmPrinter::printAttributeWithoutType to be overriden");
2891   impl->printAttribute(attr, Impl::AttrTypeElision::Must);
2892 }
2893 
2894 void AsmPrinter::printKeywordOrString(StringRef keyword) {
2895   assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
2896   ::printKeywordOrString(keyword, impl->getStream());
2897 }
2898 
2899 void AsmPrinter::printString(StringRef keyword) {
2900   assert(impl && "expected AsmPrinter::printString to be overriden");
2901   *this << '"';
2902   printEscapedString(keyword, getStream());
2903   *this << '"';
2904 }
2905 
2906 void AsmPrinter::printSymbolName(StringRef symbolRef) {
2907   assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
2908   ::printSymbolReference(symbolRef, impl->getStream());
2909 }
2910 
2911 void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) {
2912   assert(impl && "expected AsmPrinter::printResourceHandle to be overriden");
2913   impl->printResourceHandle(resource);
2914 }
2915 
2916 void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) {
2917   detail::printDimensionList(getStream(), shape);
2918 }
2919 
2920 LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
2921   return impl->pushCyclicPrinting(opaquePointer);
2922 }
2923 
2924 void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); }
2925 
2926 //===----------------------------------------------------------------------===//
2927 // Affine expressions and maps
2928 //===----------------------------------------------------------------------===//
2929 
2930 void AsmPrinter::Impl::printAffineExpr(
2931     AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
2932   printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
2933 }
2934 
2935 void AsmPrinter::Impl::printAffineExprInternal(
2936     AffineExpr expr, BindingStrength enclosingTightness,
2937     function_ref<void(unsigned, bool)> printValueName) {
2938   const char *binopSpelling = nullptr;
2939   switch (expr.getKind()) {
2940   case AffineExprKind::SymbolId: {
2941     unsigned pos = cast<AffineSymbolExpr>(expr).getPosition();
2942     if (printValueName)
2943       printValueName(pos, /*isSymbol=*/true);
2944     else
2945       os << 's' << pos;
2946     return;
2947   }
2948   case AffineExprKind::DimId: {
2949     unsigned pos = cast<AffineDimExpr>(expr).getPosition();
2950     if (printValueName)
2951       printValueName(pos, /*isSymbol=*/false);
2952     else
2953       os << 'd' << pos;
2954     return;
2955   }
2956   case AffineExprKind::Constant:
2957     os << cast<AffineConstantExpr>(expr).getValue();
2958     return;
2959   case AffineExprKind::Add:
2960     binopSpelling = " + ";
2961     break;
2962   case AffineExprKind::Mul:
2963     binopSpelling = " * ";
2964     break;
2965   case AffineExprKind::FloorDiv:
2966     binopSpelling = " floordiv ";
2967     break;
2968   case AffineExprKind::CeilDiv:
2969     binopSpelling = " ceildiv ";
2970     break;
2971   case AffineExprKind::Mod:
2972     binopSpelling = " mod ";
2973     break;
2974   }
2975 
2976   auto binOp = cast<AffineBinaryOpExpr>(expr);
2977   AffineExpr lhsExpr = binOp.getLHS();
2978   AffineExpr rhsExpr = binOp.getRHS();
2979 
2980   // Handle tightly binding binary operators.
2981   if (binOp.getKind() != AffineExprKind::Add) {
2982     if (enclosingTightness == BindingStrength::Strong)
2983       os << '(';
2984 
2985     // Pretty print multiplication with -1.
2986     auto rhsConst = dyn_cast<AffineConstantExpr>(rhsExpr);
2987     if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2988         rhsConst.getValue() == -1) {
2989       os << "-";
2990       printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2991       if (enclosingTightness == BindingStrength::Strong)
2992         os << ')';
2993       return;
2994     }
2995 
2996     printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2997 
2998     os << binopSpelling;
2999     printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
3000 
3001     if (enclosingTightness == BindingStrength::Strong)
3002       os << ')';
3003     return;
3004   }
3005 
3006   // Print out special "pretty" forms for add.
3007   if (enclosingTightness == BindingStrength::Strong)
3008     os << '(';
3009 
3010   // Pretty print addition to a product that has a negative operand as a
3011   // subtraction.
3012   if (auto rhs = dyn_cast<AffineBinaryOpExpr>(rhsExpr)) {
3013     if (rhs.getKind() == AffineExprKind::Mul) {
3014       AffineExpr rrhsExpr = rhs.getRHS();
3015       if (auto rrhs = dyn_cast<AffineConstantExpr>(rrhsExpr)) {
3016         if (rrhs.getValue() == -1) {
3017           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
3018                                   printValueName);
3019           os << " - ";
3020           if (rhs.getLHS().getKind() == AffineExprKind::Add) {
3021             printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
3022                                     printValueName);
3023           } else {
3024             printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
3025                                     printValueName);
3026           }
3027 
3028           if (enclosingTightness == BindingStrength::Strong)
3029             os << ')';
3030           return;
3031         }
3032 
3033         if (rrhs.getValue() < -1) {
3034           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
3035                                   printValueName);
3036           os << " - ";
3037           printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
3038                                   printValueName);
3039           os << " * " << -rrhs.getValue();
3040           if (enclosingTightness == BindingStrength::Strong)
3041             os << ')';
3042           return;
3043         }
3044       }
3045     }
3046   }
3047 
3048   // Pretty print addition to a negative number as a subtraction.
3049   if (auto rhsConst = dyn_cast<AffineConstantExpr>(rhsExpr)) {
3050     if (rhsConst.getValue() < 0) {
3051       printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
3052       os << " - " << -rhsConst.getValue();
3053       if (enclosingTightness == BindingStrength::Strong)
3054         os << ')';
3055       return;
3056     }
3057   }
3058 
3059   printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
3060 
3061   os << " + ";
3062   printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
3063 
3064   if (enclosingTightness == BindingStrength::Strong)
3065     os << ')';
3066 }
3067 
3068 void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) {
3069   printAffineExprInternal(expr, BindingStrength::Weak);
3070   isEq ? os << " == 0" : os << " >= 0";
3071 }
3072 
3073 void AsmPrinter::Impl::printAffineMap(AffineMap map) {
3074   // Dimension identifiers.
3075   os << '(';
3076   for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
3077     os << 'd' << i << ", ";
3078   if (map.getNumDims() >= 1)
3079     os << 'd' << map.getNumDims() - 1;
3080   os << ')';
3081 
3082   // Symbolic identifiers.
3083   if (map.getNumSymbols() != 0) {
3084     os << '[';
3085     for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
3086       os << 's' << i << ", ";
3087     if (map.getNumSymbols() >= 1)
3088       os << 's' << map.getNumSymbols() - 1;
3089     os << ']';
3090   }
3091 
3092   // Result affine expressions.
3093   os << " -> (";
3094   interleaveComma(map.getResults(),
3095                   [&](AffineExpr expr) { printAffineExpr(expr); });
3096   os << ')';
3097 }
3098 
3099 void AsmPrinter::Impl::printIntegerSet(IntegerSet set) {
3100   // Dimension identifiers.
3101   os << '(';
3102   for (unsigned i = 1; i < set.getNumDims(); ++i)
3103     os << 'd' << i - 1 << ", ";
3104   if (set.getNumDims() >= 1)
3105     os << 'd' << set.getNumDims() - 1;
3106   os << ')';
3107 
3108   // Symbolic identifiers.
3109   if (set.getNumSymbols() != 0) {
3110     os << '[';
3111     for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
3112       os << 's' << i << ", ";
3113     if (set.getNumSymbols() >= 1)
3114       os << 's' << set.getNumSymbols() - 1;
3115     os << ']';
3116   }
3117 
3118   // Print constraints.
3119   os << " : (";
3120   int numConstraints = set.getNumConstraints();
3121   for (int i = 1; i < numConstraints; ++i) {
3122     printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
3123     os << ", ";
3124   }
3125   if (numConstraints >= 1)
3126     printAffineConstraint(set.getConstraint(numConstraints - 1),
3127                           set.isEq(numConstraints - 1));
3128   os << ')';
3129 }
3130 
3131 //===----------------------------------------------------------------------===//
3132 // OperationPrinter
3133 //===----------------------------------------------------------------------===//
3134 
3135 namespace {
3136 /// This class contains the logic for printing operations, regions, and blocks.
3137 class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
3138 public:
3139   using Impl = AsmPrinter::Impl;
3140   using Impl::printType;
3141 
3142   explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
3143       : Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
3144 
3145   /// Print the given top-level operation.
3146   void printTopLevelOperation(Operation *op);
3147 
3148   /// Print the given operation, including its left-hand side and its right-hand
3149   /// side, with its indent and location.
3150   void printFullOpWithIndentAndLoc(Operation *op);
3151   /// Print the given operation, including its left-hand side and its right-hand
3152   /// side, but not including indentation and location.
3153   void printFullOp(Operation *op);
3154   /// Print the right-hand size of the given operation in the custom or generic
3155   /// form.
3156   void printCustomOrGenericOp(Operation *op) override;
3157   /// Print the right-hand side of the given operation in the generic form.
3158   void printGenericOp(Operation *op, bool printOpName) override;
3159 
3160   /// Print the name of the given block.
3161   void printBlockName(Block *block);
3162 
3163   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
3164   /// block are not printed. If 'printBlockTerminator' is false, the terminator
3165   /// operation of the block is not printed.
3166   void print(Block *block, bool printBlockArgs = true,
3167              bool printBlockTerminator = true);
3168 
3169   /// Print the ID of the given value, optionally with its result number.
3170   void printValueID(Value value, bool printResultNo = true,
3171                     raw_ostream *streamOverride = nullptr) const;
3172 
3173   /// Print the ID of the given operation.
3174   void printOperationID(Operation *op,
3175                         raw_ostream *streamOverride = nullptr) const;
3176 
3177   //===--------------------------------------------------------------------===//
3178   // OpAsmPrinter methods
3179   //===--------------------------------------------------------------------===//
3180 
3181   /// Print a loc(...) specifier if printing debug info is enabled. Locations
3182   /// may be deferred with an alias.
3183   void printOptionalLocationSpecifier(Location loc) override {
3184     printTrailingLocation(loc);
3185   }
3186 
3187   /// Print a newline and indent the printer to the start of the current
3188   /// operation.
3189   void printNewline() override {
3190     os << newLine;
3191     os.indent(currentIndent);
3192   }
3193 
3194   /// Increase indentation.
3195   void increaseIndent() override { currentIndent += indentWidth; }
3196 
3197   /// Decrease indentation.
3198   void decreaseIndent() override { currentIndent -= indentWidth; }
3199 
3200   /// Print a block argument in the usual format of:
3201   ///   %ssaName : type {attr1=42} loc("here")
3202   /// where location printing is controlled by the standard internal option.
3203   /// You may pass omitType=true to not print a type, and pass an empty
3204   /// attribute list if you don't care for attributes.
3205   void printRegionArgument(BlockArgument arg,
3206                            ArrayRef<NamedAttribute> argAttrs = {},
3207                            bool omitType = false) override;
3208 
3209   /// Print the ID for the given value.
3210   void printOperand(Value value) override { printValueID(value); }
3211   void printOperand(Value value, raw_ostream &os) override {
3212     printValueID(value, /*printResultNo=*/true, &os);
3213   }
3214 
3215   /// Print an optional attribute dictionary with a given set of elided values.
3216   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
3217                              ArrayRef<StringRef> elidedAttrs = {}) override {
3218     Impl::printOptionalAttrDict(attrs, elidedAttrs);
3219   }
3220   void printOptionalAttrDictWithKeyword(
3221       ArrayRef<NamedAttribute> attrs,
3222       ArrayRef<StringRef> elidedAttrs = {}) override {
3223     Impl::printOptionalAttrDict(attrs, elidedAttrs,
3224                                 /*withKeyword=*/true);
3225   }
3226 
3227   /// Print the given successor.
3228   void printSuccessor(Block *successor) override;
3229 
3230   /// Print an operation successor with the operands used for the block
3231   /// arguments.
3232   void printSuccessorAndUseList(Block *successor,
3233                                 ValueRange succOperands) override;
3234 
3235   /// Print the given region.
3236   void printRegion(Region &region, bool printEntryBlockArgs,
3237                    bool printBlockTerminators, bool printEmptyBlock) override;
3238 
3239   /// Renumber the arguments for the specified region to the same names as the
3240   /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
3241   /// operations. If any entry in namesToUse is null, the corresponding
3242   /// argument name is left alone.
3243   void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
3244     state.getSSANameState().shadowRegionArgs(region, namesToUse);
3245   }
3246 
3247   /// Print the given affine map with the symbol and dimension operands printed
3248   /// inline with the map.
3249   void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
3250                               ValueRange operands) override;
3251 
3252   /// Print the given affine expression with the symbol and dimension operands
3253   /// printed inline with the expression.
3254   void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
3255                                ValueRange symOperands) override;
3256 
3257   /// Print users of this operation or id of this operation if it has no result.
3258   void printUsersComment(Operation *op);
3259 
3260   /// Print users of this block arg.
3261   void printUsersComment(BlockArgument arg);
3262 
3263   /// Print the users of a value.
3264   void printValueUsers(Value value);
3265 
3266   /// Print either the ids of the result values or the id of the operation if
3267   /// the operation has no results.
3268   void printUserIDs(Operation *user, bool prefixComma = false);
3269 
3270 private:
3271   /// This class represents a resource builder implementation for the MLIR
3272   /// textual assembly format.
3273   class ResourceBuilder : public AsmResourceBuilder {
3274   public:
3275     using ValueFn = function_ref<void(raw_ostream &)>;
3276     using PrintFn = function_ref<void(StringRef, ValueFn)>;
3277 
3278     ResourceBuilder(PrintFn printFn) : printFn(printFn) {}
3279     ~ResourceBuilder() override = default;
3280 
3281     void buildBool(StringRef key, bool data) final {
3282       printFn(key, [&](raw_ostream &os) { os << (data ? "true" : "false"); });
3283     }
3284 
3285     void buildString(StringRef key, StringRef data) final {
3286       printFn(key, [&](raw_ostream &os) {
3287         os << "\"";
3288         llvm::printEscapedString(data, os);
3289         os << "\"";
3290       });
3291     }
3292 
3293     void buildBlob(StringRef key, ArrayRef<char> data,
3294                    uint32_t dataAlignment) final {
3295       printFn(key, [&](raw_ostream &os) {
3296         // Store the blob in a hex string containing the alignment and the data.
3297         llvm::support::ulittle32_t dataAlignmentLE(dataAlignment);
3298         os << "\"0x"
3299            << llvm::toHex(StringRef(reinterpret_cast<char *>(&dataAlignmentLE),
3300                                     sizeof(dataAlignment)))
3301            << llvm::toHex(StringRef(data.data(), data.size())) << "\"";
3302       });
3303     }
3304 
3305   private:
3306     PrintFn printFn;
3307   };
3308 
3309   /// Print the metadata dictionary for the file, eliding it if it is empty.
3310   void printFileMetadataDictionary(Operation *op);
3311 
3312   /// Print the resource sections for the file metadata dictionary.
3313   /// `checkAddMetadataDict` is used to indicate that metadata is going to be
3314   /// added, and the file metadata dictionary should be started if it hasn't
3315   /// yet.
3316   void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict,
3317                                  Operation *op);
3318 
3319   // Contains the stack of default dialects to use when printing regions.
3320   // A new dialect is pushed to the stack before parsing regions nested under an
3321   // operation implementing `OpAsmOpInterface`, and popped when done. At the
3322   // top-level we start with "builtin" as the default, so that the top-level
3323   // `module` operation prints as-is.
3324   SmallVector<StringRef> defaultDialectStack{"builtin"};
3325 
3326   /// The number of spaces used for indenting nested operations.
3327   const static unsigned indentWidth = 2;
3328 
3329   // This is the current indentation level for nested structures.
3330   unsigned currentIndent = 0;
3331 };
3332 } // namespace
3333 
3334 void OperationPrinter::printTopLevelOperation(Operation *op) {
3335   // Output the aliases at the top level that can't be deferred.
3336   state.getAliasState().printNonDeferredAliases(*this, newLine);
3337 
3338   // Print the module.
3339   printFullOpWithIndentAndLoc(op);
3340   os << newLine;
3341 
3342   // Output the aliases at the top level that can be deferred.
3343   state.getAliasState().printDeferredAliases(*this, newLine);
3344 
3345   // Output any file level metadata.
3346   printFileMetadataDictionary(op);
3347 }
3348 
3349 void OperationPrinter::printFileMetadataDictionary(Operation *op) {
3350   bool sawMetadataEntry = false;
3351   auto checkAddMetadataDict = [&] {
3352     if (!std::exchange(sawMetadataEntry, true))
3353       os << newLine << "{-#" << newLine;
3354   };
3355 
3356   // Add the various types of metadata.
3357   printResourceFileMetadata(checkAddMetadataDict, op);
3358 
3359   // If the file dictionary exists, close it.
3360   if (sawMetadataEntry)
3361     os << newLine << "#-}" << newLine;
3362 }
3363 
3364 void OperationPrinter::printResourceFileMetadata(
3365     function_ref<void()> checkAddMetadataDict, Operation *op) {
3366   // Functor used to add data entries to the file metadata dictionary.
3367   bool hadResource = false;
3368   bool needResourceComma = false;
3369   bool needEntryComma = false;
3370   auto processProvider = [&](StringRef dictName, StringRef name, auto &provider,
3371                              auto &&...providerArgs) {
3372     bool hadEntry = false;
3373     auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) {
3374       checkAddMetadataDict();
3375 
3376       auto printFormatting = [&]() {
3377         // Emit the top-level resource entry if we haven't yet.
3378         if (!std::exchange(hadResource, true)) {
3379           if (needResourceComma)
3380             os << "," << newLine;
3381           os << "  " << dictName << "_resources: {" << newLine;
3382         }
3383         // Emit the parent resource entry if we haven't yet.
3384         if (!std::exchange(hadEntry, true)) {
3385           if (needEntryComma)
3386             os << "," << newLine;
3387           os << "    " << name << ": {" << newLine;
3388         } else {
3389           os << "," << newLine;
3390         }
3391       };
3392 
3393       std::optional<uint64_t> charLimit =
3394           printerFlags.getLargeResourceStringLimit();
3395       if (charLimit.has_value()) {
3396         std::string resourceStr;
3397         llvm::raw_string_ostream ss(resourceStr);
3398         valueFn(ss);
3399 
3400         // Only print entry if it's string is small enough
3401         if (resourceStr.size() > charLimit.value())
3402           return;
3403 
3404         printFormatting();
3405         os << "      " << key << ": " << resourceStr;
3406       } else {
3407         printFormatting();
3408         os << "      " << key << ": ";
3409         valueFn(os);
3410       }
3411     };
3412     ResourceBuilder entryBuilder(printFn);
3413     provider.buildResources(op, providerArgs..., entryBuilder);
3414 
3415     needEntryComma |= hadEntry;
3416     if (hadEntry)
3417       os << newLine << "    }";
3418   };
3419 
3420   // Print the `dialect_resources` section if we have any dialects with
3421   // resources.
3422   for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) {
3423     auto &dialectResources = state.getDialectResources();
3424     StringRef name = interface.getDialect()->getNamespace();
3425     auto it = dialectResources.find(interface.getDialect());
3426     if (it != dialectResources.end())
3427       processProvider("dialect", name, interface, it->second);
3428     else
3429       processProvider("dialect", name, interface,
3430                       SetVector<AsmDialectResourceHandle>());
3431   }
3432   if (hadResource)
3433     os << newLine << "  }";
3434 
3435   // Print the `external_resources` section if we have any external clients with
3436   // resources.
3437   needEntryComma = false;
3438   needResourceComma = hadResource;
3439   hadResource = false;
3440   for (const auto &printer : state.getResourcePrinters())
3441     processProvider("external", printer.getName(), printer);
3442   if (hadResource)
3443     os << newLine << "  }";
3444 }
3445 
3446 /// Print a block argument in the usual format of:
3447 ///   %ssaName : type {attr1=42} loc("here")
3448 /// where location printing is controlled by the standard internal option.
3449 /// You may pass omitType=true to not print a type, and pass an empty
3450 /// attribute list if you don't care for attributes.
3451 void OperationPrinter::printRegionArgument(BlockArgument arg,
3452                                            ArrayRef<NamedAttribute> argAttrs,
3453                                            bool omitType) {
3454   printOperand(arg);
3455   if (!omitType) {
3456     os << ": ";
3457     printType(arg.getType());
3458   }
3459   printOptionalAttrDict(argAttrs);
3460   // TODO: We should allow location aliases on block arguments.
3461   printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
3462 }
3463 
3464 void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) {
3465   // Track the location of this operation.
3466   state.registerOperationLocation(op, newLine.curLine, currentIndent);
3467 
3468   os.indent(currentIndent);
3469   printFullOp(op);
3470   printTrailingLocation(op->getLoc());
3471   if (printerFlags.shouldPrintValueUsers())
3472     printUsersComment(op);
3473 }
3474 
3475 void OperationPrinter::printFullOp(Operation *op) {
3476   if (size_t numResults = op->getNumResults()) {
3477     auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
3478       printValueID(op->getResult(resultNo), /*printResultNo=*/false);
3479       if (resultCount > 1)
3480         os << ':' << resultCount;
3481     };
3482 
3483     // Check to see if this operation has multiple result groups.
3484     ArrayRef<int> resultGroups = state.getSSANameState().getOpResultGroups(op);
3485     if (!resultGroups.empty()) {
3486       // Interleave the groups excluding the last one, this one will be handled
3487       // separately.
3488       interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
3489         printResultGroup(resultGroups[i],
3490                          resultGroups[i + 1] - resultGroups[i]);
3491       });
3492       os << ", ";
3493       printResultGroup(resultGroups.back(), numResults - resultGroups.back());
3494 
3495     } else {
3496       printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
3497     }
3498 
3499     os << " = ";
3500   }
3501 
3502   printCustomOrGenericOp(op);
3503 }
3504 
3505 void OperationPrinter::printUsersComment(Operation *op) {
3506   unsigned numResults = op->getNumResults();
3507   if (!numResults && op->getNumOperands()) {
3508     os << " // id: ";
3509     printOperationID(op);
3510   } else if (numResults && op->use_empty()) {
3511     os << " // unused";
3512   } else if (numResults && !op->use_empty()) {
3513     // Print "user" if the operation has one result used to compute one other
3514     // result, or is used in one operation with no result.
3515     unsigned usedInNResults = 0;
3516     unsigned usedInNOperations = 0;
3517     SmallPtrSet<Operation *, 1> userSet;
3518     for (Operation *user : op->getUsers()) {
3519       if (userSet.insert(user).second) {
3520         ++usedInNOperations;
3521         usedInNResults += user->getNumResults();
3522       }
3523     }
3524 
3525     // We already know that users is not empty.
3526     bool exactlyOneUniqueUse =
3527         usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1;
3528     os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": ";
3529     bool shouldPrintBrackets = numResults > 1;
3530     auto printOpResult = [&](OpResult opResult) {
3531       if (shouldPrintBrackets)
3532         os << "(";
3533       printValueUsers(opResult);
3534       if (shouldPrintBrackets)
3535         os << ")";
3536     };
3537 
3538     interleaveComma(op->getResults(), printOpResult);
3539   }
3540 }
3541 
3542 void OperationPrinter::printUsersComment(BlockArgument arg) {
3543   os << "// ";
3544   printValueID(arg);
3545   if (arg.use_empty()) {
3546     os << " is unused";
3547   } else {
3548     os << " is used by ";
3549     printValueUsers(arg);
3550   }
3551   os << newLine;
3552 }
3553 
3554 void OperationPrinter::printValueUsers(Value value) {
3555   if (value.use_empty())
3556     os << "unused";
3557 
3558   // One value might be used as the operand of an operation more than once.
3559   // Only print the operations results once in that case.
3560   SmallPtrSet<Operation *, 1> userSet;
3561   for (auto [index, user] : enumerate(value.getUsers())) {
3562     if (userSet.insert(user).second)
3563       printUserIDs(user, index);
3564   }
3565 }
3566 
3567 void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) {
3568   if (prefixComma)
3569     os << ", ";
3570 
3571   if (!user->getNumResults()) {
3572     printOperationID(user);
3573   } else {
3574     interleaveComma(user->getResults(),
3575                     [this](Value result) { printValueID(result); });
3576   }
3577 }
3578 
3579 void OperationPrinter::printCustomOrGenericOp(Operation *op) {
3580   // If requested, always print the generic form.
3581   if (!printerFlags.shouldPrintGenericOpForm()) {
3582     // Check to see if this is a known operation. If so, use the registered
3583     // custom printer hook.
3584     if (auto opInfo = op->getRegisteredInfo()) {
3585       opInfo->printAssembly(op, *this, defaultDialectStack.back());
3586       return;
3587     }
3588     // Otherwise try to dispatch to the dialect, if available.
3589     if (Dialect *dialect = op->getDialect()) {
3590       if (auto opPrinter = dialect->getOperationPrinter(op)) {
3591         // Print the op name first.
3592         StringRef name = op->getName().getStringRef();
3593         // Only drop the default dialect prefix when it cannot lead to
3594         // ambiguities.
3595         if (name.count('.') == 1)
3596           name.consume_front((defaultDialectStack.back() + ".").str());
3597         os << name;
3598 
3599         // Print the rest of the op now.
3600         opPrinter(op, *this);
3601         return;
3602       }
3603     }
3604   }
3605 
3606   // Otherwise print with the generic assembly form.
3607   printGenericOp(op, /*printOpName=*/true);
3608 }
3609 
3610 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
3611   if (printOpName)
3612     printEscapedString(op->getName().getStringRef());
3613   os << '(';
3614   interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
3615   os << ')';
3616 
3617   // For terminators, print the list of successors and their operands.
3618   if (op->getNumSuccessors() != 0) {
3619     os << '[';
3620     interleaveComma(op->getSuccessors(),
3621                     [&](Block *successor) { printBlockName(successor); });
3622     os << ']';
3623   }
3624 
3625   // Print the properties.
3626   if (Attribute prop = op->getPropertiesAsAttribute()) {
3627     os << " <";
3628     Impl::printAttribute(prop);
3629     os << '>';
3630   }
3631 
3632   // Print regions.
3633   if (op->getNumRegions() != 0) {
3634     os << " (";
3635     interleaveComma(op->getRegions(), [&](Region &region) {
3636       printRegion(region, /*printEntryBlockArgs=*/true,
3637                   /*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
3638     });
3639     os << ')';
3640   }
3641 
3642   printOptionalAttrDict(op->getPropertiesStorage()
3643                             ? llvm::to_vector(op->getDiscardableAttrs())
3644                             : op->getAttrs());
3645 
3646   // Print the type signature of the operation.
3647   os << " : ";
3648   printFunctionalType(op);
3649 }
3650 
3651 void OperationPrinter::printBlockName(Block *block) {
3652   os << state.getSSANameState().getBlockInfo(block).name;
3653 }
3654 
3655 void OperationPrinter::print(Block *block, bool printBlockArgs,
3656                              bool printBlockTerminator) {
3657   // Print the block label and argument list if requested.
3658   if (printBlockArgs) {
3659     os.indent(currentIndent);
3660     printBlockName(block);
3661 
3662     // Print the argument list if non-empty.
3663     if (!block->args_empty()) {
3664       os << '(';
3665       interleaveComma(block->getArguments(), [&](BlockArgument arg) {
3666         printValueID(arg);
3667         os << ": ";
3668         printType(arg.getType());
3669         // TODO: We should allow location aliases on block arguments.
3670         printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
3671       });
3672       os << ')';
3673     }
3674     os << ':';
3675 
3676     // Print out some context information about the predecessors of this block.
3677     if (!block->getParent()) {
3678       os << "  // block is not in a region!";
3679     } else if (block->hasNoPredecessors()) {
3680       if (!block->isEntryBlock())
3681         os << "  // no predecessors";
3682     } else if (auto *pred = block->getSinglePredecessor()) {
3683       os << "  // pred: ";
3684       printBlockName(pred);
3685     } else {
3686       // We want to print the predecessors in a stable order, not in
3687       // whatever order the use-list is in, so gather and sort them.
3688       SmallVector<BlockInfo, 4> predIDs;
3689       for (auto *pred : block->getPredecessors())
3690         predIDs.push_back(state.getSSANameState().getBlockInfo(pred));
3691       llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
3692         return lhs.ordering < rhs.ordering;
3693       });
3694 
3695       os << "  // " << predIDs.size() << " preds: ";
3696 
3697       interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; });
3698     }
3699     os << newLine;
3700   }
3701 
3702   currentIndent += indentWidth;
3703 
3704   if (printerFlags.shouldPrintValueUsers()) {
3705     for (BlockArgument arg : block->getArguments()) {
3706       os.indent(currentIndent);
3707       printUsersComment(arg);
3708     }
3709   }
3710 
3711   bool hasTerminator =
3712       !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
3713   auto range = llvm::make_range(
3714       block->begin(),
3715       std::prev(block->end(),
3716                 (!hasTerminator || printBlockTerminator) ? 0 : 1));
3717   for (auto &op : range) {
3718     printFullOpWithIndentAndLoc(&op);
3719     os << newLine;
3720   }
3721   currentIndent -= indentWidth;
3722 }
3723 
3724 void OperationPrinter::printValueID(Value value, bool printResultNo,
3725                                     raw_ostream *streamOverride) const {
3726   state.getSSANameState().printValueID(value, printResultNo,
3727                                        streamOverride ? *streamOverride : os);
3728 }
3729 
3730 void OperationPrinter::printOperationID(Operation *op,
3731                                         raw_ostream *streamOverride) const {
3732   state.getSSANameState().printOperationID(op, streamOverride ? *streamOverride
3733                                                               : os);
3734 }
3735 
3736 void OperationPrinter::printSuccessor(Block *successor) {
3737   printBlockName(successor);
3738 }
3739 
3740 void OperationPrinter::printSuccessorAndUseList(Block *successor,
3741                                                 ValueRange succOperands) {
3742   printBlockName(successor);
3743   if (succOperands.empty())
3744     return;
3745 
3746   os << '(';
3747   interleaveComma(succOperands,
3748                   [this](Value operand) { printValueID(operand); });
3749   os << " : ";
3750   interleaveComma(succOperands,
3751                   [this](Value operand) { printType(operand.getType()); });
3752   os << ')';
3753 }
3754 
3755 void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
3756                                    bool printBlockTerminators,
3757                                    bool printEmptyBlock) {
3758   if (printerFlags.shouldSkipRegions()) {
3759     os << "{...}";
3760     return;
3761   }
3762   os << "{" << newLine;
3763   if (!region.empty()) {
3764     auto restoreDefaultDialect =
3765         llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); });
3766     if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp()))
3767       defaultDialectStack.push_back(iface.getDefaultDialect());
3768     else
3769       defaultDialectStack.push_back("");
3770 
3771     auto *entryBlock = &region.front();
3772     // Force printing the block header if printEmptyBlock is set and the block
3773     // is empty or if printEntryBlockArgs is set and there are arguments to
3774     // print.
3775     bool shouldAlwaysPrintBlockHeader =
3776         (printEmptyBlock && entryBlock->empty()) ||
3777         (printEntryBlockArgs && entryBlock->getNumArguments() != 0);
3778     print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
3779     for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
3780       print(&b);
3781   }
3782   os.indent(currentIndent) << "}";
3783 }
3784 
3785 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
3786                                               ValueRange operands) {
3787   if (!mapAttr) {
3788     os << "<<NULL AFFINE MAP>>";
3789     return;
3790   }
3791   AffineMap map = mapAttr.getValue();
3792   unsigned numDims = map.getNumDims();
3793   auto printValueName = [&](unsigned pos, bool isSymbol) {
3794     unsigned index = isSymbol ? numDims + pos : pos;
3795     assert(index < operands.size());
3796     if (isSymbol)
3797       os << "symbol(";
3798     printValueID(operands[index]);
3799     if (isSymbol)
3800       os << ')';
3801   };
3802 
3803   interleaveComma(map.getResults(), [&](AffineExpr expr) {
3804     printAffineExpr(expr, printValueName);
3805   });
3806 }
3807 
3808 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
3809                                                ValueRange dimOperands,
3810                                                ValueRange symOperands) {
3811   auto printValueName = [&](unsigned pos, bool isSymbol) {
3812     if (!isSymbol)
3813       return printValueID(dimOperands[pos]);
3814     os << "symbol(";
3815     printValueID(symOperands[pos]);
3816     os << ')';
3817   };
3818   printAffineExpr(expr, printValueName);
3819 }
3820 
3821 //===----------------------------------------------------------------------===//
3822 // print and dump methods
3823 //===----------------------------------------------------------------------===//
3824 
3825 void Attribute::print(raw_ostream &os, bool elideType) const {
3826   if (!*this) {
3827     os << "<<NULL ATTRIBUTE>>";
3828     return;
3829   }
3830 
3831   AsmState state(getContext());
3832   print(os, state, elideType);
3833 }
3834 void Attribute::print(raw_ostream &os, AsmState &state, bool elideType) const {
3835   using AttrTypeElision = AsmPrinter::Impl::AttrTypeElision;
3836   AsmPrinter::Impl(os, state.getImpl())
3837       .printAttribute(*this, elideType ? AttrTypeElision::Must
3838                                        : AttrTypeElision::Never);
3839 }
3840 
3841 void Attribute::dump() const {
3842   print(llvm::errs());
3843   llvm::errs() << "\n";
3844 }
3845 
3846 void Attribute::printStripped(raw_ostream &os, AsmState &state) const {
3847   if (!*this) {
3848     os << "<<NULL ATTRIBUTE>>";
3849     return;
3850   }
3851 
3852   AsmPrinter::Impl subPrinter(os, state.getImpl());
3853   if (succeeded(subPrinter.printAlias(*this)))
3854     return;
3855 
3856   auto &dialect = this->getDialect();
3857   uint64_t posPrior = os.tell();
3858   DialectAsmPrinter printer(subPrinter);
3859   dialect.printAttribute(*this, printer);
3860   if (posPrior != os.tell())
3861     return;
3862 
3863   // Fallback to printing with prefix if the above failed to write anything
3864   // to the output stream.
3865   print(os, state);
3866 }
3867 void Attribute::printStripped(raw_ostream &os) const {
3868   if (!*this) {
3869     os << "<<NULL ATTRIBUTE>>";
3870     return;
3871   }
3872 
3873   AsmState state(getContext());
3874   printStripped(os, state);
3875 }
3876 
3877 void Type::print(raw_ostream &os) const {
3878   if (!*this) {
3879     os << "<<NULL TYPE>>";
3880     return;
3881   }
3882 
3883   AsmState state(getContext());
3884   print(os, state);
3885 }
3886 void Type::print(raw_ostream &os, AsmState &state) const {
3887   AsmPrinter::Impl(os, state.getImpl()).printType(*this);
3888 }
3889 
3890 void Type::dump() const {
3891   print(llvm::errs());
3892   llvm::errs() << "\n";
3893 }
3894 
3895 void AffineMap::dump() const {
3896   print(llvm::errs());
3897   llvm::errs() << "\n";
3898 }
3899 
3900 void IntegerSet::dump() const {
3901   print(llvm::errs());
3902   llvm::errs() << "\n";
3903 }
3904 
3905 void AffineExpr::print(raw_ostream &os) const {
3906   if (!expr) {
3907     os << "<<NULL AFFINE EXPR>>";
3908     return;
3909   }
3910   AsmState state(getContext());
3911   AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(*this);
3912 }
3913 
3914 void AffineExpr::dump() const {
3915   print(llvm::errs());
3916   llvm::errs() << "\n";
3917 }
3918 
3919 void AffineMap::print(raw_ostream &os) const {
3920   if (!map) {
3921     os << "<<NULL AFFINE MAP>>";
3922     return;
3923   }
3924   AsmState state(getContext());
3925   AsmPrinter::Impl(os, state.getImpl()).printAffineMap(*this);
3926 }
3927 
3928 void IntegerSet::print(raw_ostream &os) const {
3929   AsmState state(getContext());
3930   AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(*this);
3931 }
3932 
3933 void Value::print(raw_ostream &os) const { print(os, OpPrintingFlags()); }
3934 void Value::print(raw_ostream &os, const OpPrintingFlags &flags) const {
3935   if (!impl) {
3936     os << "<<NULL VALUE>>";
3937     return;
3938   }
3939 
3940   if (auto *op = getDefiningOp())
3941     return op->print(os, flags);
3942   // TODO: Improve BlockArgument print'ing.
3943   BlockArgument arg = llvm::cast<BlockArgument>(*this);
3944   os << "<block argument> of type '" << arg.getType()
3945      << "' at index: " << arg.getArgNumber();
3946 }
3947 void Value::print(raw_ostream &os, AsmState &state) const {
3948   if (!impl) {
3949     os << "<<NULL VALUE>>";
3950     return;
3951   }
3952 
3953   if (auto *op = getDefiningOp())
3954     return op->print(os, state);
3955 
3956   // TODO: Improve BlockArgument print'ing.
3957   BlockArgument arg = llvm::cast<BlockArgument>(*this);
3958   os << "<block argument> of type '" << arg.getType()
3959      << "' at index: " << arg.getArgNumber();
3960 }
3961 
3962 void Value::dump() const {
3963   print(llvm::errs());
3964   llvm::errs() << "\n";
3965 }
3966 
3967 void Value::printAsOperand(raw_ostream &os, AsmState &state) const {
3968   // TODO: This doesn't necessarily capture all potential cases.
3969   // Currently, region arguments can be shadowed when printing the main
3970   // operation. If the IR hasn't been printed, this will produce the old SSA
3971   // name and not the shadowed name.
3972   state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
3973                                                  os);
3974 }
3975 
3976 static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
3977   do {
3978     // If we are printing local scope, stop at the first operation that is
3979     // isolated from above.
3980     if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
3981       break;
3982 
3983     // Otherwise, traverse up to the next parent.
3984     Operation *parentOp = op->getParentOp();
3985     if (!parentOp)
3986       break;
3987     op = parentOp;
3988   } while (true);
3989   return op;
3990 }
3991 
3992 void Value::printAsOperand(raw_ostream &os,
3993                            const OpPrintingFlags &flags) const {
3994   Operation *op;
3995   if (auto result = llvm::dyn_cast<OpResult>(*this)) {
3996     op = result.getOwner();
3997   } else {
3998     op = llvm::cast<BlockArgument>(*this).getOwner()->getParentOp();
3999     if (!op) {
4000       os << "<<UNKNOWN SSA VALUE>>";
4001       return;
4002     }
4003   }
4004   op = findParent(op, flags.shouldUseLocalScope());
4005   AsmState state(op, flags);
4006   printAsOperand(os, state);
4007 }
4008 
4009 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
4010   // Find the operation to number from based upon the provided flags.
4011   Operation *op = findParent(this, printerFlags.shouldUseLocalScope());
4012   AsmState state(op, printerFlags);
4013   print(os, state);
4014 }
4015 void Operation::print(raw_ostream &os, AsmState &state) {
4016   OperationPrinter printer(os, state.getImpl());
4017   if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) {
4018     state.getImpl().initializeAliases(this);
4019     printer.printTopLevelOperation(this);
4020   } else {
4021     printer.printFullOpWithIndentAndLoc(this);
4022   }
4023 }
4024 
4025 void Operation::dump() {
4026   print(llvm::errs(), OpPrintingFlags().useLocalScope());
4027   llvm::errs() << "\n";
4028 }
4029 
4030 void Operation::dumpPretty() {
4031   print(llvm::errs(), OpPrintingFlags().useLocalScope().assumeVerified());
4032   llvm::errs() << "\n";
4033 }
4034 
4035 void Block::print(raw_ostream &os) {
4036   Operation *parentOp = getParentOp();
4037   if (!parentOp) {
4038     os << "<<UNLINKED BLOCK>>\n";
4039     return;
4040   }
4041   // Get the top-level op.
4042   while (auto *nextOp = parentOp->getParentOp())
4043     parentOp = nextOp;
4044 
4045   AsmState state(parentOp);
4046   print(os, state);
4047 }
4048 void Block::print(raw_ostream &os, AsmState &state) {
4049   OperationPrinter(os, state.getImpl()).print(this);
4050 }
4051 
4052 void Block::dump() { print(llvm::errs()); }
4053 
4054 /// Print out the name of the block without printing its body.
4055 void Block::printAsOperand(raw_ostream &os, bool printType) {
4056   Operation *parentOp = getParentOp();
4057   if (!parentOp) {
4058     os << "<<UNLINKED BLOCK>>\n";
4059     return;
4060   }
4061   AsmState state(parentOp);
4062   printAsOperand(os, state);
4063 }
4064 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
4065   OperationPrinter printer(os, state.getImpl());
4066   printer.printBlockName(this);
4067 }
4068 
4069 raw_ostream &mlir::operator<<(raw_ostream &os, Block &block) {
4070   block.print(os);
4071   return os;
4072 }
4073 
4074 //===--------------------------------------------------------------------===//
4075 // Custom printers
4076 //===--------------------------------------------------------------------===//
4077 namespace mlir {
4078 
4079 void printDimensionList(OpAsmPrinter &printer, Operation *op,
4080                         ArrayRef<int64_t> dimensions) {
4081   if (dimensions.empty())
4082     printer << "[";
4083   printer.printDimensionList(dimensions);
4084   if (dimensions.empty())
4085     printer << "]";
4086 }
4087 
4088 ParseResult parseDimensionList(OpAsmParser &parser,
4089                                DenseI64ArrayAttr &dimensions) {
4090   // Empty list case denoted by "[]".
4091   if (succeeded(parser.parseOptionalLSquare())) {
4092     if (failed(parser.parseRSquare())) {
4093       return parser.emitError(parser.getCurrentLocation())
4094              << "Failed parsing dimension list.";
4095     }
4096     dimensions =
4097         DenseI64ArrayAttr::get(parser.getContext(), ArrayRef<int64_t>());
4098     return success();
4099   }
4100 
4101   // Non-empty list case.
4102   SmallVector<int64_t> shapeArr;
4103   if (failed(parser.parseDimensionList(shapeArr, true, false))) {
4104     return parser.emitError(parser.getCurrentLocation())
4105            << "Failed parsing dimension list.";
4106   }
4107   if (shapeArr.empty()) {
4108     return parser.emitError(parser.getCurrentLocation())
4109            << "Failed parsing dimension list. Did you mean an empty list? It "
4110               "must be denoted by \"[]\".";
4111   }
4112   dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
4113   return success();
4114 }
4115 
4116 } // namespace mlir
4117