xref: /llvm-project/mlir/examples/toy/Ch7/include/toy/AST.h (revision 0a81ace0047a2de93e71c82cdf0977fc989660df)
1 //===- AST.h - Node definition for the Toy AST ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AST for the Toy language. It is optimized for
10 // simplicity, not efficiency. The AST forms a tree structure where each node
11 // references its children using std::unique_ptr<>.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef TOY_AST_H
16 #define TOY_AST_H
17 
18 #include "toy/Lexer.h"
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/Support/Casting.h"
23 #include <utility>
24 #include <vector>
25 #include <optional>
26 
27 namespace toy {
28 
29 /// A variable type with either name or shape information.
30 struct VarType {
31   std::string name;
32   std::vector<int64_t> shape;
33 };
34 
35 /// Base class for all expression nodes.
36 class ExprAST {
37 public:
38   enum ExprASTKind {
39     Expr_VarDecl,
40     Expr_Return,
41     Expr_Num,
42     Expr_Literal,
43     Expr_StructLiteral,
44     Expr_Var,
45     Expr_BinOp,
46     Expr_Call,
47     Expr_Print,
48   };
49 
ExprAST(ExprASTKind kind,Location location)50   ExprAST(ExprASTKind kind, Location location)
51       : kind(kind), location(std::move(location)) {}
52   virtual ~ExprAST() = default;
53 
getKind()54   ExprASTKind getKind() const { return kind; }
55 
loc()56   const Location &loc() { return location; }
57 
58 private:
59   const ExprASTKind kind;
60   Location location;
61 };
62 
63 /// A block-list of expressions.
64 using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
65 
66 /// Expression class for numeric literals like "1.0".
67 class NumberExprAST : public ExprAST {
68   double val;
69 
70 public:
NumberExprAST(Location loc,double val)71   NumberExprAST(Location loc, double val)
72       : ExprAST(Expr_Num, std::move(loc)), val(val) {}
73 
getValue()74   double getValue() { return val; }
75 
76   /// LLVM style RTTI
classof(const ExprAST * c)77   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
78 };
79 
80 /// Expression class for a literal value.
81 class LiteralExprAST : public ExprAST {
82   std::vector<std::unique_ptr<ExprAST>> values;
83   std::vector<int64_t> dims;
84 
85 public:
LiteralExprAST(Location loc,std::vector<std::unique_ptr<ExprAST>> values,std::vector<int64_t> dims)86   LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values,
87                  std::vector<int64_t> dims)
88       : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)),
89         dims(std::move(dims)) {}
90 
getValues()91   llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
getDims()92   llvm::ArrayRef<int64_t> getDims() { return dims; }
93 
94   /// LLVM style RTTI
classof(const ExprAST * c)95   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
96 };
97 
98 /// Expression class for a literal struct value.
99 class StructLiteralExprAST : public ExprAST {
100   std::vector<std::unique_ptr<ExprAST>> values;
101 
102 public:
StructLiteralExprAST(Location loc,std::vector<std::unique_ptr<ExprAST>> values)103   StructLiteralExprAST(Location loc,
104                        std::vector<std::unique_ptr<ExprAST>> values)
105       : ExprAST(Expr_StructLiteral, std::move(loc)), values(std::move(values)) {
106   }
107 
getValues()108   llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
109 
110   /// LLVM style RTTI
classof(const ExprAST * c)111   static bool classof(const ExprAST *c) {
112     return c->getKind() == Expr_StructLiteral;
113   }
114 };
115 
116 /// Expression class for referencing a variable, like "a".
117 class VariableExprAST : public ExprAST {
118   std::string name;
119 
120 public:
VariableExprAST(Location loc,llvm::StringRef name)121   VariableExprAST(Location loc, llvm::StringRef name)
122       : ExprAST(Expr_Var, std::move(loc)), name(name) {}
123 
getName()124   llvm::StringRef getName() { return name; }
125 
126   /// LLVM style RTTI
classof(const ExprAST * c)127   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
128 };
129 
130 /// Expression class for defining a variable.
131 class VarDeclExprAST : public ExprAST {
132   std::string name;
133   VarType type;
134   std::unique_ptr<ExprAST> initVal;
135 
136 public:
137   VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
138                  std::unique_ptr<ExprAST> initVal = nullptr)
ExprAST(Expr_VarDecl,std::move (loc))139       : ExprAST(Expr_VarDecl, std::move(loc)), name(name),
140         type(std::move(type)), initVal(std::move(initVal)) {}
141 
getName()142   llvm::StringRef getName() { return name; }
getInitVal()143   ExprAST *getInitVal() { return initVal.get(); }
getType()144   const VarType &getType() { return type; }
145 
146   /// LLVM style RTTI
classof(const ExprAST * c)147   static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
148 };
149 
150 /// Expression class for a return operator.
151 class ReturnExprAST : public ExprAST {
152   std::optional<std::unique_ptr<ExprAST>> expr;
153 
154 public:
ReturnExprAST(Location loc,std::optional<std::unique_ptr<ExprAST>> expr)155   ReturnExprAST(Location loc, std::optional<std::unique_ptr<ExprAST>> expr)
156       : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {}
157 
getExpr()158   std::optional<ExprAST *> getExpr() {
159     if (expr.has_value())
160       return expr->get();
161     return std::nullopt;
162   }
163 
164   /// LLVM style RTTI
classof(const ExprAST * c)165   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
166 };
167 
168 /// Expression class for a binary operator.
169 class BinaryExprAST : public ExprAST {
170   char op;
171   std::unique_ptr<ExprAST> lhs, rhs;
172 
173 public:
getOp()174   char getOp() { return op; }
getLHS()175   ExprAST *getLHS() { return lhs.get(); }
getRHS()176   ExprAST *getRHS() { return rhs.get(); }
177 
BinaryExprAST(Location loc,char op,std::unique_ptr<ExprAST> lhs,std::unique_ptr<ExprAST> rhs)178   BinaryExprAST(Location loc, char op, std::unique_ptr<ExprAST> lhs,
179                 std::unique_ptr<ExprAST> rhs)
180       : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)),
181         rhs(std::move(rhs)) {}
182 
183   /// LLVM style RTTI
classof(const ExprAST * c)184   static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
185 };
186 
187 /// Expression class for function calls.
188 class CallExprAST : public ExprAST {
189   std::string callee;
190   std::vector<std::unique_ptr<ExprAST>> args;
191 
192 public:
CallExprAST(Location loc,const std::string & callee,std::vector<std::unique_ptr<ExprAST>> args)193   CallExprAST(Location loc, const std::string &callee,
194               std::vector<std::unique_ptr<ExprAST>> args)
195       : ExprAST(Expr_Call, std::move(loc)), callee(callee),
196         args(std::move(args)) {}
197 
getCallee()198   llvm::StringRef getCallee() { return callee; }
getArgs()199   llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
200 
201   /// LLVM style RTTI
classof(const ExprAST * c)202   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
203 };
204 
205 /// Expression class for builtin print calls.
206 class PrintExprAST : public ExprAST {
207   std::unique_ptr<ExprAST> arg;
208 
209 public:
PrintExprAST(Location loc,std::unique_ptr<ExprAST> arg)210   PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
211       : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {}
212 
getArg()213   ExprAST *getArg() { return arg.get(); }
214 
215   /// LLVM style RTTI
classof(const ExprAST * c)216   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
217 };
218 
219 /// This class represents the "prototype" for a function, which captures its
220 /// name, and its argument names (thus implicitly the number of arguments the
221 /// function takes).
222 class PrototypeAST {
223   Location location;
224   std::string name;
225   std::vector<std::unique_ptr<VarDeclExprAST>> args;
226 
227 public:
PrototypeAST(Location location,const std::string & name,std::vector<std::unique_ptr<VarDeclExprAST>> args)228   PrototypeAST(Location location, const std::string &name,
229                std::vector<std::unique_ptr<VarDeclExprAST>> args)
230       : location(std::move(location)), name(name), args(std::move(args)) {}
231 
loc()232   const Location &loc() { return location; }
getName()233   llvm::StringRef getName() const { return name; }
getArgs()234   llvm::ArrayRef<std::unique_ptr<VarDeclExprAST>> getArgs() { return args; }
235 };
236 
237 /// This class represents a top level record in a module.
238 class RecordAST {
239 public:
240   enum RecordASTKind {
241     Record_Function,
242     Record_Struct,
243   };
244 
RecordAST(RecordASTKind kind)245   RecordAST(RecordASTKind kind) : kind(kind) {}
246   virtual ~RecordAST() = default;
247 
getKind()248   RecordASTKind getKind() const { return kind; }
249 
250 private:
251   const RecordASTKind kind;
252 };
253 
254 /// This class represents a function definition itself.
255 class FunctionAST : public RecordAST {
256   std::unique_ptr<PrototypeAST> proto;
257   std::unique_ptr<ExprASTList> body;
258 
259 public:
FunctionAST(std::unique_ptr<PrototypeAST> proto,std::unique_ptr<ExprASTList> body)260   FunctionAST(std::unique_ptr<PrototypeAST> proto,
261               std::unique_ptr<ExprASTList> body)
262       : RecordAST(Record_Function), proto(std::move(proto)),
263         body(std::move(body)) {}
getProto()264   PrototypeAST *getProto() { return proto.get(); }
getBody()265   ExprASTList *getBody() { return body.get(); }
266 
267   /// LLVM style RTTI
classof(const RecordAST * r)268   static bool classof(const RecordAST *r) {
269     return r->getKind() == Record_Function;
270   }
271 };
272 
273 /// This class represents a struct definition.
274 class StructAST : public RecordAST {
275   Location location;
276   std::string name;
277   std::vector<std::unique_ptr<VarDeclExprAST>> variables;
278 
279 public:
StructAST(Location location,const std::string & name,std::vector<std::unique_ptr<VarDeclExprAST>> variables)280   StructAST(Location location, const std::string &name,
281             std::vector<std::unique_ptr<VarDeclExprAST>> variables)
282       : RecordAST(Record_Struct), location(std::move(location)), name(name),
283         variables(std::move(variables)) {}
284 
loc()285   const Location &loc() { return location; }
getName()286   llvm::StringRef getName() const { return name; }
getVariables()287   llvm::ArrayRef<std::unique_ptr<VarDeclExprAST>> getVariables() {
288     return variables;
289   }
290 
291   /// LLVM style RTTI
classof(const RecordAST * r)292   static bool classof(const RecordAST *r) {
293     return r->getKind() == Record_Struct;
294   }
295 };
296 
297 /// This class represents a list of functions to be processed together
298 class ModuleAST {
299   std::vector<std::unique_ptr<RecordAST>> records;
300 
301 public:
ModuleAST(std::vector<std::unique_ptr<RecordAST>> records)302   ModuleAST(std::vector<std::unique_ptr<RecordAST>> records)
303       : records(std::move(records)) {}
304 
begin()305   auto begin() { return records.begin(); }
end()306   auto end() { return records.end(); }
307 };
308 
309 void dump(ModuleAST &);
310 
311 } // namespace toy
312 
313 #endif // TOY_AST_H
314