xref: /llvm-project/llvm/examples/Kaleidoscope/Chapter6/toy.cpp (revision 81c0f3023fc38e3ea720045407a17f47653ea2ac)
1 #include "../include/KaleidoscopeJIT.h"
2 #include "llvm/ADT/APFloat.h"
3 #include "llvm/ADT/STLExtras.h"
4 #include "llvm/IR/BasicBlock.h"
5 #include "llvm/IR/Constants.h"
6 #include "llvm/IR/DerivedTypes.h"
7 #include "llvm/IR/Function.h"
8 #include "llvm/IR/IRBuilder.h"
9 #include "llvm/IR/Instructions.h"
10 #include "llvm/IR/LLVMContext.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/IR/PassManager.h"
13 #include "llvm/IR/Type.h"
14 #include "llvm/IR/Verifier.h"
15 #include "llvm/Passes/PassBuilder.h"
16 #include "llvm/Passes/StandardInstrumentations.h"
17 #include "llvm/Support/TargetSelect.h"
18 #include "llvm/Target/TargetMachine.h"
19 #include "llvm/Transforms/InstCombine/InstCombine.h"
20 #include "llvm/Transforms/Scalar.h"
21 #include "llvm/Transforms/Scalar/GVN.h"
22 #include "llvm/Transforms/Scalar/Reassociate.h"
23 #include "llvm/Transforms/Scalar/SimplifyCFG.h"
24 #include <algorithm>
25 #include <cassert>
26 #include <cctype>
27 #include <cstdint>
28 #include <cstdio>
29 #include <cstdlib>
30 #include <map>
31 #include <memory>
32 #include <string>
33 #include <vector>
34 
35 using namespace llvm;
36 using namespace llvm::orc;
37 
38 //===----------------------------------------------------------------------===//
39 // Lexer
40 //===----------------------------------------------------------------------===//
41 
42 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
43 // of these for known things.
44 enum Token {
45   tok_eof = -1,
46 
47   // commands
48   tok_def = -2,
49   tok_extern = -3,
50 
51   // primary
52   tok_identifier = -4,
53   tok_number = -5,
54 
55   // control
56   tok_if = -6,
57   tok_then = -7,
58   tok_else = -8,
59   tok_for = -9,
60   tok_in = -10,
61 
62   // operators
63   tok_binary = -11,
64   tok_unary = -12
65 };
66 
67 static std::string IdentifierStr; // Filled in if tok_identifier
68 static double NumVal;             // Filled in if tok_number
69 
70 /// gettok - Return the next token from standard input.
71 static int gettok() {
72   static int LastChar = ' ';
73 
74   // Skip any whitespace.
75   while (isspace(LastChar))
76     LastChar = getchar();
77 
78   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
79     IdentifierStr = LastChar;
80     while (isalnum((LastChar = getchar())))
81       IdentifierStr += LastChar;
82 
83     if (IdentifierStr == "def")
84       return tok_def;
85     if (IdentifierStr == "extern")
86       return tok_extern;
87     if (IdentifierStr == "if")
88       return tok_if;
89     if (IdentifierStr == "then")
90       return tok_then;
91     if (IdentifierStr == "else")
92       return tok_else;
93     if (IdentifierStr == "for")
94       return tok_for;
95     if (IdentifierStr == "in")
96       return tok_in;
97     if (IdentifierStr == "binary")
98       return tok_binary;
99     if (IdentifierStr == "unary")
100       return tok_unary;
101     return tok_identifier;
102   }
103 
104   if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
105     std::string NumStr;
106     do {
107       NumStr += LastChar;
108       LastChar = getchar();
109     } while (isdigit(LastChar) || LastChar == '.');
110 
111     NumVal = strtod(NumStr.c_str(), nullptr);
112     return tok_number;
113   }
114 
115   if (LastChar == '#') {
116     // Comment until end of line.
117     do
118       LastChar = getchar();
119     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
120 
121     if (LastChar != EOF)
122       return gettok();
123   }
124 
125   // Check for end of file.  Don't eat the EOF.
126   if (LastChar == EOF)
127     return tok_eof;
128 
129   // Otherwise, just return the character as its ascii value.
130   int ThisChar = LastChar;
131   LastChar = getchar();
132   return ThisChar;
133 }
134 
135 //===----------------------------------------------------------------------===//
136 // Abstract Syntax Tree (aka Parse Tree)
137 //===----------------------------------------------------------------------===//
138 
139 namespace {
140 
141 /// ExprAST - Base class for all expression nodes.
142 class ExprAST {
143 public:
144   virtual ~ExprAST() = default;
145 
146   virtual Value *codegen() = 0;
147 };
148 
149 /// NumberExprAST - Expression class for numeric literals like "1.0".
150 class NumberExprAST : public ExprAST {
151   double Val;
152 
153 public:
154   NumberExprAST(double Val) : Val(Val) {}
155 
156   Value *codegen() override;
157 };
158 
159 /// VariableExprAST - Expression class for referencing a variable, like "a".
160 class VariableExprAST : public ExprAST {
161   std::string Name;
162 
163 public:
164   VariableExprAST(const std::string &Name) : Name(Name) {}
165 
166   Value *codegen() override;
167 };
168 
169 /// UnaryExprAST - Expression class for a unary operator.
170 class UnaryExprAST : public ExprAST {
171   char Opcode;
172   std::unique_ptr<ExprAST> Operand;
173 
174 public:
175   UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
176       : Opcode(Opcode), Operand(std::move(Operand)) {}
177 
178   Value *codegen() override;
179 };
180 
181 /// BinaryExprAST - Expression class for a binary operator.
182 class BinaryExprAST : public ExprAST {
183   char Op;
184   std::unique_ptr<ExprAST> LHS, RHS;
185 
186 public:
187   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
188                 std::unique_ptr<ExprAST> RHS)
189       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
190 
191   Value *codegen() override;
192 };
193 
194 /// CallExprAST - Expression class for function calls.
195 class CallExprAST : public ExprAST {
196   std::string Callee;
197   std::vector<std::unique_ptr<ExprAST>> Args;
198 
199 public:
200   CallExprAST(const std::string &Callee,
201               std::vector<std::unique_ptr<ExprAST>> Args)
202       : Callee(Callee), Args(std::move(Args)) {}
203 
204   Value *codegen() override;
205 };
206 
207 /// IfExprAST - Expression class for if/then/else.
208 class IfExprAST : public ExprAST {
209   std::unique_ptr<ExprAST> Cond, Then, Else;
210 
211 public:
212   IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
213             std::unique_ptr<ExprAST> Else)
214       : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
215 
216   Value *codegen() override;
217 };
218 
219 /// ForExprAST - Expression class for for/in.
220 class ForExprAST : public ExprAST {
221   std::string VarName;
222   std::unique_ptr<ExprAST> Start, End, Step, Body;
223 
224 public:
225   ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
226              std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
227              std::unique_ptr<ExprAST> Body)
228       : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
229         Step(std::move(Step)), Body(std::move(Body)) {}
230 
231   Value *codegen() override;
232 };
233 
234 /// PrototypeAST - This class represents the "prototype" for a function,
235 /// which captures its name, and its argument names (thus implicitly the number
236 /// of arguments the function takes), as well as if it is an operator.
237 class PrototypeAST {
238   std::string Name;
239   std::vector<std::string> Args;
240   bool IsOperator;
241   unsigned Precedence; // Precedence if a binary op.
242 
243 public:
244   PrototypeAST(const std::string &Name, std::vector<std::string> Args,
245                bool IsOperator = false, unsigned Prec = 0)
246       : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
247         Precedence(Prec) {}
248 
249   Function *codegen();
250   const std::string &getName() const { return Name; }
251 
252   bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
253   bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
254 
255   char getOperatorName() const {
256     assert(isUnaryOp() || isBinaryOp());
257     return Name[Name.size() - 1];
258   }
259 
260   unsigned getBinaryPrecedence() const { return Precedence; }
261 };
262 
263 /// FunctionAST - This class represents a function definition itself.
264 class FunctionAST {
265   std::unique_ptr<PrototypeAST> Proto;
266   std::unique_ptr<ExprAST> Body;
267 
268 public:
269   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
270               std::unique_ptr<ExprAST> Body)
271       : Proto(std::move(Proto)), Body(std::move(Body)) {}
272 
273   Function *codegen();
274 };
275 
276 } // end anonymous namespace
277 
278 //===----------------------------------------------------------------------===//
279 // Parser
280 //===----------------------------------------------------------------------===//
281 
282 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
283 /// token the parser is looking at.  getNextToken reads another token from the
284 /// lexer and updates CurTok with its results.
285 static int CurTok;
286 static int getNextToken() { return CurTok = gettok(); }
287 
288 /// BinopPrecedence - This holds the precedence for each binary operator that is
289 /// defined.
290 static std::map<char, int> BinopPrecedence;
291 
292 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
293 static int GetTokPrecedence() {
294   if (!isascii(CurTok))
295     return -1;
296 
297   // Make sure it's a declared binop.
298   int TokPrec = BinopPrecedence[CurTok];
299   if (TokPrec <= 0)
300     return -1;
301   return TokPrec;
302 }
303 
304 /// Error* - These are little helper functions for error handling.
305 std::unique_ptr<ExprAST> LogError(const char *Str) {
306   fprintf(stderr, "Error: %s\n", Str);
307   return nullptr;
308 }
309 
310 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
311   LogError(Str);
312   return nullptr;
313 }
314 
315 static std::unique_ptr<ExprAST> ParseExpression();
316 
317 /// numberexpr ::= number
318 static std::unique_ptr<ExprAST> ParseNumberExpr() {
319   auto Result = std::make_unique<NumberExprAST>(NumVal);
320   getNextToken(); // consume the number
321   return std::move(Result);
322 }
323 
324 /// parenexpr ::= '(' expression ')'
325 static std::unique_ptr<ExprAST> ParseParenExpr() {
326   getNextToken(); // eat (.
327   auto V = ParseExpression();
328   if (!V)
329     return nullptr;
330 
331   if (CurTok != ')')
332     return LogError("expected ')'");
333   getNextToken(); // eat ).
334   return V;
335 }
336 
337 /// identifierexpr
338 ///   ::= identifier
339 ///   ::= identifier '(' expression* ')'
340 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
341   std::string IdName = IdentifierStr;
342 
343   getNextToken(); // eat identifier.
344 
345   if (CurTok != '(') // Simple variable ref.
346     return std::make_unique<VariableExprAST>(IdName);
347 
348   // Call.
349   getNextToken(); // eat (
350   std::vector<std::unique_ptr<ExprAST>> Args;
351   if (CurTok != ')') {
352     while (true) {
353       if (auto Arg = ParseExpression())
354         Args.push_back(std::move(Arg));
355       else
356         return nullptr;
357 
358       if (CurTok == ')')
359         break;
360 
361       if (CurTok != ',')
362         return LogError("Expected ')' or ',' in argument list");
363       getNextToken();
364     }
365   }
366 
367   // Eat the ')'.
368   getNextToken();
369 
370   return std::make_unique<CallExprAST>(IdName, std::move(Args));
371 }
372 
373 /// ifexpr ::= 'if' expression 'then' expression 'else' expression
374 static std::unique_ptr<ExprAST> ParseIfExpr() {
375   getNextToken(); // eat the if.
376 
377   // condition.
378   auto Cond = ParseExpression();
379   if (!Cond)
380     return nullptr;
381 
382   if (CurTok != tok_then)
383     return LogError("expected then");
384   getNextToken(); // eat the then
385 
386   auto Then = ParseExpression();
387   if (!Then)
388     return nullptr;
389 
390   if (CurTok != tok_else)
391     return LogError("expected else");
392 
393   getNextToken();
394 
395   auto Else = ParseExpression();
396   if (!Else)
397     return nullptr;
398 
399   return std::make_unique<IfExprAST>(std::move(Cond), std::move(Then),
400                                       std::move(Else));
401 }
402 
403 /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
404 static std::unique_ptr<ExprAST> ParseForExpr() {
405   getNextToken(); // eat the for.
406 
407   if (CurTok != tok_identifier)
408     return LogError("expected identifier after for");
409 
410   std::string IdName = IdentifierStr;
411   getNextToken(); // eat identifier.
412 
413   if (CurTok != '=')
414     return LogError("expected '=' after for");
415   getNextToken(); // eat '='.
416 
417   auto Start = ParseExpression();
418   if (!Start)
419     return nullptr;
420   if (CurTok != ',')
421     return LogError("expected ',' after for start value");
422   getNextToken();
423 
424   auto End = ParseExpression();
425   if (!End)
426     return nullptr;
427 
428   // The step value is optional.
429   std::unique_ptr<ExprAST> Step;
430   if (CurTok == ',') {
431     getNextToken();
432     Step = ParseExpression();
433     if (!Step)
434       return nullptr;
435   }
436 
437   if (CurTok != tok_in)
438     return LogError("expected 'in' after for");
439   getNextToken(); // eat 'in'.
440 
441   auto Body = ParseExpression();
442   if (!Body)
443     return nullptr;
444 
445   return std::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
446                                        std::move(Step), std::move(Body));
447 }
448 
449 /// primary
450 ///   ::= identifierexpr
451 ///   ::= numberexpr
452 ///   ::= parenexpr
453 ///   ::= ifexpr
454 ///   ::= forexpr
455 static std::unique_ptr<ExprAST> ParsePrimary() {
456   switch (CurTok) {
457   default:
458     return LogError("unknown token when expecting an expression");
459   case tok_identifier:
460     return ParseIdentifierExpr();
461   case tok_number:
462     return ParseNumberExpr();
463   case '(':
464     return ParseParenExpr();
465   case tok_if:
466     return ParseIfExpr();
467   case tok_for:
468     return ParseForExpr();
469   }
470 }
471 
472 /// unary
473 ///   ::= primary
474 ///   ::= '!' unary
475 static std::unique_ptr<ExprAST> ParseUnary() {
476   // If the current token is not an operator, it must be a primary expr.
477   if (!isascii(CurTok) || CurTok == '(' || CurTok == ',')
478     return ParsePrimary();
479 
480   // If this is a unary operator, read it.
481   int Opc = CurTok;
482   getNextToken();
483   if (auto Operand = ParseUnary())
484     return std::make_unique<UnaryExprAST>(Opc, std::move(Operand));
485   return nullptr;
486 }
487 
488 /// binoprhs
489 ///   ::= ('+' unary)*
490 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
491                                               std::unique_ptr<ExprAST> LHS) {
492   // If this is a binop, find its precedence.
493   while (true) {
494     int TokPrec = GetTokPrecedence();
495 
496     // If this is a binop that binds at least as tightly as the current binop,
497     // consume it, otherwise we are done.
498     if (TokPrec < ExprPrec)
499       return LHS;
500 
501     // Okay, we know this is a binop.
502     int BinOp = CurTok;
503     getNextToken(); // eat binop
504 
505     // Parse the unary expression after the binary operator.
506     auto RHS = ParseUnary();
507     if (!RHS)
508       return nullptr;
509 
510     // If BinOp binds less tightly with RHS than the operator after RHS, let
511     // the pending operator take RHS as its LHS.
512     int NextPrec = GetTokPrecedence();
513     if (TokPrec < NextPrec) {
514       RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
515       if (!RHS)
516         return nullptr;
517     }
518 
519     // Merge LHS/RHS.
520     LHS =
521         std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
522   }
523 }
524 
525 /// expression
526 ///   ::= unary binoprhs
527 ///
528 static std::unique_ptr<ExprAST> ParseExpression() {
529   auto LHS = ParseUnary();
530   if (!LHS)
531     return nullptr;
532 
533   return ParseBinOpRHS(0, std::move(LHS));
534 }
535 
536 /// prototype
537 ///   ::= id '(' id* ')'
538 ///   ::= binary LETTER number? (id, id)
539 ///   ::= unary LETTER (id)
540 static std::unique_ptr<PrototypeAST> ParsePrototype() {
541   std::string FnName;
542 
543   unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary.
544   unsigned BinaryPrecedence = 30;
545 
546   switch (CurTok) {
547   default:
548     return LogErrorP("Expected function name in prototype");
549   case tok_identifier:
550     FnName = IdentifierStr;
551     Kind = 0;
552     getNextToken();
553     break;
554   case tok_unary:
555     getNextToken();
556     if (!isascii(CurTok))
557       return LogErrorP("Expected unary operator");
558     FnName = "unary";
559     FnName += (char)CurTok;
560     Kind = 1;
561     getNextToken();
562     break;
563   case tok_binary:
564     getNextToken();
565     if (!isascii(CurTok))
566       return LogErrorP("Expected binary operator");
567     FnName = "binary";
568     FnName += (char)CurTok;
569     Kind = 2;
570     getNextToken();
571 
572     // Read the precedence if present.
573     if (CurTok == tok_number) {
574       if (NumVal < 1 || NumVal > 100)
575         return LogErrorP("Invalid precedence: must be 1..100");
576       BinaryPrecedence = (unsigned)NumVal;
577       getNextToken();
578     }
579     break;
580   }
581 
582   if (CurTok != '(')
583     return LogErrorP("Expected '(' in prototype");
584 
585   std::vector<std::string> ArgNames;
586   while (getNextToken() == tok_identifier)
587     ArgNames.push_back(IdentifierStr);
588   if (CurTok != ')')
589     return LogErrorP("Expected ')' in prototype");
590 
591   // success.
592   getNextToken(); // eat ')'.
593 
594   // Verify right number of names for operator.
595   if (Kind && ArgNames.size() != Kind)
596     return LogErrorP("Invalid number of operands for operator");
597 
598   return std::make_unique<PrototypeAST>(FnName, ArgNames, Kind != 0,
599                                          BinaryPrecedence);
600 }
601 
602 /// definition ::= 'def' prototype expression
603 static std::unique_ptr<FunctionAST> ParseDefinition() {
604   getNextToken(); // eat def.
605   auto Proto = ParsePrototype();
606   if (!Proto)
607     return nullptr;
608 
609   if (auto E = ParseExpression())
610     return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
611   return nullptr;
612 }
613 
614 /// toplevelexpr ::= expression
615 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
616   if (auto E = ParseExpression()) {
617     // Make an anonymous proto.
618     auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
619                                                  std::vector<std::string>());
620     return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
621   }
622   return nullptr;
623 }
624 
625 /// external ::= 'extern' prototype
626 static std::unique_ptr<PrototypeAST> ParseExtern() {
627   getNextToken(); // eat extern.
628   return ParsePrototype();
629 }
630 
631 //===----------------------------------------------------------------------===//
632 // Code Generation
633 //===----------------------------------------------------------------------===//
634 
635 static std::unique_ptr<LLVMContext> TheContext;
636 static std::unique_ptr<Module> TheModule;
637 static std::unique_ptr<IRBuilder<>> Builder;
638 static std::map<std::string, Value *> NamedValues;
639 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
640 static std::unique_ptr<FunctionPassManager> TheFPM;
641 static std::unique_ptr<LoopAnalysisManager> TheLAM;
642 static std::unique_ptr<FunctionAnalysisManager> TheFAM;
643 static std::unique_ptr<CGSCCAnalysisManager> TheCGAM;
644 static std::unique_ptr<ModuleAnalysisManager> TheMAM;
645 static std::unique_ptr<PassInstrumentationCallbacks> ThePIC;
646 static std::unique_ptr<StandardInstrumentations> TheSI;
647 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
648 static ExitOnError ExitOnErr;
649 
650 Value *LogErrorV(const char *Str) {
651   LogError(Str);
652   return nullptr;
653 }
654 
655 Function *getFunction(std::string Name) {
656   // First, see if the function has already been added to the current module.
657   if (auto *F = TheModule->getFunction(Name))
658     return F;
659 
660   // If not, check whether we can codegen the declaration from some existing
661   // prototype.
662   auto FI = FunctionProtos.find(Name);
663   if (FI != FunctionProtos.end())
664     return FI->second->codegen();
665 
666   // If no existing prototype exists, return null.
667   return nullptr;
668 }
669 
670 Value *NumberExprAST::codegen() {
671   return ConstantFP::get(*TheContext, APFloat(Val));
672 }
673 
674 Value *VariableExprAST::codegen() {
675   // Look this variable up in the function.
676   Value *V = NamedValues[Name];
677   if (!V)
678     return LogErrorV("Unknown variable name");
679   return V;
680 }
681 
682 Value *UnaryExprAST::codegen() {
683   Value *OperandV = Operand->codegen();
684   if (!OperandV)
685     return nullptr;
686 
687   Function *F = getFunction(std::string("unary") + Opcode);
688   if (!F)
689     return LogErrorV("Unknown unary operator");
690 
691   return Builder->CreateCall(F, OperandV, "unop");
692 }
693 
694 Value *BinaryExprAST::codegen() {
695   Value *L = LHS->codegen();
696   Value *R = RHS->codegen();
697   if (!L || !R)
698     return nullptr;
699 
700   switch (Op) {
701   case '+':
702     return Builder->CreateFAdd(L, R, "addtmp");
703   case '-':
704     return Builder->CreateFSub(L, R, "subtmp");
705   case '*':
706     return Builder->CreateFMul(L, R, "multmp");
707   case '<':
708     L = Builder->CreateFCmpULT(L, R, "cmptmp");
709     // Convert bool 0/1 to double 0.0 or 1.0
710     return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
711   default:
712     break;
713   }
714 
715   // If it wasn't a builtin binary operator, it must be a user defined one. Emit
716   // a call to it.
717   Function *F = getFunction(std::string("binary") + Op);
718   assert(F && "binary operator not found!");
719 
720   Value *Ops[] = {L, R};
721   return Builder->CreateCall(F, Ops, "binop");
722 }
723 
724 Value *CallExprAST::codegen() {
725   // Look up the name in the global module table.
726   Function *CalleeF = getFunction(Callee);
727   if (!CalleeF)
728     return LogErrorV("Unknown function referenced");
729 
730   // If argument mismatch error.
731   if (CalleeF->arg_size() != Args.size())
732     return LogErrorV("Incorrect # arguments passed");
733 
734   std::vector<Value *> ArgsV;
735   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
736     ArgsV.push_back(Args[i]->codegen());
737     if (!ArgsV.back())
738       return nullptr;
739   }
740 
741   return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
742 }
743 
744 Value *IfExprAST::codegen() {
745   Value *CondV = Cond->codegen();
746   if (!CondV)
747     return nullptr;
748 
749   // Convert condition to a bool by comparing non-equal to 0.0.
750   CondV = Builder->CreateFCmpONE(
751       CondV, ConstantFP::get(*TheContext, APFloat(0.0)), "ifcond");
752 
753   Function *TheFunction = Builder->GetInsertBlock()->getParent();
754 
755   // Create blocks for the then and else cases.  Insert the 'then' block at the
756   // end of the function.
757   BasicBlock *ThenBB = BasicBlock::Create(*TheContext, "then", TheFunction);
758   BasicBlock *ElseBB = BasicBlock::Create(*TheContext, "else");
759   BasicBlock *MergeBB = BasicBlock::Create(*TheContext, "ifcont");
760 
761   Builder->CreateCondBr(CondV, ThenBB, ElseBB);
762 
763   // Emit then value.
764   Builder->SetInsertPoint(ThenBB);
765 
766   Value *ThenV = Then->codegen();
767   if (!ThenV)
768     return nullptr;
769 
770   Builder->CreateBr(MergeBB);
771   // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
772   ThenBB = Builder->GetInsertBlock();
773 
774   // Emit else block.
775   TheFunction->insert(TheFunction->end(), ElseBB);
776   Builder->SetInsertPoint(ElseBB);
777 
778   Value *ElseV = Else->codegen();
779   if (!ElseV)
780     return nullptr;
781 
782   Builder->CreateBr(MergeBB);
783   // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
784   ElseBB = Builder->GetInsertBlock();
785 
786   // Emit merge block.
787   TheFunction->insert(TheFunction->end(), MergeBB);
788   Builder->SetInsertPoint(MergeBB);
789   PHINode *PN = Builder->CreatePHI(Type::getDoubleTy(*TheContext), 2, "iftmp");
790 
791   PN->addIncoming(ThenV, ThenBB);
792   PN->addIncoming(ElseV, ElseBB);
793   return PN;
794 }
795 
796 // Output for-loop as:
797 //   ...
798 //   start = startexpr
799 //   goto loop
800 // loop:
801 //   variable = phi [start, loopheader], [nextvariable, loopend]
802 //   ...
803 //   bodyexpr
804 //   ...
805 // loopend:
806 //   step = stepexpr
807 //   nextvariable = variable + step
808 //   endcond = endexpr
809 //   br endcond, loop, endloop
810 // outloop:
811 Value *ForExprAST::codegen() {
812   // Emit the start code first, without 'variable' in scope.
813   Value *StartVal = Start->codegen();
814   if (!StartVal)
815     return nullptr;
816 
817   // Make the new basic block for the loop header, inserting after current
818   // block.
819   Function *TheFunction = Builder->GetInsertBlock()->getParent();
820   BasicBlock *PreheaderBB = Builder->GetInsertBlock();
821   BasicBlock *LoopBB = BasicBlock::Create(*TheContext, "loop", TheFunction);
822 
823   // Insert an explicit fall through from the current block to the LoopBB.
824   Builder->CreateBr(LoopBB);
825 
826   // Start insertion in LoopBB.
827   Builder->SetInsertPoint(LoopBB);
828 
829   // Start the PHI node with an entry for Start.
830   PHINode *Variable =
831       Builder->CreatePHI(Type::getDoubleTy(*TheContext), 2, VarName);
832   Variable->addIncoming(StartVal, PreheaderBB);
833 
834   // Within the loop, the variable is defined equal to the PHI node.  If it
835   // shadows an existing variable, we have to restore it, so save it now.
836   Value *OldVal = NamedValues[VarName];
837   NamedValues[VarName] = Variable;
838 
839   // Emit the body of the loop.  This, like any other expr, can change the
840   // current BB.  Note that we ignore the value computed by the body, but don't
841   // allow an error.
842   if (!Body->codegen())
843     return nullptr;
844 
845   // Emit the step value.
846   Value *StepVal = nullptr;
847   if (Step) {
848     StepVal = Step->codegen();
849     if (!StepVal)
850       return nullptr;
851   } else {
852     // If not specified, use 1.0.
853     StepVal = ConstantFP::get(*TheContext, APFloat(1.0));
854   }
855 
856   Value *NextVar = Builder->CreateFAdd(Variable, StepVal, "nextvar");
857 
858   // Compute the end condition.
859   Value *EndCond = End->codegen();
860   if (!EndCond)
861     return nullptr;
862 
863   // Convert condition to a bool by comparing non-equal to 0.0.
864   EndCond = Builder->CreateFCmpONE(
865       EndCond, ConstantFP::get(*TheContext, APFloat(0.0)), "loopcond");
866 
867   // Create the "after loop" block and insert it.
868   BasicBlock *LoopEndBB = Builder->GetInsertBlock();
869   BasicBlock *AfterBB =
870       BasicBlock::Create(*TheContext, "afterloop", TheFunction);
871 
872   // Insert the conditional branch into the end of LoopEndBB.
873   Builder->CreateCondBr(EndCond, LoopBB, AfterBB);
874 
875   // Any new code will be inserted in AfterBB.
876   Builder->SetInsertPoint(AfterBB);
877 
878   // Add a new entry to the PHI node for the backedge.
879   Variable->addIncoming(NextVar, LoopEndBB);
880 
881   // Restore the unshadowed variable.
882   if (OldVal)
883     NamedValues[VarName] = OldVal;
884   else
885     NamedValues.erase(VarName);
886 
887   // for expr always returns 0.0.
888   return Constant::getNullValue(Type::getDoubleTy(*TheContext));
889 }
890 
891 Function *PrototypeAST::codegen() {
892   // Make the function type:  double(double,double) etc.
893   std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
894   FunctionType *FT =
895       FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);
896 
897   Function *F =
898       Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
899 
900   // Set names for all arguments.
901   unsigned Idx = 0;
902   for (auto &Arg : F->args())
903     Arg.setName(Args[Idx++]);
904 
905   return F;
906 }
907 
908 Function *FunctionAST::codegen() {
909   // Transfer ownership of the prototype to the FunctionProtos map, but keep a
910   // reference to it for use below.
911   auto &P = *Proto;
912   FunctionProtos[Proto->getName()] = std::move(Proto);
913   Function *TheFunction = getFunction(P.getName());
914   if (!TheFunction)
915     return nullptr;
916 
917   // If this is an operator, install it.
918   if (P.isBinaryOp())
919     BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
920 
921   // Create a new basic block to start insertion into.
922   BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
923   Builder->SetInsertPoint(BB);
924 
925   // Record the function arguments in the NamedValues map.
926   NamedValues.clear();
927   for (auto &Arg : TheFunction->args())
928     NamedValues[std::string(Arg.getName())] = &Arg;
929 
930   if (Value *RetVal = Body->codegen()) {
931     // Finish off the function.
932     Builder->CreateRet(RetVal);
933 
934     // Validate the generated code, checking for consistency.
935     verifyFunction(*TheFunction);
936 
937     // Run the optimizer on the function.
938     TheFPM->run(*TheFunction, *TheFAM);
939 
940     return TheFunction;
941   }
942 
943   // Error reading body, remove function.
944   TheFunction->eraseFromParent();
945 
946   if (P.isBinaryOp())
947     BinopPrecedence.erase(P.getOperatorName());
948   return nullptr;
949 }
950 
951 //===----------------------------------------------------------------------===//
952 // Top-Level parsing and JIT Driver
953 //===----------------------------------------------------------------------===//
954 
955 static void InitializeModuleAndManagers() {
956   // Open a new context and module.
957   TheContext = std::make_unique<LLVMContext>();
958   TheModule = std::make_unique<Module>("KaleidoscopeJIT", *TheContext);
959   TheModule->setDataLayout(TheJIT->getDataLayout());
960 
961   // Create a new builder for the module.
962   Builder = std::make_unique<IRBuilder<>>(*TheContext);
963 
964   // Create new pass and analysis managers.
965   TheFPM = std::make_unique<FunctionPassManager>();
966   TheLAM = std::make_unique<LoopAnalysisManager>();
967   TheFAM = std::make_unique<FunctionAnalysisManager>();
968   TheCGAM = std::make_unique<CGSCCAnalysisManager>();
969   TheMAM = std::make_unique<ModuleAnalysisManager>();
970   ThePIC = std::make_unique<PassInstrumentationCallbacks>();
971   TheSI = std::make_unique<StandardInstrumentations>(*TheContext,
972                                                      /*DebugLogging*/ true);
973   TheSI->registerCallbacks(*ThePIC, TheMAM.get());
974 
975   // Add transform passes.
976   // Do simple "peephole" optimizations and bit-twiddling optzns.
977   TheFPM->addPass(InstCombinePass());
978   // Reassociate expressions.
979   TheFPM->addPass(ReassociatePass());
980   // Eliminate Common SubExpressions.
981   TheFPM->addPass(GVNPass());
982   // Simplify the control flow graph (deleting unreachable blocks, etc).
983   TheFPM->addPass(SimplifyCFGPass());
984 
985   // Register analysis passes used in these transform passes.
986   PassBuilder PB;
987   PB.registerModuleAnalyses(*TheMAM);
988   PB.registerFunctionAnalyses(*TheFAM);
989   PB.crossRegisterProxies(*TheLAM, *TheFAM, *TheCGAM, *TheMAM);
990 }
991 
992 static void HandleDefinition() {
993   if (auto FnAST = ParseDefinition()) {
994     if (auto *FnIR = FnAST->codegen()) {
995       fprintf(stderr, "Read function definition:");
996       FnIR->print(errs());
997       fprintf(stderr, "\n");
998       ExitOnErr(TheJIT->addModule(
999           ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
1000       InitializeModuleAndManagers();
1001     }
1002   } else {
1003     // Skip token for error recovery.
1004     getNextToken();
1005   }
1006 }
1007 
1008 static void HandleExtern() {
1009   if (auto ProtoAST = ParseExtern()) {
1010     if (auto *FnIR = ProtoAST->codegen()) {
1011       fprintf(stderr, "Read extern: ");
1012       FnIR->print(errs());
1013       fprintf(stderr, "\n");
1014       FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
1015     }
1016   } else {
1017     // Skip token for error recovery.
1018     getNextToken();
1019   }
1020 }
1021 
1022 static void HandleTopLevelExpression() {
1023   // Evaluate a top-level expression into an anonymous function.
1024   if (auto FnAST = ParseTopLevelExpr()) {
1025     if (FnAST->codegen()) {
1026       // Create a ResourceTracker to track JIT'd memory allocated to our
1027       // anonymous expression -- that way we can free it after executing.
1028       auto RT = TheJIT->getMainJITDylib().createResourceTracker();
1029 
1030       auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
1031       ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
1032       InitializeModuleAndManagers();
1033 
1034       // Search the JIT for the __anon_expr symbol.
1035       auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));
1036 
1037       // Get the symbol's address and cast it to the right type (takes no
1038       // arguments, returns a double) so we can call it as a native function.
1039       double (*FP)() = ExprSymbol.toPtr<double (*)()>();
1040       fprintf(stderr, "Evaluated to %f\n", FP());
1041 
1042       // Delete the anonymous expression module from the JIT.
1043       ExitOnErr(RT->remove());
1044     }
1045   } else {
1046     // Skip token for error recovery.
1047     getNextToken();
1048   }
1049 }
1050 
1051 /// top ::= definition | external | expression | ';'
1052 static void MainLoop() {
1053   while (true) {
1054     fprintf(stderr, "ready> ");
1055     switch (CurTok) {
1056     case tok_eof:
1057       return;
1058     case ';': // ignore top-level semicolons.
1059       getNextToken();
1060       break;
1061     case tok_def:
1062       HandleDefinition();
1063       break;
1064     case tok_extern:
1065       HandleExtern();
1066       break;
1067     default:
1068       HandleTopLevelExpression();
1069       break;
1070     }
1071   }
1072 }
1073 
1074 //===----------------------------------------------------------------------===//
1075 // "Library" functions that can be "extern'd" from user code.
1076 //===----------------------------------------------------------------------===//
1077 
1078 #ifdef _WIN32
1079 #define DLLEXPORT __declspec(dllexport)
1080 #else
1081 #define DLLEXPORT
1082 #endif
1083 
1084 /// putchard - putchar that takes a double and returns 0.
1085 extern "C" DLLEXPORT double putchard(double X) {
1086   fputc((char)X, stderr);
1087   return 0;
1088 }
1089 
1090 /// printd - printf that takes a double prints it as "%f\n", returning 0.
1091 extern "C" DLLEXPORT double printd(double X) {
1092   fprintf(stderr, "%f\n", X);
1093   return 0;
1094 }
1095 
1096 //===----------------------------------------------------------------------===//
1097 // Main driver code.
1098 //===----------------------------------------------------------------------===//
1099 
1100 int main() {
1101   InitializeNativeTarget();
1102   InitializeNativeTargetAsmPrinter();
1103   InitializeNativeTargetAsmParser();
1104 
1105   // Install standard binary operators.
1106   // 1 is lowest precedence.
1107   BinopPrecedence['<'] = 10;
1108   BinopPrecedence['+'] = 20;
1109   BinopPrecedence['-'] = 20;
1110   BinopPrecedence['*'] = 40; // highest.
1111 
1112   // Prime the first token.
1113   fprintf(stderr, "ready> ");
1114   getNextToken();
1115 
1116   TheJIT = ExitOnErr(KaleidoscopeJIT::Create());
1117 
1118   InitializeModuleAndManagers();
1119 
1120   // Run the main "interpreter loop" now.
1121   MainLoop();
1122 
1123   return 0;
1124 }
1125