xref: /llvm-project/mlir/examples/toy/Ch7/parser/AST.cpp (revision ec6da0652282d29569faa628d2180909fa588906)
1 //===- AST.cpp - Helper for printing out the Toy AST ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AST dump for the Toy language.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "toy/AST.h"
14 
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 #include "llvm/Support/Casting.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include <string>
21 
22 using namespace toy;
23 
24 namespace {
25 
26 // RAII helper to manage increasing/decreasing the indentation as we traverse
27 // the AST
28 struct Indent {
Indent__anon3b6f7b840111::Indent29   Indent(int &level) : level(level) { ++level; }
~Indent__anon3b6f7b840111::Indent30   ~Indent() { --level; }
31   int &level;
32 };
33 
34 /// Helper class that implement the AST tree traversal and print the nodes along
35 /// the way. The only data member is the current indentation level.
36 class ASTDumper {
37 public:
38   void dump(ModuleAST *node);
39 
40 private:
41   void dump(const VarType &type);
42   void dump(VarDeclExprAST *varDecl);
43   void dump(ExprAST *expr);
44   void dump(ExprASTList *exprList);
45   void dump(NumberExprAST *num);
46   void dump(LiteralExprAST *node);
47   void dump(StructLiteralExprAST *node);
48   void dump(VariableExprAST *node);
49   void dump(ReturnExprAST *node);
50   void dump(BinaryExprAST *node);
51   void dump(CallExprAST *node);
52   void dump(PrintExprAST *node);
53   void dump(PrototypeAST *node);
54   void dump(FunctionAST *node);
55   void dump(StructAST *node);
56 
57   // Actually print spaces matching the current indentation level
indent()58   void indent() {
59     for (int i = 0; i < curIndent; i++)
60       llvm::errs() << "  ";
61   }
62   int curIndent = 0;
63 };
64 
65 } // namespace
66 
67 /// Return a formatted string for the location of any node
68 template <typename T>
loc(T * node)69 static std::string loc(T *node) {
70   const auto &loc = node->loc();
71   return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
72           llvm::Twine(loc.col))
73       .str();
74 }
75 
76 // Helper Macro to bump the indentation level and print the leading spaces for
77 // the current indentations
78 #define INDENT()                                                               \
79   Indent level_(curIndent);                                                    \
80   indent();
81 
82 /// Dispatch to a generic expressions to the appropriate subclass using RTTI
dump(ExprAST * expr)83 void ASTDumper::dump(ExprAST *expr) {
84   llvm::TypeSwitch<ExprAST *>(expr)
85       .Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
86             PrintExprAST, ReturnExprAST, StructLiteralExprAST, VarDeclExprAST,
87             VariableExprAST>([&](auto *node) { this->dump(node); })
88       .Default([&](ExprAST *) {
89         // No match, fallback to a generic message
90         INDENT();
91         llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
92       });
93 }
94 
95 /// A variable declaration is printing the variable name, the type, and then
96 /// recurse in the initializer value.
dump(VarDeclExprAST * varDecl)97 void ASTDumper::dump(VarDeclExprAST *varDecl) {
98   INDENT();
99   llvm::errs() << "VarDecl " << varDecl->getName();
100   dump(varDecl->getType());
101   llvm::errs() << " " << loc(varDecl) << "\n";
102   if (auto *initVal = varDecl->getInitVal())
103     dump(initVal);
104 }
105 
106 /// A "block", or a list of expression
dump(ExprASTList * exprList)107 void ASTDumper::dump(ExprASTList *exprList) {
108   INDENT();
109   llvm::errs() << "Block {\n";
110   for (auto &expr : *exprList)
111     dump(expr.get());
112   indent();
113   llvm::errs() << "} // Block\n";
114 }
115 
116 /// A literal number, just print the value.
dump(NumberExprAST * num)117 void ASTDumper::dump(NumberExprAST *num) {
118   INDENT();
119   llvm::errs() << num->getValue() << " " << loc(num) << "\n";
120 }
121 
122 /// Helper to print recursively a literal. This handles nested array like:
123 ///    [ [ 1, 2 ], [ 3, 4 ] ]
124 /// We print out such array with the dimensions spelled out at every level:
125 ///    <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
printLitHelper(ExprAST * litOrNum)126 void printLitHelper(ExprAST *litOrNum) {
127   // Inside a literal expression we can have either a number or another literal
128   if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
129     llvm::errs() << num->getValue();
130     return;
131   }
132   auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
133 
134   // Print the dimension for this literal first
135   llvm::errs() << "<";
136   llvm::interleaveComma(literal->getDims(), llvm::errs());
137   llvm::errs() << ">";
138 
139   // Now print the content, recursing on every element of the list
140   llvm::errs() << "[ ";
141   llvm::interleaveComma(literal->getValues(), llvm::errs(),
142                         [&](auto &elt) { printLitHelper(elt.get()); });
143   llvm::errs() << "]";
144 }
145 
146 /// Print a literal, see the recursive helper above for the implementation.
dump(LiteralExprAST * node)147 void ASTDumper::dump(LiteralExprAST *node) {
148   INDENT();
149   llvm::errs() << "Literal: ";
150   printLitHelper(node);
151   llvm::errs() << " " << loc(node) << "\n";
152 }
153 
154 /// Print a struct literal.
dump(StructLiteralExprAST * node)155 void ASTDumper::dump(StructLiteralExprAST *node) {
156   INDENT();
157   llvm::errs() << "Struct Literal: ";
158   for (auto &value : node->getValues())
159     dump(value.get());
160   indent();
161   llvm::errs() << " " << loc(node) << "\n";
162 }
163 
164 /// Print a variable reference (just a name).
dump(VariableExprAST * node)165 void ASTDumper::dump(VariableExprAST *node) {
166   INDENT();
167   llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
168 }
169 
170 /// Return statement print the return and its (optional) argument.
dump(ReturnExprAST * node)171 void ASTDumper::dump(ReturnExprAST *node) {
172   INDENT();
173   llvm::errs() << "Return\n";
174   if (node->getExpr().has_value())
175     return dump(*node->getExpr());
176   {
177     INDENT();
178     llvm::errs() << "(void)\n";
179   }
180 }
181 
182 /// Print a binary operation, first the operator, then recurse into LHS and RHS.
dump(BinaryExprAST * node)183 void ASTDumper::dump(BinaryExprAST *node) {
184   INDENT();
185   llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
186   dump(node->getLHS());
187   dump(node->getRHS());
188 }
189 
190 /// Print a call expression, first the callee name and the list of args by
191 /// recursing into each individual argument.
dump(CallExprAST * node)192 void ASTDumper::dump(CallExprAST *node) {
193   INDENT();
194   llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
195   for (auto &arg : node->getArgs())
196     dump(arg.get());
197   indent();
198   llvm::errs() << "]\n";
199 }
200 
201 /// Print a builtin print call, first the builtin name and then the argument.
dump(PrintExprAST * node)202 void ASTDumper::dump(PrintExprAST *node) {
203   INDENT();
204   llvm::errs() << "Print [ " << loc(node) << "\n";
205   dump(node->getArg());
206   indent();
207   llvm::errs() << "]\n";
208 }
209 
210 /// Print type: only the shape is printed in between '<' and '>'
dump(const VarType & type)211 void ASTDumper::dump(const VarType &type) {
212   llvm::errs() << "<";
213   if (!type.name.empty())
214     llvm::errs() << type.name;
215   else
216     llvm::interleaveComma(type.shape, llvm::errs());
217   llvm::errs() << ">";
218 }
219 
220 /// Print a function prototype, first the function name, and then the list of
221 /// parameters names.
dump(PrototypeAST * node)222 void ASTDumper::dump(PrototypeAST *node) {
223   INDENT();
224   llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n";
225   indent();
226   llvm::errs() << "Params: [";
227   llvm::interleaveComma(node->getArgs(), llvm::errs(),
228                         [](auto &arg) { llvm::errs() << arg->getName(); });
229   llvm::errs() << "]\n";
230 }
231 
232 /// Print a function, first the prototype and then the body.
dump(FunctionAST * node)233 void ASTDumper::dump(FunctionAST *node) {
234   INDENT();
235   llvm::errs() << "Function \n";
236   dump(node->getProto());
237   dump(node->getBody());
238 }
239 
240 /// Print a struct.
dump(StructAST * node)241 void ASTDumper::dump(StructAST *node) {
242   INDENT();
243   llvm::errs() << "Struct: " << node->getName() << " " << loc(node) << "\n";
244 
245   {
246     INDENT();
247     llvm::errs() << "Variables: [\n";
248     for (auto &variable : node->getVariables())
249       dump(variable.get());
250     indent();
251     llvm::errs() << "]\n";
252   }
253 }
254 
255 /// Print a module, actually loop over the functions and print them in sequence.
dump(ModuleAST * node)256 void ASTDumper::dump(ModuleAST *node) {
257   INDENT();
258   llvm::errs() << "Module:\n";
259   for (auto &record : *node) {
260     if (FunctionAST *function = llvm::dyn_cast<FunctionAST>(record.get()))
261       dump(function);
262     else if (StructAST *str = llvm::dyn_cast<StructAST>(record.get()))
263       dump(str);
264     else
265       llvm::errs() << "<unknown Record, kind " << record->getKind() << ">\n";
266   }
267 }
268 
269 namespace toy {
270 
271 // Public API
dump(ModuleAST & module)272 void dump(ModuleAST &module) { ASTDumper().dump(&module); }
273 
274 } // namespace toy
275