xref: /llvm-project/mlir/examples/toy/Ch4/parser/AST.cpp (revision ec6da0652282d29569faa628d2180909fa588906)
1 //===- AST.cpp - Helper for printing out the Toy AST ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AST dump for the Toy language.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "toy/AST.h"
14 
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 #include "llvm/Support/Casting.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include <string>
21 
22 using namespace toy;
23 
24 namespace {
25 
26 // RAII helper to manage increasing/decreasing the indentation as we traverse
27 // the AST
28 struct Indent {
Indent__anon16d009e10111::Indent29   Indent(int &level) : level(level) { ++level; }
~Indent__anon16d009e10111::Indent30   ~Indent() { --level; }
31   int &level;
32 };
33 
34 /// Helper class that implement the AST tree traversal and print the nodes along
35 /// the way. The only data member is the current indentation level.
36 class ASTDumper {
37 public:
38   void dump(ModuleAST *node);
39 
40 private:
41   void dump(const VarType &type);
42   void dump(VarDeclExprAST *varDecl);
43   void dump(ExprAST *expr);
44   void dump(ExprASTList *exprList);
45   void dump(NumberExprAST *num);
46   void dump(LiteralExprAST *node);
47   void dump(VariableExprAST *node);
48   void dump(ReturnExprAST *node);
49   void dump(BinaryExprAST *node);
50   void dump(CallExprAST *node);
51   void dump(PrintExprAST *node);
52   void dump(PrototypeAST *node);
53   void dump(FunctionAST *node);
54 
55   // Actually print spaces matching the current indentation level
indent()56   void indent() {
57     for (int i = 0; i < curIndent; i++)
58       llvm::errs() << "  ";
59   }
60   int curIndent = 0;
61 };
62 
63 } // namespace
64 
65 /// Return a formatted string for the location of any node
66 template <typename T>
loc(T * node)67 static std::string loc(T *node) {
68   const auto &loc = node->loc();
69   return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
70           llvm::Twine(loc.col))
71       .str();
72 }
73 
74 // Helper Macro to bump the indentation level and print the leading spaces for
75 // the current indentations
76 #define INDENT()                                                               \
77   Indent level_(curIndent);                                                    \
78   indent();
79 
80 /// Dispatch to a generic expressions to the appropriate subclass using RTTI
dump(ExprAST * expr)81 void ASTDumper::dump(ExprAST *expr) {
82   llvm::TypeSwitch<ExprAST *>(expr)
83       .Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
84             PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
85           [&](auto *node) { this->dump(node); })
86       .Default([&](ExprAST *) {
87         // No match, fallback to a generic message
88         INDENT();
89         llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
90       });
91 }
92 
93 /// A variable declaration is printing the variable name, the type, and then
94 /// recurse in the initializer value.
dump(VarDeclExprAST * varDecl)95 void ASTDumper::dump(VarDeclExprAST *varDecl) {
96   INDENT();
97   llvm::errs() << "VarDecl " << varDecl->getName();
98   dump(varDecl->getType());
99   llvm::errs() << " " << loc(varDecl) << "\n";
100   dump(varDecl->getInitVal());
101 }
102 
103 /// A "block", or a list of expression
dump(ExprASTList * exprList)104 void ASTDumper::dump(ExprASTList *exprList) {
105   INDENT();
106   llvm::errs() << "Block {\n";
107   for (auto &expr : *exprList)
108     dump(expr.get());
109   indent();
110   llvm::errs() << "} // Block\n";
111 }
112 
113 /// A literal number, just print the value.
dump(NumberExprAST * num)114 void ASTDumper::dump(NumberExprAST *num) {
115   INDENT();
116   llvm::errs() << num->getValue() << " " << loc(num) << "\n";
117 }
118 
119 /// Helper to print recursively a literal. This handles nested array like:
120 ///    [ [ 1, 2 ], [ 3, 4 ] ]
121 /// We print out such array with the dimensions spelled out at every level:
122 ///    <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
printLitHelper(ExprAST * litOrNum)123 void printLitHelper(ExprAST *litOrNum) {
124   // Inside a literal expression we can have either a number or another literal
125   if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
126     llvm::errs() << num->getValue();
127     return;
128   }
129   auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
130 
131   // Print the dimension for this literal first
132   llvm::errs() << "<";
133   llvm::interleaveComma(literal->getDims(), llvm::errs());
134   llvm::errs() << ">";
135 
136   // Now print the content, recursing on every element of the list
137   llvm::errs() << "[ ";
138   llvm::interleaveComma(literal->getValues(), llvm::errs(),
139                         [&](auto &elt) { printLitHelper(elt.get()); });
140   llvm::errs() << "]";
141 }
142 
143 /// Print a literal, see the recursive helper above for the implementation.
dump(LiteralExprAST * node)144 void ASTDumper::dump(LiteralExprAST *node) {
145   INDENT();
146   llvm::errs() << "Literal: ";
147   printLitHelper(node);
148   llvm::errs() << " " << loc(node) << "\n";
149 }
150 
151 /// Print a variable reference (just a name).
dump(VariableExprAST * node)152 void ASTDumper::dump(VariableExprAST *node) {
153   INDENT();
154   llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
155 }
156 
157 /// Return statement print the return and its (optional) argument.
dump(ReturnExprAST * node)158 void ASTDumper::dump(ReturnExprAST *node) {
159   INDENT();
160   llvm::errs() << "Return\n";
161   if (node->getExpr().has_value())
162     return dump(*node->getExpr());
163   {
164     INDENT();
165     llvm::errs() << "(void)\n";
166   }
167 }
168 
169 /// Print a binary operation, first the operator, then recurse into LHS and RHS.
dump(BinaryExprAST * node)170 void ASTDumper::dump(BinaryExprAST *node) {
171   INDENT();
172   llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
173   dump(node->getLHS());
174   dump(node->getRHS());
175 }
176 
177 /// Print a call expression, first the callee name and the list of args by
178 /// recursing into each individual argument.
dump(CallExprAST * node)179 void ASTDumper::dump(CallExprAST *node) {
180   INDENT();
181   llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
182   for (auto &arg : node->getArgs())
183     dump(arg.get());
184   indent();
185   llvm::errs() << "]\n";
186 }
187 
188 /// Print a builtin print call, first the builtin name and then the argument.
dump(PrintExprAST * node)189 void ASTDumper::dump(PrintExprAST *node) {
190   INDENT();
191   llvm::errs() << "Print [ " << loc(node) << "\n";
192   dump(node->getArg());
193   indent();
194   llvm::errs() << "]\n";
195 }
196 
197 /// Print type: only the shape is printed in between '<' and '>'
dump(const VarType & type)198 void ASTDumper::dump(const VarType &type) {
199   llvm::errs() << "<";
200   llvm::interleaveComma(type.shape, llvm::errs());
201   llvm::errs() << ">";
202 }
203 
204 /// Print a function prototype, first the function name, and then the list of
205 /// parameters names.
dump(PrototypeAST * node)206 void ASTDumper::dump(PrototypeAST *node) {
207   INDENT();
208   llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n";
209   indent();
210   llvm::errs() << "Params: [";
211   llvm::interleaveComma(node->getArgs(), llvm::errs(),
212                         [](auto &arg) { llvm::errs() << arg->getName(); });
213   llvm::errs() << "]\n";
214 }
215 
216 /// Print a function, first the prototype and then the body.
dump(FunctionAST * node)217 void ASTDumper::dump(FunctionAST *node) {
218   INDENT();
219   llvm::errs() << "Function \n";
220   dump(node->getProto());
221   dump(node->getBody());
222 }
223 
224 /// Print a module, actually loop over the functions and print them in sequence.
dump(ModuleAST * node)225 void ASTDumper::dump(ModuleAST *node) {
226   INDENT();
227   llvm::errs() << "Module:\n";
228   for (auto &f : *node)
229     dump(&f);
230 }
231 
232 namespace toy {
233 
234 // Public API
dump(ModuleAST & module)235 void dump(ModuleAST &module) { ASTDumper().dump(&module); }
236 
237 } // namespace toy
238