xref: /llvm-project/mlir/examples/toy/Ch6/parser/AST.cpp (revision ec6da0652282d29569faa628d2180909fa588906)
10372eb41SRiver Riddle //===- AST.cpp - Helper for printing out the Toy AST ----------------------===//
20372eb41SRiver Riddle //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60372eb41SRiver Riddle //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
80372eb41SRiver Riddle //
90372eb41SRiver Riddle // This file implements the AST dump for the Toy language.
100372eb41SRiver Riddle //
110372eb41SRiver Riddle //===----------------------------------------------------------------------===//
120372eb41SRiver Riddle 
130372eb41SRiver Riddle #include "toy/AST.h"
140372eb41SRiver Riddle 
15*ec6da065SMehdi Amini #include "llvm/ADT/STLExtras.h"
160372eb41SRiver Riddle #include "llvm/ADT/Twine.h"
17ebf190fcSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
18*ec6da065SMehdi Amini #include "llvm/Support/Casting.h"
190372eb41SRiver Riddle #include "llvm/Support/raw_ostream.h"
20*ec6da065SMehdi Amini #include <string>
210372eb41SRiver Riddle 
220372eb41SRiver Riddle using namespace toy;
230372eb41SRiver Riddle 
240372eb41SRiver Riddle namespace {
250372eb41SRiver Riddle 
260372eb41SRiver Riddle // RAII helper to manage increasing/decreasing the indentation as we traverse
270372eb41SRiver Riddle // the AST
280372eb41SRiver Riddle struct Indent {
Indent__anon2f3a55a30111::Indent290372eb41SRiver Riddle   Indent(int &level) : level(level) { ++level; }
~Indent__anon2f3a55a30111::Indent300372eb41SRiver Riddle   ~Indent() { --level; }
310372eb41SRiver Riddle   int &level;
320372eb41SRiver Riddle };
330372eb41SRiver Riddle 
340372eb41SRiver Riddle /// Helper class that implement the AST tree traversal and print the nodes along
350372eb41SRiver Riddle /// the way. The only data member is the current indentation level.
360372eb41SRiver Riddle class ASTDumper {
370372eb41SRiver Riddle public:
3822cfff70SRiver Riddle   void dump(ModuleAST *node);
390372eb41SRiver Riddle 
400372eb41SRiver Riddle private:
4122cfff70SRiver Riddle   void dump(const VarType &type);
420372eb41SRiver Riddle   void dump(VarDeclExprAST *varDecl);
430372eb41SRiver Riddle   void dump(ExprAST *expr);
440372eb41SRiver Riddle   void dump(ExprASTList *exprList);
450372eb41SRiver Riddle   void dump(NumberExprAST *num);
4622cfff70SRiver Riddle   void dump(LiteralExprAST *node);
4722cfff70SRiver Riddle   void dump(VariableExprAST *node);
4822cfff70SRiver Riddle   void dump(ReturnExprAST *node);
4922cfff70SRiver Riddle   void dump(BinaryExprAST *node);
5022cfff70SRiver Riddle   void dump(CallExprAST *node);
5122cfff70SRiver Riddle   void dump(PrintExprAST *node);
5222cfff70SRiver Riddle   void dump(PrototypeAST *node);
5322cfff70SRiver Riddle   void dump(FunctionAST *node);
540372eb41SRiver Riddle 
550372eb41SRiver Riddle   // Actually print spaces matching the current indentation level
indent()560372eb41SRiver Riddle   void indent() {
570372eb41SRiver Riddle     for (int i = 0; i < curIndent; i++)
580372eb41SRiver Riddle       llvm::errs() << "  ";
590372eb41SRiver Riddle   }
600372eb41SRiver Riddle   int curIndent = 0;
610372eb41SRiver Riddle };
620372eb41SRiver Riddle 
630372eb41SRiver Riddle } // namespace
640372eb41SRiver Riddle 
650372eb41SRiver Riddle /// Return a formatted string for the location of any node
66b7f93c28SJeff Niu template <typename T>
loc(T * node)67b7f93c28SJeff Niu static std::string loc(T *node) {
6822cfff70SRiver Riddle   const auto &loc = node->loc();
690372eb41SRiver Riddle   return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
700372eb41SRiver Riddle           llvm::Twine(loc.col))
710372eb41SRiver Riddle       .str();
720372eb41SRiver Riddle }
730372eb41SRiver Riddle 
740372eb41SRiver Riddle // Helper Macro to bump the indentation level and print the leading spaces for
750372eb41SRiver Riddle // the current indentations
760372eb41SRiver Riddle #define INDENT()                                                               \
770372eb41SRiver Riddle   Indent level_(curIndent);                                                    \
780372eb41SRiver Riddle   indent();
790372eb41SRiver Riddle 
800372eb41SRiver Riddle /// Dispatch to a generic expressions to the appropriate subclass using RTTI
dump(ExprAST * expr)810372eb41SRiver Riddle void ASTDumper::dump(ExprAST *expr) {
82ebf190fcSRiver Riddle   llvm::TypeSwitch<ExprAST *>(expr)
8374278dd0SRiver Riddle       .Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
8474278dd0SRiver Riddle             PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
855a0d4803SRiver Riddle           [&](auto *node) { this->dump(node); })
8674278dd0SRiver Riddle       .Default([&](ExprAST *) {
870372eb41SRiver Riddle         // No match, fallback to a generic message
880372eb41SRiver Riddle         INDENT();
890372eb41SRiver Riddle         llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
9074278dd0SRiver Riddle       });
910372eb41SRiver Riddle }
920372eb41SRiver Riddle 
930372eb41SRiver Riddle /// A variable declaration is printing the variable name, the type, and then
940372eb41SRiver Riddle /// recurse in the initializer value.
dump(VarDeclExprAST * varDecl)950372eb41SRiver Riddle void ASTDumper::dump(VarDeclExprAST *varDecl) {
960372eb41SRiver Riddle   INDENT();
970372eb41SRiver Riddle   llvm::errs() << "VarDecl " << varDecl->getName();
980372eb41SRiver Riddle   dump(varDecl->getType());
990372eb41SRiver Riddle   llvm::errs() << " " << loc(varDecl) << "\n";
1000372eb41SRiver Riddle   dump(varDecl->getInitVal());
1010372eb41SRiver Riddle }
1020372eb41SRiver Riddle 
1030372eb41SRiver Riddle /// A "block", or a list of expression
dump(ExprASTList * exprList)1040372eb41SRiver Riddle void ASTDumper::dump(ExprASTList *exprList) {
1050372eb41SRiver Riddle   INDENT();
1060372eb41SRiver Riddle   llvm::errs() << "Block {\n";
1070372eb41SRiver Riddle   for (auto &expr : *exprList)
1080372eb41SRiver Riddle     dump(expr.get());
1090372eb41SRiver Riddle   indent();
1100372eb41SRiver Riddle   llvm::errs() << "} // Block\n";
1110372eb41SRiver Riddle }
1120372eb41SRiver Riddle 
1130372eb41SRiver Riddle /// A literal number, just print the value.
dump(NumberExprAST * num)1140372eb41SRiver Riddle void ASTDumper::dump(NumberExprAST *num) {
1150372eb41SRiver Riddle   INDENT();
1160372eb41SRiver Riddle   llvm::errs() << num->getValue() << " " << loc(num) << "\n";
1170372eb41SRiver Riddle }
1180372eb41SRiver Riddle 
119f28c5acaSKazuaki Ishizaki /// Helper to print recursively a literal. This handles nested array like:
1200372eb41SRiver Riddle ///    [ [ 1, 2 ], [ 3, 4 ] ]
1210372eb41SRiver Riddle /// We print out such array with the dimensions spelled out at every level:
1220372eb41SRiver Riddle ///    <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
printLitHelper(ExprAST * litOrNum)12322cfff70SRiver Riddle void printLitHelper(ExprAST *litOrNum) {
1240372eb41SRiver Riddle   // Inside a literal expression we can have either a number or another literal
12502b6fb21SMehdi Amini   if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
1260372eb41SRiver Riddle     llvm::errs() << num->getValue();
1270372eb41SRiver Riddle     return;
1280372eb41SRiver Riddle   }
12922cfff70SRiver Riddle   auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
1300372eb41SRiver Riddle 
1310372eb41SRiver Riddle   // Print the dimension for this literal first
1320372eb41SRiver Riddle   llvm::errs() << "<";
1332f21a579SRiver Riddle   llvm::interleaveComma(literal->getDims(), llvm::errs());
1340372eb41SRiver Riddle   llvm::errs() << ">";
1350372eb41SRiver Riddle 
1360372eb41SRiver Riddle   // Now print the content, recursing on every element of the list
1370372eb41SRiver Riddle   llvm::errs() << "[ ";
1382f21a579SRiver Riddle   llvm::interleaveComma(literal->getValues(), llvm::errs(),
13922cfff70SRiver Riddle                         [&](auto &elt) { printLitHelper(elt.get()); });
1400372eb41SRiver Riddle   llvm::errs() << "]";
1410372eb41SRiver Riddle }
1420372eb41SRiver Riddle 
1430372eb41SRiver Riddle /// Print a literal, see the recursive helper above for the implementation.
dump(LiteralExprAST * node)14422cfff70SRiver Riddle void ASTDumper::dump(LiteralExprAST *node) {
1450372eb41SRiver Riddle   INDENT();
1460372eb41SRiver Riddle   llvm::errs() << "Literal: ";
14722cfff70SRiver Riddle   printLitHelper(node);
14822cfff70SRiver Riddle   llvm::errs() << " " << loc(node) << "\n";
1490372eb41SRiver Riddle }
1500372eb41SRiver Riddle 
1510372eb41SRiver Riddle /// Print a variable reference (just a name).
dump(VariableExprAST * node)15222cfff70SRiver Riddle void ASTDumper::dump(VariableExprAST *node) {
1530372eb41SRiver Riddle   INDENT();
15422cfff70SRiver Riddle   llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
1550372eb41SRiver Riddle }
1560372eb41SRiver Riddle 
1570372eb41SRiver Riddle /// Return statement print the return and its (optional) argument.
dump(ReturnExprAST * node)15822cfff70SRiver Riddle void ASTDumper::dump(ReturnExprAST *node) {
1590372eb41SRiver Riddle   INDENT();
1600372eb41SRiver Riddle   llvm::errs() << "Return\n";
1617430894aSFangrui Song   if (node->getExpr().has_value())
16222cfff70SRiver Riddle     return dump(*node->getExpr());
1630372eb41SRiver Riddle   {
1640372eb41SRiver Riddle     INDENT();
1650372eb41SRiver Riddle     llvm::errs() << "(void)\n";
1660372eb41SRiver Riddle   }
1670372eb41SRiver Riddle }
1680372eb41SRiver Riddle 
1690372eb41SRiver Riddle /// Print a binary operation, first the operator, then recurse into LHS and RHS.
dump(BinaryExprAST * node)17022cfff70SRiver Riddle void ASTDumper::dump(BinaryExprAST *node) {
1710372eb41SRiver Riddle   INDENT();
17222cfff70SRiver Riddle   llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
17322cfff70SRiver Riddle   dump(node->getLHS());
17422cfff70SRiver Riddle   dump(node->getRHS());
1750372eb41SRiver Riddle }
1760372eb41SRiver Riddle 
1770372eb41SRiver Riddle /// Print a call expression, first the callee name and the list of args by
1780372eb41SRiver Riddle /// recursing into each individual argument.
dump(CallExprAST * node)17922cfff70SRiver Riddle void ASTDumper::dump(CallExprAST *node) {
1800372eb41SRiver Riddle   INDENT();
18122cfff70SRiver Riddle   llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
18222cfff70SRiver Riddle   for (auto &arg : node->getArgs())
1830372eb41SRiver Riddle     dump(arg.get());
1840372eb41SRiver Riddle   indent();
1850372eb41SRiver Riddle   llvm::errs() << "]\n";
1860372eb41SRiver Riddle }
1870372eb41SRiver Riddle 
1880372eb41SRiver Riddle /// Print a builtin print call, first the builtin name and then the argument.
dump(PrintExprAST * node)18922cfff70SRiver Riddle void ASTDumper::dump(PrintExprAST *node) {
1900372eb41SRiver Riddle   INDENT();
19122cfff70SRiver Riddle   llvm::errs() << "Print [ " << loc(node) << "\n";
19222cfff70SRiver Riddle   dump(node->getArg());
1930372eb41SRiver Riddle   indent();
1940372eb41SRiver Riddle   llvm::errs() << "]\n";
1950372eb41SRiver Riddle }
1960372eb41SRiver Riddle 
1970372eb41SRiver Riddle /// Print type: only the shape is printed in between '<' and '>'
dump(const VarType & type)19822cfff70SRiver Riddle void ASTDumper::dump(const VarType &type) {
1990372eb41SRiver Riddle   llvm::errs() << "<";
2002f21a579SRiver Riddle   llvm::interleaveComma(type.shape, llvm::errs());
2010372eb41SRiver Riddle   llvm::errs() << ">";
2020372eb41SRiver Riddle }
2030372eb41SRiver Riddle 
2040372eb41SRiver Riddle /// Print a function prototype, first the function name, and then the list of
2050372eb41SRiver Riddle /// parameters names.
dump(PrototypeAST * node)20622cfff70SRiver Riddle void ASTDumper::dump(PrototypeAST *node) {
2070372eb41SRiver Riddle   INDENT();
2085633813bSRahul Joshi   llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n";
2090372eb41SRiver Riddle   indent();
2100372eb41SRiver Riddle   llvm::errs() << "Params: [";
2112f21a579SRiver Riddle   llvm::interleaveComma(node->getArgs(), llvm::errs(),
21222cfff70SRiver Riddle                         [](auto &arg) { llvm::errs() << arg->getName(); });
2130372eb41SRiver Riddle   llvm::errs() << "]\n";
2140372eb41SRiver Riddle }
2150372eb41SRiver Riddle 
2160372eb41SRiver Riddle /// Print a function, first the prototype and then the body.
dump(FunctionAST * node)21722cfff70SRiver Riddle void ASTDumper::dump(FunctionAST *node) {
2180372eb41SRiver Riddle   INDENT();
2190372eb41SRiver Riddle   llvm::errs() << "Function \n";
22022cfff70SRiver Riddle   dump(node->getProto());
22122cfff70SRiver Riddle   dump(node->getBody());
2220372eb41SRiver Riddle }
2230372eb41SRiver Riddle 
2240372eb41SRiver Riddle /// Print a module, actually loop over the functions and print them in sequence.
dump(ModuleAST * node)22522cfff70SRiver Riddle void ASTDumper::dump(ModuleAST *node) {
2260372eb41SRiver Riddle   INDENT();
2270372eb41SRiver Riddle   llvm::errs() << "Module:\n";
22822cfff70SRiver Riddle   for (auto &f : *node)
22922cfff70SRiver Riddle     dump(&f);
2300372eb41SRiver Riddle }
2310372eb41SRiver Riddle 
2320372eb41SRiver Riddle namespace toy {
2330372eb41SRiver Riddle 
2340372eb41SRiver Riddle // Public API
dump(ModuleAST & module)2350372eb41SRiver Riddle void dump(ModuleAST &module) { ASTDumper().dump(&module); }
2360372eb41SRiver Riddle 
2370372eb41SRiver Riddle } // namespace toy
238