xref: /llvm-project/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp (revision 930916c7f3622870b40138dafcc5f94740404e8c)
1 //===- NodePrinter.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/AST/Context.h"
10 #include "mlir/Tools/PDLL/AST/Nodes.h"
11 #include "llvm/ADT/StringExtras.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 #include "llvm/Support/SaveAndRestore.h"
14 #include "llvm/Support/ScopedPrinter.h"
15 #include <optional>
16 
17 using namespace mlir;
18 using namespace mlir::pdll::ast;
19 
20 //===----------------------------------------------------------------------===//
21 // NodePrinter
22 //===----------------------------------------------------------------------===//
23 
24 namespace {
25 class NodePrinter {
26 public:
NodePrinter(raw_ostream & os)27   NodePrinter(raw_ostream &os) : os(os) {}
28 
29   /// Print the given type to the stream.
30   void print(Type type);
31 
32   /// Print the given node to the stream.
33   void print(const Node *node);
34 
35 private:
36   /// Print a range containing children of a node.
37   template <typename RangeT,
38             std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
39                 * = nullptr>
printChildren(RangeT && range)40   void printChildren(RangeT &&range) {
41     if (range.empty())
42       return;
43 
44     // Print the first N-1 elements with a prefix of "|-".
45     auto it = std::begin(range);
46     for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
47       print(*it);
48 
49     // Print the last element.
50     elementIndentStack.back() = true;
51     print(*it);
52   }
53   template <typename RangeT, typename... OthersT,
54             std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
55                 * = nullptr>
printChildren(RangeT && range,OthersT &&...others)56   void printChildren(RangeT &&range, OthersT &&...others) {
57     printChildren(ArrayRef<const Node *>({range, others...}));
58   }
59   /// Print a range containing children of a node, nesting the children under
60   /// the given label.
61   template <typename RangeT>
printChildren(StringRef label,RangeT && range)62   void printChildren(StringRef label, RangeT &&range) {
63     if (range.empty())
64       return;
65     elementIndentStack.reserve(elementIndentStack.size() + 1);
66     llvm::SaveAndRestore lastElement(elementIndentStack.back(), true);
67 
68     printIndent();
69     os << label << "`\n";
70     elementIndentStack.push_back(/*isLastElt*/ false);
71     printChildren(std::forward<RangeT>(range));
72     elementIndentStack.pop_back();
73   }
74 
75   /// Print the given derived node to the stream.
76   void printImpl(const CompoundStmt *stmt);
77   void printImpl(const EraseStmt *stmt);
78   void printImpl(const LetStmt *stmt);
79   void printImpl(const ReplaceStmt *stmt);
80   void printImpl(const ReturnStmt *stmt);
81   void printImpl(const RewriteStmt *stmt);
82 
83   void printImpl(const AttributeExpr *expr);
84   void printImpl(const CallExpr *expr);
85   void printImpl(const DeclRefExpr *expr);
86   void printImpl(const MemberAccessExpr *expr);
87   void printImpl(const OperationExpr *expr);
88   void printImpl(const RangeExpr *expr);
89   void printImpl(const TupleExpr *expr);
90   void printImpl(const TypeExpr *expr);
91 
92   void printImpl(const AttrConstraintDecl *decl);
93   void printImpl(const OpConstraintDecl *decl);
94   void printImpl(const TypeConstraintDecl *decl);
95   void printImpl(const TypeRangeConstraintDecl *decl);
96   void printImpl(const UserConstraintDecl *decl);
97   void printImpl(const ValueConstraintDecl *decl);
98   void printImpl(const ValueRangeConstraintDecl *decl);
99   void printImpl(const NamedAttributeDecl *decl);
100   void printImpl(const OpNameDecl *decl);
101   void printImpl(const PatternDecl *decl);
102   void printImpl(const UserRewriteDecl *decl);
103   void printImpl(const VariableDecl *decl);
104   void printImpl(const Module *module);
105 
106   /// Print the current indent stack.
printIndent()107   void printIndent() {
108     if (elementIndentStack.empty())
109       return;
110 
111     for (bool isLastElt : llvm::ArrayRef(elementIndentStack).drop_back())
112       os << (isLastElt ? "  " : " |");
113     os << (elementIndentStack.back() ? " `" : " |");
114   }
115 
116   /// The raw output stream.
117   raw_ostream &os;
118 
119   /// A stack of indents and a flag indicating if the current element being
120   /// printed at that indent is the last element.
121   SmallVector<bool> elementIndentStack;
122 };
123 } // namespace
124 
print(Type type)125 void NodePrinter::print(Type type) {
126   // Protect against invalid inputs.
127   if (!type) {
128     os << "Type<NULL>";
129     return;
130   }
131 
132   TypeSwitch<Type>(type)
133       .Case([&](AttributeType) { os << "Attr"; })
134       .Case([&](ConstraintType) { os << "Constraint"; })
135       .Case([&](OperationType type) {
136         os << "Op";
137         if (std::optional<StringRef> name = type.getName())
138           os << "<" << *name << ">";
139       })
140       .Case([&](RangeType type) {
141         print(type.getElementType());
142         os << "Range";
143       })
144       .Case([&](RewriteType) { os << "Rewrite"; })
145       .Case([&](TupleType type) {
146         os << "Tuple<";
147         llvm::interleaveComma(
148             llvm::zip(type.getElementNames(), type.getElementTypes()), os,
149             [&](auto it) {
150               if (!std::get<0>(it).empty())
151                 os << std::get<0>(it) << ": ";
152               this->print(std::get<1>(it));
153             });
154         os << ">";
155       })
156       .Case([&](TypeType) { os << "Type"; })
157       .Case([&](ValueType) { os << "Value"; })
158       .Default([](Type) { llvm_unreachable("unknown AST type"); });
159 }
160 
print(const Node * node)161 void NodePrinter::print(const Node *node) {
162   printIndent();
163   os << "-";
164 
165   elementIndentStack.push_back(/*isLastElt*/ false);
166   TypeSwitch<const Node *>(node)
167       .Case<
168           // Statements.
169           const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
170           const ReturnStmt, const RewriteStmt,
171 
172           // Expressions.
173           const AttributeExpr, const CallExpr, const DeclRefExpr,
174           const MemberAccessExpr, const OperationExpr, const RangeExpr,
175           const TupleExpr, const TypeExpr,
176 
177           // Decls.
178           const AttrConstraintDecl, const OpConstraintDecl,
179           const TypeConstraintDecl, const TypeRangeConstraintDecl,
180           const UserConstraintDecl, const ValueConstraintDecl,
181           const ValueRangeConstraintDecl, const NamedAttributeDecl,
182           const OpNameDecl, const PatternDecl, const UserRewriteDecl,
183           const VariableDecl,
184 
185           const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
186       .Default([](const Node *) { llvm_unreachable("unknown AST node"); });
187   elementIndentStack.pop_back();
188 }
189 
printImpl(const CompoundStmt * stmt)190 void NodePrinter::printImpl(const CompoundStmt *stmt) {
191   os << "CompoundStmt " << stmt << "\n";
192   printChildren(stmt->getChildren());
193 }
194 
printImpl(const EraseStmt * stmt)195 void NodePrinter::printImpl(const EraseStmt *stmt) {
196   os << "EraseStmt " << stmt << "\n";
197   printChildren(stmt->getRootOpExpr());
198 }
199 
printImpl(const LetStmt * stmt)200 void NodePrinter::printImpl(const LetStmt *stmt) {
201   os << "LetStmt " << stmt << "\n";
202   printChildren(stmt->getVarDecl());
203 }
204 
printImpl(const ReplaceStmt * stmt)205 void NodePrinter::printImpl(const ReplaceStmt *stmt) {
206   os << "ReplaceStmt " << stmt << "\n";
207   printChildren(stmt->getRootOpExpr());
208   printChildren("ReplValues", stmt->getReplExprs());
209 }
210 
printImpl(const ReturnStmt * stmt)211 void NodePrinter::printImpl(const ReturnStmt *stmt) {
212   os << "ReturnStmt " << stmt << "\n";
213   printChildren(stmt->getResultExpr());
214 }
215 
printImpl(const RewriteStmt * stmt)216 void NodePrinter::printImpl(const RewriteStmt *stmt) {
217   os << "RewriteStmt " << stmt << "\n";
218   printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
219 }
220 
printImpl(const AttributeExpr * expr)221 void NodePrinter::printImpl(const AttributeExpr *expr) {
222   os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
223 }
224 
printImpl(const CallExpr * expr)225 void NodePrinter::printImpl(const CallExpr *expr) {
226   os << "CallExpr " << expr << " Type<";
227   print(expr->getType());
228   os << ">";
229   if (expr->getIsNegated())
230     os << " Negated";
231   os << "\n";
232   printChildren(expr->getCallableExpr());
233   printChildren("Arguments", expr->getArguments());
234 }
235 
printImpl(const DeclRefExpr * expr)236 void NodePrinter::printImpl(const DeclRefExpr *expr) {
237   os << "DeclRefExpr " << expr << " Type<";
238   print(expr->getType());
239   os << ">\n";
240   printChildren(expr->getDecl());
241 }
242 
printImpl(const MemberAccessExpr * expr)243 void NodePrinter::printImpl(const MemberAccessExpr *expr) {
244   os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
245      << "> Type<";
246   print(expr->getType());
247   os << ">\n";
248   printChildren(expr->getParentExpr());
249 }
250 
printImpl(const OperationExpr * expr)251 void NodePrinter::printImpl(const OperationExpr *expr) {
252   os << "OperationExpr " << expr << " Type<";
253   print(expr->getType());
254   os << ">\n";
255 
256   printChildren(expr->getNameDecl());
257   printChildren("Operands", expr->getOperands());
258   printChildren("Result Types", expr->getResultTypes());
259   printChildren("Attributes", expr->getAttributes());
260 }
261 
printImpl(const RangeExpr * expr)262 void NodePrinter::printImpl(const RangeExpr *expr) {
263   os << "RangeExpr " << expr << " Type<";
264   print(expr->getType());
265   os << ">\n";
266 
267   printChildren(expr->getElements());
268 }
269 
printImpl(const TupleExpr * expr)270 void NodePrinter::printImpl(const TupleExpr *expr) {
271   os << "TupleExpr " << expr << " Type<";
272   print(expr->getType());
273   os << ">\n";
274 
275   printChildren(expr->getElements());
276 }
277 
printImpl(const TypeExpr * expr)278 void NodePrinter::printImpl(const TypeExpr *expr) {
279   os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
280 }
281 
printImpl(const AttrConstraintDecl * decl)282 void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
283   os << "AttrConstraintDecl " << decl << "\n";
284   if (const auto *typeExpr = decl->getTypeExpr())
285     printChildren(typeExpr);
286 }
287 
printImpl(const OpConstraintDecl * decl)288 void NodePrinter::printImpl(const OpConstraintDecl *decl) {
289   os << "OpConstraintDecl " << decl << "\n";
290   printChildren(decl->getNameDecl());
291 }
292 
printImpl(const TypeConstraintDecl * decl)293 void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
294   os << "TypeConstraintDecl " << decl << "\n";
295 }
296 
printImpl(const TypeRangeConstraintDecl * decl)297 void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
298   os << "TypeRangeConstraintDecl " << decl << "\n";
299 }
300 
printImpl(const UserConstraintDecl * decl)301 void NodePrinter::printImpl(const UserConstraintDecl *decl) {
302   os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName()
303      << "> ResultType<" << decl->getResultType() << ">";
304   if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
305     os << " Code<";
306     llvm::printEscapedString(*codeBlock, os);
307     os << ">";
308   }
309   os << "\n";
310   printChildren("Inputs", decl->getInputs());
311   printChildren("Results", decl->getResults());
312   if (const CompoundStmt *body = decl->getBody())
313     printChildren(body);
314 }
315 
printImpl(const ValueConstraintDecl * decl)316 void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
317   os << "ValueConstraintDecl " << decl << "\n";
318   if (const auto *typeExpr = decl->getTypeExpr())
319     printChildren(typeExpr);
320 }
321 
printImpl(const ValueRangeConstraintDecl * decl)322 void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
323   os << "ValueRangeConstraintDecl " << decl << "\n";
324   if (const auto *typeExpr = decl->getTypeExpr())
325     printChildren(typeExpr);
326 }
327 
printImpl(const NamedAttributeDecl * decl)328 void NodePrinter::printImpl(const NamedAttributeDecl *decl) {
329   os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName()
330      << ">\n";
331   printChildren(decl->getValue());
332 }
333 
printImpl(const OpNameDecl * decl)334 void NodePrinter::printImpl(const OpNameDecl *decl) {
335   os << "OpNameDecl " << decl;
336   if (std::optional<StringRef> name = decl->getName())
337     os << " Name<" << *name << ">";
338   os << "\n";
339 }
340 
printImpl(const PatternDecl * decl)341 void NodePrinter::printImpl(const PatternDecl *decl) {
342   os << "PatternDecl " << decl;
343   if (const Name *name = decl->getName())
344     os << " Name<" << name->getName() << ">";
345   if (std::optional<uint16_t> benefit = decl->getBenefit())
346     os << " Benefit<" << *benefit << ">";
347   if (decl->hasBoundedRewriteRecursion())
348     os << " Recursion";
349 
350   os << "\n";
351   printChildren(decl->getBody());
352 }
353 
printImpl(const UserRewriteDecl * decl)354 void NodePrinter::printImpl(const UserRewriteDecl *decl) {
355   os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName()
356      << "> ResultType<" << decl->getResultType() << ">";
357   if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
358     os << " Code<";
359     llvm::printEscapedString(*codeBlock, os);
360     os << ">";
361   }
362   os << "\n";
363   printChildren("Inputs", decl->getInputs());
364   printChildren("Results", decl->getResults());
365   if (const CompoundStmt *body = decl->getBody())
366     printChildren(body);
367 }
368 
printImpl(const VariableDecl * decl)369 void NodePrinter::printImpl(const VariableDecl *decl) {
370   os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
371      << "> Type<";
372   print(decl->getType());
373   os << ">\n";
374   if (Expr *initExpr = decl->getInitExpr())
375     printChildren(initExpr);
376 
377   auto constraints =
378       llvm::map_range(decl->getConstraints(),
379                       [](const ConstraintRef &ref) { return ref.constraint; });
380   printChildren("Constraints", constraints);
381 }
382 
printImpl(const Module * module)383 void NodePrinter::printImpl(const Module *module) {
384   os << "Module " << module << "\n";
385   printChildren(module->getChildren());
386 }
387 
388 //===----------------------------------------------------------------------===//
389 // Entry point
390 //===----------------------------------------------------------------------===//
391 
print(raw_ostream & os) const392 void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
393 
print(raw_ostream & os) const394 void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }
395