1 //===- AST.h - Node definition for the Toy AST ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the AST for the Toy language. It is optimized for 10 // simplicity, not efficiency. The AST forms a tree structure where each node 11 // references its children using std::unique_ptr<>. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef TOY_AST_H 16 #define TOY_AST_H 17 18 #include "toy/Lexer.h" 19 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/StringRef.h" 22 #include "llvm/Support/Casting.h" 23 #include <utility> 24 #include <vector> 25 #include <optional> 26 27 namespace toy { 28 29 /// A variable type with either name or shape information. 30 struct VarType { 31 std::string name; 32 std::vector<int64_t> shape; 33 }; 34 35 /// Base class for all expression nodes. 36 class ExprAST { 37 public: 38 enum ExprASTKind { 39 Expr_VarDecl, 40 Expr_Return, 41 Expr_Num, 42 Expr_Literal, 43 Expr_StructLiteral, 44 Expr_Var, 45 Expr_BinOp, 46 Expr_Call, 47 Expr_Print, 48 }; 49 ExprAST(ExprASTKind kind,Location location)50 ExprAST(ExprASTKind kind, Location location) 51 : kind(kind), location(std::move(location)) {} 52 virtual ~ExprAST() = default; 53 getKind()54 ExprASTKind getKind() const { return kind; } 55 loc()56 const Location &loc() { return location; } 57 58 private: 59 const ExprASTKind kind; 60 Location location; 61 }; 62 63 /// A block-list of expressions. 64 using ExprASTList = std::vector<std::unique_ptr<ExprAST>>; 65 66 /// Expression class for numeric literals like "1.0". 67 class NumberExprAST : public ExprAST { 68 double val; 69 70 public: NumberExprAST(Location loc,double val)71 NumberExprAST(Location loc, double val) 72 : ExprAST(Expr_Num, std::move(loc)), val(val) {} 73 getValue()74 double getValue() { return val; } 75 76 /// LLVM style RTTI classof(const ExprAST * c)77 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } 78 }; 79 80 /// Expression class for a literal value. 81 class LiteralExprAST : public ExprAST { 82 std::vector<std::unique_ptr<ExprAST>> values; 83 std::vector<int64_t> dims; 84 85 public: LiteralExprAST(Location loc,std::vector<std::unique_ptr<ExprAST>> values,std::vector<int64_t> dims)86 LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values, 87 std::vector<int64_t> dims) 88 : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), 89 dims(std::move(dims)) {} 90 getValues()91 llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; } getDims()92 llvm::ArrayRef<int64_t> getDims() { return dims; } 93 94 /// LLVM style RTTI classof(const ExprAST * c)95 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } 96 }; 97 98 /// Expression class for a literal struct value. 99 class StructLiteralExprAST : public ExprAST { 100 std::vector<std::unique_ptr<ExprAST>> values; 101 102 public: StructLiteralExprAST(Location loc,std::vector<std::unique_ptr<ExprAST>> values)103 StructLiteralExprAST(Location loc, 104 std::vector<std::unique_ptr<ExprAST>> values) 105 : ExprAST(Expr_StructLiteral, std::move(loc)), values(std::move(values)) { 106 } 107 getValues()108 llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; } 109 110 /// LLVM style RTTI classof(const ExprAST * c)111 static bool classof(const ExprAST *c) { 112 return c->getKind() == Expr_StructLiteral; 113 } 114 }; 115 116 /// Expression class for referencing a variable, like "a". 117 class VariableExprAST : public ExprAST { 118 std::string name; 119 120 public: VariableExprAST(Location loc,llvm::StringRef name)121 VariableExprAST(Location loc, llvm::StringRef name) 122 : ExprAST(Expr_Var, std::move(loc)), name(name) {} 123 getName()124 llvm::StringRef getName() { return name; } 125 126 /// LLVM style RTTI classof(const ExprAST * c)127 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } 128 }; 129 130 /// Expression class for defining a variable. 131 class VarDeclExprAST : public ExprAST { 132 std::string name; 133 VarType type; 134 std::unique_ptr<ExprAST> initVal; 135 136 public: 137 VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, 138 std::unique_ptr<ExprAST> initVal = nullptr) ExprAST(Expr_VarDecl,std::move (loc))139 : ExprAST(Expr_VarDecl, std::move(loc)), name(name), 140 type(std::move(type)), initVal(std::move(initVal)) {} 141 getName()142 llvm::StringRef getName() { return name; } getInitVal()143 ExprAST *getInitVal() { return initVal.get(); } getType()144 const VarType &getType() { return type; } 145 146 /// LLVM style RTTI classof(const ExprAST * c)147 static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } 148 }; 149 150 /// Expression class for a return operator. 151 class ReturnExprAST : public ExprAST { 152 std::optional<std::unique_ptr<ExprAST>> expr; 153 154 public: ReturnExprAST(Location loc,std::optional<std::unique_ptr<ExprAST>> expr)155 ReturnExprAST(Location loc, std::optional<std::unique_ptr<ExprAST>> expr) 156 : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} 157 getExpr()158 std::optional<ExprAST *> getExpr() { 159 if (expr.has_value()) 160 return expr->get(); 161 return std::nullopt; 162 } 163 164 /// LLVM style RTTI classof(const ExprAST * c)165 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } 166 }; 167 168 /// Expression class for a binary operator. 169 class BinaryExprAST : public ExprAST { 170 char op; 171 std::unique_ptr<ExprAST> lhs, rhs; 172 173 public: getOp()174 char getOp() { return op; } getLHS()175 ExprAST *getLHS() { return lhs.get(); } getRHS()176 ExprAST *getRHS() { return rhs.get(); } 177 BinaryExprAST(Location loc,char op,std::unique_ptr<ExprAST> lhs,std::unique_ptr<ExprAST> rhs)178 BinaryExprAST(Location loc, char op, std::unique_ptr<ExprAST> lhs, 179 std::unique_ptr<ExprAST> rhs) 180 : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), 181 rhs(std::move(rhs)) {} 182 183 /// LLVM style RTTI classof(const ExprAST * c)184 static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } 185 }; 186 187 /// Expression class for function calls. 188 class CallExprAST : public ExprAST { 189 std::string callee; 190 std::vector<std::unique_ptr<ExprAST>> args; 191 192 public: CallExprAST(Location loc,const std::string & callee,std::vector<std::unique_ptr<ExprAST>> args)193 CallExprAST(Location loc, const std::string &callee, 194 std::vector<std::unique_ptr<ExprAST>> args) 195 : ExprAST(Expr_Call, std::move(loc)), callee(callee), 196 args(std::move(args)) {} 197 getCallee()198 llvm::StringRef getCallee() { return callee; } getArgs()199 llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; } 200 201 /// LLVM style RTTI classof(const ExprAST * c)202 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } 203 }; 204 205 /// Expression class for builtin print calls. 206 class PrintExprAST : public ExprAST { 207 std::unique_ptr<ExprAST> arg; 208 209 public: PrintExprAST(Location loc,std::unique_ptr<ExprAST> arg)210 PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg) 211 : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} 212 getArg()213 ExprAST *getArg() { return arg.get(); } 214 215 /// LLVM style RTTI classof(const ExprAST * c)216 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } 217 }; 218 219 /// This class represents the "prototype" for a function, which captures its 220 /// name, and its argument names (thus implicitly the number of arguments the 221 /// function takes). 222 class PrototypeAST { 223 Location location; 224 std::string name; 225 std::vector<std::unique_ptr<VarDeclExprAST>> args; 226 227 public: PrototypeAST(Location location,const std::string & name,std::vector<std::unique_ptr<VarDeclExprAST>> args)228 PrototypeAST(Location location, const std::string &name, 229 std::vector<std::unique_ptr<VarDeclExprAST>> args) 230 : location(std::move(location)), name(name), args(std::move(args)) {} 231 loc()232 const Location &loc() { return location; } getName()233 llvm::StringRef getName() const { return name; } getArgs()234 llvm::ArrayRef<std::unique_ptr<VarDeclExprAST>> getArgs() { return args; } 235 }; 236 237 /// This class represents a top level record in a module. 238 class RecordAST { 239 public: 240 enum RecordASTKind { 241 Record_Function, 242 Record_Struct, 243 }; 244 RecordAST(RecordASTKind kind)245 RecordAST(RecordASTKind kind) : kind(kind) {} 246 virtual ~RecordAST() = default; 247 getKind()248 RecordASTKind getKind() const { return kind; } 249 250 private: 251 const RecordASTKind kind; 252 }; 253 254 /// This class represents a function definition itself. 255 class FunctionAST : public RecordAST { 256 std::unique_ptr<PrototypeAST> proto; 257 std::unique_ptr<ExprASTList> body; 258 259 public: FunctionAST(std::unique_ptr<PrototypeAST> proto,std::unique_ptr<ExprASTList> body)260 FunctionAST(std::unique_ptr<PrototypeAST> proto, 261 std::unique_ptr<ExprASTList> body) 262 : RecordAST(Record_Function), proto(std::move(proto)), 263 body(std::move(body)) {} getProto()264 PrototypeAST *getProto() { return proto.get(); } getBody()265 ExprASTList *getBody() { return body.get(); } 266 267 /// LLVM style RTTI classof(const RecordAST * r)268 static bool classof(const RecordAST *r) { 269 return r->getKind() == Record_Function; 270 } 271 }; 272 273 /// This class represents a struct definition. 274 class StructAST : public RecordAST { 275 Location location; 276 std::string name; 277 std::vector<std::unique_ptr<VarDeclExprAST>> variables; 278 279 public: StructAST(Location location,const std::string & name,std::vector<std::unique_ptr<VarDeclExprAST>> variables)280 StructAST(Location location, const std::string &name, 281 std::vector<std::unique_ptr<VarDeclExprAST>> variables) 282 : RecordAST(Record_Struct), location(std::move(location)), name(name), 283 variables(std::move(variables)) {} 284 loc()285 const Location &loc() { return location; } getName()286 llvm::StringRef getName() const { return name; } getVariables()287 llvm::ArrayRef<std::unique_ptr<VarDeclExprAST>> getVariables() { 288 return variables; 289 } 290 291 /// LLVM style RTTI classof(const RecordAST * r)292 static bool classof(const RecordAST *r) { 293 return r->getKind() == Record_Struct; 294 } 295 }; 296 297 /// This class represents a list of functions to be processed together 298 class ModuleAST { 299 std::vector<std::unique_ptr<RecordAST>> records; 300 301 public: ModuleAST(std::vector<std::unique_ptr<RecordAST>> records)302 ModuleAST(std::vector<std::unique_ptr<RecordAST>> records) 303 : records(std::move(records)) {} 304 begin()305 auto begin() { return records.begin(); } end()306 auto end() { return records.end(); } 307 }; 308 309 void dump(ModuleAST &); 310 311 } // namespace toy 312 313 #endif // TOY_AST_H 314