1 //===- NodePrinter.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Tools/PDLL/AST/Context.h"
10 #include "mlir/Tools/PDLL/AST/Nodes.h"
11 #include "llvm/ADT/StringExtras.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 #include "llvm/Support/SaveAndRestore.h"
14 #include "llvm/Support/ScopedPrinter.h"
15 #include <optional>
16
17 using namespace mlir;
18 using namespace mlir::pdll::ast;
19
20 //===----------------------------------------------------------------------===//
21 // NodePrinter
22 //===----------------------------------------------------------------------===//
23
24 namespace {
25 class NodePrinter {
26 public:
NodePrinter(raw_ostream & os)27 NodePrinter(raw_ostream &os) : os(os) {}
28
29 /// Print the given type to the stream.
30 void print(Type type);
31
32 /// Print the given node to the stream.
33 void print(const Node *node);
34
35 private:
36 /// Print a range containing children of a node.
37 template <typename RangeT,
38 std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
39 * = nullptr>
printChildren(RangeT && range)40 void printChildren(RangeT &&range) {
41 if (range.empty())
42 return;
43
44 // Print the first N-1 elements with a prefix of "|-".
45 auto it = std::begin(range);
46 for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
47 print(*it);
48
49 // Print the last element.
50 elementIndentStack.back() = true;
51 print(*it);
52 }
53 template <typename RangeT, typename... OthersT,
54 std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
55 * = nullptr>
printChildren(RangeT && range,OthersT &&...others)56 void printChildren(RangeT &&range, OthersT &&...others) {
57 printChildren(ArrayRef<const Node *>({range, others...}));
58 }
59 /// Print a range containing children of a node, nesting the children under
60 /// the given label.
61 template <typename RangeT>
printChildren(StringRef label,RangeT && range)62 void printChildren(StringRef label, RangeT &&range) {
63 if (range.empty())
64 return;
65 elementIndentStack.reserve(elementIndentStack.size() + 1);
66 llvm::SaveAndRestore lastElement(elementIndentStack.back(), true);
67
68 printIndent();
69 os << label << "`\n";
70 elementIndentStack.push_back(/*isLastElt*/ false);
71 printChildren(std::forward<RangeT>(range));
72 elementIndentStack.pop_back();
73 }
74
75 /// Print the given derived node to the stream.
76 void printImpl(const CompoundStmt *stmt);
77 void printImpl(const EraseStmt *stmt);
78 void printImpl(const LetStmt *stmt);
79 void printImpl(const ReplaceStmt *stmt);
80 void printImpl(const ReturnStmt *stmt);
81 void printImpl(const RewriteStmt *stmt);
82
83 void printImpl(const AttributeExpr *expr);
84 void printImpl(const CallExpr *expr);
85 void printImpl(const DeclRefExpr *expr);
86 void printImpl(const MemberAccessExpr *expr);
87 void printImpl(const OperationExpr *expr);
88 void printImpl(const RangeExpr *expr);
89 void printImpl(const TupleExpr *expr);
90 void printImpl(const TypeExpr *expr);
91
92 void printImpl(const AttrConstraintDecl *decl);
93 void printImpl(const OpConstraintDecl *decl);
94 void printImpl(const TypeConstraintDecl *decl);
95 void printImpl(const TypeRangeConstraintDecl *decl);
96 void printImpl(const UserConstraintDecl *decl);
97 void printImpl(const ValueConstraintDecl *decl);
98 void printImpl(const ValueRangeConstraintDecl *decl);
99 void printImpl(const NamedAttributeDecl *decl);
100 void printImpl(const OpNameDecl *decl);
101 void printImpl(const PatternDecl *decl);
102 void printImpl(const UserRewriteDecl *decl);
103 void printImpl(const VariableDecl *decl);
104 void printImpl(const Module *module);
105
106 /// Print the current indent stack.
printIndent()107 void printIndent() {
108 if (elementIndentStack.empty())
109 return;
110
111 for (bool isLastElt : llvm::ArrayRef(elementIndentStack).drop_back())
112 os << (isLastElt ? " " : " |");
113 os << (elementIndentStack.back() ? " `" : " |");
114 }
115
116 /// The raw output stream.
117 raw_ostream &os;
118
119 /// A stack of indents and a flag indicating if the current element being
120 /// printed at that indent is the last element.
121 SmallVector<bool> elementIndentStack;
122 };
123 } // namespace
124
print(Type type)125 void NodePrinter::print(Type type) {
126 // Protect against invalid inputs.
127 if (!type) {
128 os << "Type<NULL>";
129 return;
130 }
131
132 TypeSwitch<Type>(type)
133 .Case([&](AttributeType) { os << "Attr"; })
134 .Case([&](ConstraintType) { os << "Constraint"; })
135 .Case([&](OperationType type) {
136 os << "Op";
137 if (std::optional<StringRef> name = type.getName())
138 os << "<" << *name << ">";
139 })
140 .Case([&](RangeType type) {
141 print(type.getElementType());
142 os << "Range";
143 })
144 .Case([&](RewriteType) { os << "Rewrite"; })
145 .Case([&](TupleType type) {
146 os << "Tuple<";
147 llvm::interleaveComma(
148 llvm::zip(type.getElementNames(), type.getElementTypes()), os,
149 [&](auto it) {
150 if (!std::get<0>(it).empty())
151 os << std::get<0>(it) << ": ";
152 this->print(std::get<1>(it));
153 });
154 os << ">";
155 })
156 .Case([&](TypeType) { os << "Type"; })
157 .Case([&](ValueType) { os << "Value"; })
158 .Default([](Type) { llvm_unreachable("unknown AST type"); });
159 }
160
print(const Node * node)161 void NodePrinter::print(const Node *node) {
162 printIndent();
163 os << "-";
164
165 elementIndentStack.push_back(/*isLastElt*/ false);
166 TypeSwitch<const Node *>(node)
167 .Case<
168 // Statements.
169 const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
170 const ReturnStmt, const RewriteStmt,
171
172 // Expressions.
173 const AttributeExpr, const CallExpr, const DeclRefExpr,
174 const MemberAccessExpr, const OperationExpr, const RangeExpr,
175 const TupleExpr, const TypeExpr,
176
177 // Decls.
178 const AttrConstraintDecl, const OpConstraintDecl,
179 const TypeConstraintDecl, const TypeRangeConstraintDecl,
180 const UserConstraintDecl, const ValueConstraintDecl,
181 const ValueRangeConstraintDecl, const NamedAttributeDecl,
182 const OpNameDecl, const PatternDecl, const UserRewriteDecl,
183 const VariableDecl,
184
185 const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
186 .Default([](const Node *) { llvm_unreachable("unknown AST node"); });
187 elementIndentStack.pop_back();
188 }
189
printImpl(const CompoundStmt * stmt)190 void NodePrinter::printImpl(const CompoundStmt *stmt) {
191 os << "CompoundStmt " << stmt << "\n";
192 printChildren(stmt->getChildren());
193 }
194
printImpl(const EraseStmt * stmt)195 void NodePrinter::printImpl(const EraseStmt *stmt) {
196 os << "EraseStmt " << stmt << "\n";
197 printChildren(stmt->getRootOpExpr());
198 }
199
printImpl(const LetStmt * stmt)200 void NodePrinter::printImpl(const LetStmt *stmt) {
201 os << "LetStmt " << stmt << "\n";
202 printChildren(stmt->getVarDecl());
203 }
204
printImpl(const ReplaceStmt * stmt)205 void NodePrinter::printImpl(const ReplaceStmt *stmt) {
206 os << "ReplaceStmt " << stmt << "\n";
207 printChildren(stmt->getRootOpExpr());
208 printChildren("ReplValues", stmt->getReplExprs());
209 }
210
printImpl(const ReturnStmt * stmt)211 void NodePrinter::printImpl(const ReturnStmt *stmt) {
212 os << "ReturnStmt " << stmt << "\n";
213 printChildren(stmt->getResultExpr());
214 }
215
printImpl(const RewriteStmt * stmt)216 void NodePrinter::printImpl(const RewriteStmt *stmt) {
217 os << "RewriteStmt " << stmt << "\n";
218 printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
219 }
220
printImpl(const AttributeExpr * expr)221 void NodePrinter::printImpl(const AttributeExpr *expr) {
222 os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
223 }
224
printImpl(const CallExpr * expr)225 void NodePrinter::printImpl(const CallExpr *expr) {
226 os << "CallExpr " << expr << " Type<";
227 print(expr->getType());
228 os << ">";
229 if (expr->getIsNegated())
230 os << " Negated";
231 os << "\n";
232 printChildren(expr->getCallableExpr());
233 printChildren("Arguments", expr->getArguments());
234 }
235
printImpl(const DeclRefExpr * expr)236 void NodePrinter::printImpl(const DeclRefExpr *expr) {
237 os << "DeclRefExpr " << expr << " Type<";
238 print(expr->getType());
239 os << ">\n";
240 printChildren(expr->getDecl());
241 }
242
printImpl(const MemberAccessExpr * expr)243 void NodePrinter::printImpl(const MemberAccessExpr *expr) {
244 os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
245 << "> Type<";
246 print(expr->getType());
247 os << ">\n";
248 printChildren(expr->getParentExpr());
249 }
250
printImpl(const OperationExpr * expr)251 void NodePrinter::printImpl(const OperationExpr *expr) {
252 os << "OperationExpr " << expr << " Type<";
253 print(expr->getType());
254 os << ">\n";
255
256 printChildren(expr->getNameDecl());
257 printChildren("Operands", expr->getOperands());
258 printChildren("Result Types", expr->getResultTypes());
259 printChildren("Attributes", expr->getAttributes());
260 }
261
printImpl(const RangeExpr * expr)262 void NodePrinter::printImpl(const RangeExpr *expr) {
263 os << "RangeExpr " << expr << " Type<";
264 print(expr->getType());
265 os << ">\n";
266
267 printChildren(expr->getElements());
268 }
269
printImpl(const TupleExpr * expr)270 void NodePrinter::printImpl(const TupleExpr *expr) {
271 os << "TupleExpr " << expr << " Type<";
272 print(expr->getType());
273 os << ">\n";
274
275 printChildren(expr->getElements());
276 }
277
printImpl(const TypeExpr * expr)278 void NodePrinter::printImpl(const TypeExpr *expr) {
279 os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
280 }
281
printImpl(const AttrConstraintDecl * decl)282 void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
283 os << "AttrConstraintDecl " << decl << "\n";
284 if (const auto *typeExpr = decl->getTypeExpr())
285 printChildren(typeExpr);
286 }
287
printImpl(const OpConstraintDecl * decl)288 void NodePrinter::printImpl(const OpConstraintDecl *decl) {
289 os << "OpConstraintDecl " << decl << "\n";
290 printChildren(decl->getNameDecl());
291 }
292
printImpl(const TypeConstraintDecl * decl)293 void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
294 os << "TypeConstraintDecl " << decl << "\n";
295 }
296
printImpl(const TypeRangeConstraintDecl * decl)297 void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
298 os << "TypeRangeConstraintDecl " << decl << "\n";
299 }
300
printImpl(const UserConstraintDecl * decl)301 void NodePrinter::printImpl(const UserConstraintDecl *decl) {
302 os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName()
303 << "> ResultType<" << decl->getResultType() << ">";
304 if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
305 os << " Code<";
306 llvm::printEscapedString(*codeBlock, os);
307 os << ">";
308 }
309 os << "\n";
310 printChildren("Inputs", decl->getInputs());
311 printChildren("Results", decl->getResults());
312 if (const CompoundStmt *body = decl->getBody())
313 printChildren(body);
314 }
315
printImpl(const ValueConstraintDecl * decl)316 void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
317 os << "ValueConstraintDecl " << decl << "\n";
318 if (const auto *typeExpr = decl->getTypeExpr())
319 printChildren(typeExpr);
320 }
321
printImpl(const ValueRangeConstraintDecl * decl)322 void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
323 os << "ValueRangeConstraintDecl " << decl << "\n";
324 if (const auto *typeExpr = decl->getTypeExpr())
325 printChildren(typeExpr);
326 }
327
printImpl(const NamedAttributeDecl * decl)328 void NodePrinter::printImpl(const NamedAttributeDecl *decl) {
329 os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName()
330 << ">\n";
331 printChildren(decl->getValue());
332 }
333
printImpl(const OpNameDecl * decl)334 void NodePrinter::printImpl(const OpNameDecl *decl) {
335 os << "OpNameDecl " << decl;
336 if (std::optional<StringRef> name = decl->getName())
337 os << " Name<" << *name << ">";
338 os << "\n";
339 }
340
printImpl(const PatternDecl * decl)341 void NodePrinter::printImpl(const PatternDecl *decl) {
342 os << "PatternDecl " << decl;
343 if (const Name *name = decl->getName())
344 os << " Name<" << name->getName() << ">";
345 if (std::optional<uint16_t> benefit = decl->getBenefit())
346 os << " Benefit<" << *benefit << ">";
347 if (decl->hasBoundedRewriteRecursion())
348 os << " Recursion";
349
350 os << "\n";
351 printChildren(decl->getBody());
352 }
353
printImpl(const UserRewriteDecl * decl)354 void NodePrinter::printImpl(const UserRewriteDecl *decl) {
355 os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName()
356 << "> ResultType<" << decl->getResultType() << ">";
357 if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
358 os << " Code<";
359 llvm::printEscapedString(*codeBlock, os);
360 os << ">";
361 }
362 os << "\n";
363 printChildren("Inputs", decl->getInputs());
364 printChildren("Results", decl->getResults());
365 if (const CompoundStmt *body = decl->getBody())
366 printChildren(body);
367 }
368
printImpl(const VariableDecl * decl)369 void NodePrinter::printImpl(const VariableDecl *decl) {
370 os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
371 << "> Type<";
372 print(decl->getType());
373 os << ">\n";
374 if (Expr *initExpr = decl->getInitExpr())
375 printChildren(initExpr);
376
377 auto constraints =
378 llvm::map_range(decl->getConstraints(),
379 [](const ConstraintRef &ref) { return ref.constraint; });
380 printChildren("Constraints", constraints);
381 }
382
printImpl(const Module * module)383 void NodePrinter::printImpl(const Module *module) {
384 os << "Module " << module << "\n";
385 printChildren(module->getChildren());
386 }
387
388 //===----------------------------------------------------------------------===//
389 // Entry point
390 //===----------------------------------------------------------------------===//
391
print(raw_ostream & os) const392 void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
393
print(raw_ostream & os) const394 void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }
395