xref: /llvm-project/llvm/examples/Kaleidoscope/Chapter5/toy.cpp (revision 81c0f3023fc38e3ea720045407a17f47653ea2ac)
1 #include "../include/KaleidoscopeJIT.h"
2 #include "llvm/ADT/APFloat.h"
3 #include "llvm/ADT/STLExtras.h"
4 #include "llvm/IR/BasicBlock.h"
5 #include "llvm/IR/Constants.h"
6 #include "llvm/IR/DerivedTypes.h"
7 #include "llvm/IR/Function.h"
8 #include "llvm/IR/IRBuilder.h"
9 #include "llvm/IR/Instructions.h"
10 #include "llvm/IR/LLVMContext.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/IR/PassManager.h"
13 #include "llvm/IR/Type.h"
14 #include "llvm/IR/Verifier.h"
15 #include "llvm/Passes/PassBuilder.h"
16 #include "llvm/Passes/StandardInstrumentations.h"
17 #include "llvm/Support/TargetSelect.h"
18 #include "llvm/Target/TargetMachine.h"
19 #include "llvm/Transforms/InstCombine/InstCombine.h"
20 #include "llvm/Transforms/Scalar.h"
21 #include "llvm/Transforms/Scalar/GVN.h"
22 #include "llvm/Transforms/Scalar/Reassociate.h"
23 #include "llvm/Transforms/Scalar/SimplifyCFG.h"
24 #include <algorithm>
25 #include <cassert>
26 #include <cctype>
27 #include <cstdint>
28 #include <cstdio>
29 #include <cstdlib>
30 #include <map>
31 #include <memory>
32 #include <string>
33 #include <vector>
34 
35 using namespace llvm;
36 using namespace llvm::orc;
37 
38 //===----------------------------------------------------------------------===//
39 // Lexer
40 //===----------------------------------------------------------------------===//
41 
42 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
43 // of these for known things.
44 enum Token {
45   tok_eof = -1,
46 
47   // commands
48   tok_def = -2,
49   tok_extern = -3,
50 
51   // primary
52   tok_identifier = -4,
53   tok_number = -5,
54 
55   // control
56   tok_if = -6,
57   tok_then = -7,
58   tok_else = -8,
59   tok_for = -9,
60   tok_in = -10
61 };
62 
63 static std::string IdentifierStr; // Filled in if tok_identifier
64 static double NumVal;             // Filled in if tok_number
65 
66 /// gettok - Return the next token from standard input.
67 static int gettok() {
68   static int LastChar = ' ';
69 
70   // Skip any whitespace.
71   while (isspace(LastChar))
72     LastChar = getchar();
73 
74   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
75     IdentifierStr = LastChar;
76     while (isalnum((LastChar = getchar())))
77       IdentifierStr += LastChar;
78 
79     if (IdentifierStr == "def")
80       return tok_def;
81     if (IdentifierStr == "extern")
82       return tok_extern;
83     if (IdentifierStr == "if")
84       return tok_if;
85     if (IdentifierStr == "then")
86       return tok_then;
87     if (IdentifierStr == "else")
88       return tok_else;
89     if (IdentifierStr == "for")
90       return tok_for;
91     if (IdentifierStr == "in")
92       return tok_in;
93     return tok_identifier;
94   }
95 
96   if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
97     std::string NumStr;
98     do {
99       NumStr += LastChar;
100       LastChar = getchar();
101     } while (isdigit(LastChar) || LastChar == '.');
102 
103     NumVal = strtod(NumStr.c_str(), nullptr);
104     return tok_number;
105   }
106 
107   if (LastChar == '#') {
108     // Comment until end of line.
109     do
110       LastChar = getchar();
111     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
112 
113     if (LastChar != EOF)
114       return gettok();
115   }
116 
117   // Check for end of file.  Don't eat the EOF.
118   if (LastChar == EOF)
119     return tok_eof;
120 
121   // Otherwise, just return the character as its ascii value.
122   int ThisChar = LastChar;
123   LastChar = getchar();
124   return ThisChar;
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // Abstract Syntax Tree (aka Parse Tree)
129 //===----------------------------------------------------------------------===//
130 
131 namespace {
132 
133 /// ExprAST - Base class for all expression nodes.
134 class ExprAST {
135 public:
136   virtual ~ExprAST() = default;
137 
138   virtual Value *codegen() = 0;
139 };
140 
141 /// NumberExprAST - Expression class for numeric literals like "1.0".
142 class NumberExprAST : public ExprAST {
143   double Val;
144 
145 public:
146   NumberExprAST(double Val) : Val(Val) {}
147 
148   Value *codegen() override;
149 };
150 
151 /// VariableExprAST - Expression class for referencing a variable, like "a".
152 class VariableExprAST : public ExprAST {
153   std::string Name;
154 
155 public:
156   VariableExprAST(const std::string &Name) : Name(Name) {}
157 
158   Value *codegen() override;
159 };
160 
161 /// BinaryExprAST - Expression class for a binary operator.
162 class BinaryExprAST : public ExprAST {
163   char Op;
164   std::unique_ptr<ExprAST> LHS, RHS;
165 
166 public:
167   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
168                 std::unique_ptr<ExprAST> RHS)
169       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
170 
171   Value *codegen() override;
172 };
173 
174 /// CallExprAST - Expression class for function calls.
175 class CallExprAST : public ExprAST {
176   std::string Callee;
177   std::vector<std::unique_ptr<ExprAST>> Args;
178 
179 public:
180   CallExprAST(const std::string &Callee,
181               std::vector<std::unique_ptr<ExprAST>> Args)
182       : Callee(Callee), Args(std::move(Args)) {}
183 
184   Value *codegen() override;
185 };
186 
187 /// IfExprAST - Expression class for if/then/else.
188 class IfExprAST : public ExprAST {
189   std::unique_ptr<ExprAST> Cond, Then, Else;
190 
191 public:
192   IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
193             std::unique_ptr<ExprAST> Else)
194       : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
195 
196   Value *codegen() override;
197 };
198 
199 /// ForExprAST - Expression class for for/in.
200 class ForExprAST : public ExprAST {
201   std::string VarName;
202   std::unique_ptr<ExprAST> Start, End, Step, Body;
203 
204 public:
205   ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
206              std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
207              std::unique_ptr<ExprAST> Body)
208       : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
209         Step(std::move(Step)), Body(std::move(Body)) {}
210 
211   Value *codegen() override;
212 };
213 
214 /// PrototypeAST - This class represents the "prototype" for a function,
215 /// which captures its name, and its argument names (thus implicitly the number
216 /// of arguments the function takes).
217 class PrototypeAST {
218   std::string Name;
219   std::vector<std::string> Args;
220 
221 public:
222   PrototypeAST(const std::string &Name, std::vector<std::string> Args)
223       : Name(Name), Args(std::move(Args)) {}
224 
225   Function *codegen();
226   const std::string &getName() const { return Name; }
227 };
228 
229 /// FunctionAST - This class represents a function definition itself.
230 class FunctionAST {
231   std::unique_ptr<PrototypeAST> Proto;
232   std::unique_ptr<ExprAST> Body;
233 
234 public:
235   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
236               std::unique_ptr<ExprAST> Body)
237       : Proto(std::move(Proto)), Body(std::move(Body)) {}
238 
239   Function *codegen();
240 };
241 
242 } // end anonymous namespace
243 
244 //===----------------------------------------------------------------------===//
245 // Parser
246 //===----------------------------------------------------------------------===//
247 
248 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
249 /// token the parser is looking at.  getNextToken reads another token from the
250 /// lexer and updates CurTok with its results.
251 static int CurTok;
252 static int getNextToken() { return CurTok = gettok(); }
253 
254 /// BinopPrecedence - This holds the precedence for each binary operator that is
255 /// defined.
256 static std::map<char, int> BinopPrecedence;
257 
258 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
259 static int GetTokPrecedence() {
260   if (!isascii(CurTok))
261     return -1;
262 
263   // Make sure it's a declared binop.
264   int TokPrec = BinopPrecedence[CurTok];
265   if (TokPrec <= 0)
266     return -1;
267   return TokPrec;
268 }
269 
270 /// LogError* - These are little helper functions for error handling.
271 std::unique_ptr<ExprAST> LogError(const char *Str) {
272   fprintf(stderr, "Error: %s\n", Str);
273   return nullptr;
274 }
275 
276 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
277   LogError(Str);
278   return nullptr;
279 }
280 
281 static std::unique_ptr<ExprAST> ParseExpression();
282 
283 /// numberexpr ::= number
284 static std::unique_ptr<ExprAST> ParseNumberExpr() {
285   auto Result = std::make_unique<NumberExprAST>(NumVal);
286   getNextToken(); // consume the number
287   return std::move(Result);
288 }
289 
290 /// parenexpr ::= '(' expression ')'
291 static std::unique_ptr<ExprAST> ParseParenExpr() {
292   getNextToken(); // eat (.
293   auto V = ParseExpression();
294   if (!V)
295     return nullptr;
296 
297   if (CurTok != ')')
298     return LogError("expected ')'");
299   getNextToken(); // eat ).
300   return V;
301 }
302 
303 /// identifierexpr
304 ///   ::= identifier
305 ///   ::= identifier '(' expression* ')'
306 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
307   std::string IdName = IdentifierStr;
308 
309   getNextToken(); // eat identifier.
310 
311   if (CurTok != '(') // Simple variable ref.
312     return std::make_unique<VariableExprAST>(IdName);
313 
314   // Call.
315   getNextToken(); // eat (
316   std::vector<std::unique_ptr<ExprAST>> Args;
317   if (CurTok != ')') {
318     while (true) {
319       if (auto Arg = ParseExpression())
320         Args.push_back(std::move(Arg));
321       else
322         return nullptr;
323 
324       if (CurTok == ')')
325         break;
326 
327       if (CurTok != ',')
328         return LogError("Expected ')' or ',' in argument list");
329       getNextToken();
330     }
331   }
332 
333   // Eat the ')'.
334   getNextToken();
335 
336   return std::make_unique<CallExprAST>(IdName, std::move(Args));
337 }
338 
339 /// ifexpr ::= 'if' expression 'then' expression 'else' expression
340 static std::unique_ptr<ExprAST> ParseIfExpr() {
341   getNextToken(); // eat the if.
342 
343   // condition.
344   auto Cond = ParseExpression();
345   if (!Cond)
346     return nullptr;
347 
348   if (CurTok != tok_then)
349     return LogError("expected then");
350   getNextToken(); // eat the then
351 
352   auto Then = ParseExpression();
353   if (!Then)
354     return nullptr;
355 
356   if (CurTok != tok_else)
357     return LogError("expected else");
358 
359   getNextToken();
360 
361   auto Else = ParseExpression();
362   if (!Else)
363     return nullptr;
364 
365   return std::make_unique<IfExprAST>(std::move(Cond), std::move(Then),
366                                       std::move(Else));
367 }
368 
369 /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
370 static std::unique_ptr<ExprAST> ParseForExpr() {
371   getNextToken(); // eat the for.
372 
373   if (CurTok != tok_identifier)
374     return LogError("expected identifier after for");
375 
376   std::string IdName = IdentifierStr;
377   getNextToken(); // eat identifier.
378 
379   if (CurTok != '=')
380     return LogError("expected '=' after for");
381   getNextToken(); // eat '='.
382 
383   auto Start = ParseExpression();
384   if (!Start)
385     return nullptr;
386   if (CurTok != ',')
387     return LogError("expected ',' after for start value");
388   getNextToken();
389 
390   auto End = ParseExpression();
391   if (!End)
392     return nullptr;
393 
394   // The step value is optional.
395   std::unique_ptr<ExprAST> Step;
396   if (CurTok == ',') {
397     getNextToken();
398     Step = ParseExpression();
399     if (!Step)
400       return nullptr;
401   }
402 
403   if (CurTok != tok_in)
404     return LogError("expected 'in' after for");
405   getNextToken(); // eat 'in'.
406 
407   auto Body = ParseExpression();
408   if (!Body)
409     return nullptr;
410 
411   return std::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
412                                        std::move(Step), std::move(Body));
413 }
414 
415 /// primary
416 ///   ::= identifierexpr
417 ///   ::= numberexpr
418 ///   ::= parenexpr
419 ///   ::= ifexpr
420 ///   ::= forexpr
421 static std::unique_ptr<ExprAST> ParsePrimary() {
422   switch (CurTok) {
423   default:
424     return LogError("unknown token when expecting an expression");
425   case tok_identifier:
426     return ParseIdentifierExpr();
427   case tok_number:
428     return ParseNumberExpr();
429   case '(':
430     return ParseParenExpr();
431   case tok_if:
432     return ParseIfExpr();
433   case tok_for:
434     return ParseForExpr();
435   }
436 }
437 
438 /// binoprhs
439 ///   ::= ('+' primary)*
440 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
441                                               std::unique_ptr<ExprAST> LHS) {
442   // If this is a binop, find its precedence.
443   while (true) {
444     int TokPrec = GetTokPrecedence();
445 
446     // If this is a binop that binds at least as tightly as the current binop,
447     // consume it, otherwise we are done.
448     if (TokPrec < ExprPrec)
449       return LHS;
450 
451     // Okay, we know this is a binop.
452     int BinOp = CurTok;
453     getNextToken(); // eat binop
454 
455     // Parse the primary expression after the binary operator.
456     auto RHS = ParsePrimary();
457     if (!RHS)
458       return nullptr;
459 
460     // If BinOp binds less tightly with RHS than the operator after RHS, let
461     // the pending operator take RHS as its LHS.
462     int NextPrec = GetTokPrecedence();
463     if (TokPrec < NextPrec) {
464       RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
465       if (!RHS)
466         return nullptr;
467     }
468 
469     // Merge LHS/RHS.
470     LHS =
471         std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
472   }
473 }
474 
475 /// expression
476 ///   ::= primary binoprhs
477 ///
478 static std::unique_ptr<ExprAST> ParseExpression() {
479   auto LHS = ParsePrimary();
480   if (!LHS)
481     return nullptr;
482 
483   return ParseBinOpRHS(0, std::move(LHS));
484 }
485 
486 /// prototype
487 ///   ::= id '(' id* ')'
488 static std::unique_ptr<PrototypeAST> ParsePrototype() {
489   if (CurTok != tok_identifier)
490     return LogErrorP("Expected function name in prototype");
491 
492   std::string FnName = IdentifierStr;
493   getNextToken();
494 
495   if (CurTok != '(')
496     return LogErrorP("Expected '(' in prototype");
497 
498   std::vector<std::string> ArgNames;
499   while (getNextToken() == tok_identifier)
500     ArgNames.push_back(IdentifierStr);
501   if (CurTok != ')')
502     return LogErrorP("Expected ')' in prototype");
503 
504   // success.
505   getNextToken(); // eat ')'.
506 
507   return std::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
508 }
509 
510 /// definition ::= 'def' prototype expression
511 static std::unique_ptr<FunctionAST> ParseDefinition() {
512   getNextToken(); // eat def.
513   auto Proto = ParsePrototype();
514   if (!Proto)
515     return nullptr;
516 
517   if (auto E = ParseExpression())
518     return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
519   return nullptr;
520 }
521 
522 /// toplevelexpr ::= expression
523 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
524   if (auto E = ParseExpression()) {
525     // Make an anonymous proto.
526     auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
527                                                  std::vector<std::string>());
528     return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
529   }
530   return nullptr;
531 }
532 
533 /// external ::= 'extern' prototype
534 static std::unique_ptr<PrototypeAST> ParseExtern() {
535   getNextToken(); // eat extern.
536   return ParsePrototype();
537 }
538 
539 //===----------------------------------------------------------------------===//
540 // Code Generation
541 //===----------------------------------------------------------------------===//
542 
543 static std::unique_ptr<LLVMContext> TheContext;
544 static std::unique_ptr<Module> TheModule;
545 static std::unique_ptr<IRBuilder<>> Builder;
546 static std::map<std::string, Value *> NamedValues;
547 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
548 static std::unique_ptr<FunctionPassManager> TheFPM;
549 static std::unique_ptr<LoopAnalysisManager> TheLAM;
550 static std::unique_ptr<FunctionAnalysisManager> TheFAM;
551 static std::unique_ptr<CGSCCAnalysisManager> TheCGAM;
552 static std::unique_ptr<ModuleAnalysisManager> TheMAM;
553 static std::unique_ptr<PassInstrumentationCallbacks> ThePIC;
554 static std::unique_ptr<StandardInstrumentations> TheSI;
555 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
556 static ExitOnError ExitOnErr;
557 
558 Value *LogErrorV(const char *Str) {
559   LogError(Str);
560   return nullptr;
561 }
562 
563 Function *getFunction(std::string Name) {
564   // First, see if the function has already been added to the current module.
565   if (auto *F = TheModule->getFunction(Name))
566     return F;
567 
568   // If not, check whether we can codegen the declaration from some existing
569   // prototype.
570   auto FI = FunctionProtos.find(Name);
571   if (FI != FunctionProtos.end())
572     return FI->second->codegen();
573 
574   // If no existing prototype exists, return null.
575   return nullptr;
576 }
577 
578 Value *NumberExprAST::codegen() {
579   return ConstantFP::get(*TheContext, APFloat(Val));
580 }
581 
582 Value *VariableExprAST::codegen() {
583   // Look this variable up in the function.
584   Value *V = NamedValues[Name];
585   if (!V)
586     return LogErrorV("Unknown variable name");
587   return V;
588 }
589 
590 Value *BinaryExprAST::codegen() {
591   Value *L = LHS->codegen();
592   Value *R = RHS->codegen();
593   if (!L || !R)
594     return nullptr;
595 
596   switch (Op) {
597   case '+':
598     return Builder->CreateFAdd(L, R, "addtmp");
599   case '-':
600     return Builder->CreateFSub(L, R, "subtmp");
601   case '*':
602     return Builder->CreateFMul(L, R, "multmp");
603   case '<':
604     L = Builder->CreateFCmpULT(L, R, "cmptmp");
605     // Convert bool 0/1 to double 0.0 or 1.0
606     return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
607   default:
608     return LogErrorV("invalid binary operator");
609   }
610 }
611 
612 Value *CallExprAST::codegen() {
613   // Look up the name in the global module table.
614   Function *CalleeF = getFunction(Callee);
615   if (!CalleeF)
616     return LogErrorV("Unknown function referenced");
617 
618   // If argument mismatch error.
619   if (CalleeF->arg_size() != Args.size())
620     return LogErrorV("Incorrect # arguments passed");
621 
622   std::vector<Value *> ArgsV;
623   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
624     ArgsV.push_back(Args[i]->codegen());
625     if (!ArgsV.back())
626       return nullptr;
627   }
628 
629   return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
630 }
631 
632 Value *IfExprAST::codegen() {
633   Value *CondV = Cond->codegen();
634   if (!CondV)
635     return nullptr;
636 
637   // Convert condition to a bool by comparing non-equal to 0.0.
638   CondV = Builder->CreateFCmpONE(
639       CondV, ConstantFP::get(*TheContext, APFloat(0.0)), "ifcond");
640 
641   Function *TheFunction = Builder->GetInsertBlock()->getParent();
642 
643   // Create blocks for the then and else cases.  Insert the 'then' block at the
644   // end of the function.
645   BasicBlock *ThenBB = BasicBlock::Create(*TheContext, "then", TheFunction);
646   BasicBlock *ElseBB = BasicBlock::Create(*TheContext, "else");
647   BasicBlock *MergeBB = BasicBlock::Create(*TheContext, "ifcont");
648 
649   Builder->CreateCondBr(CondV, ThenBB, ElseBB);
650 
651   // Emit then value.
652   Builder->SetInsertPoint(ThenBB);
653 
654   Value *ThenV = Then->codegen();
655   if (!ThenV)
656     return nullptr;
657 
658   Builder->CreateBr(MergeBB);
659   // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
660   ThenBB = Builder->GetInsertBlock();
661 
662   // Emit else block.
663   TheFunction->insert(TheFunction->end(), ElseBB);
664   Builder->SetInsertPoint(ElseBB);
665 
666   Value *ElseV = Else->codegen();
667   if (!ElseV)
668     return nullptr;
669 
670   Builder->CreateBr(MergeBB);
671   // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
672   ElseBB = Builder->GetInsertBlock();
673 
674   // Emit merge block.
675   TheFunction->insert(TheFunction->end(), MergeBB);
676   Builder->SetInsertPoint(MergeBB);
677   PHINode *PN = Builder->CreatePHI(Type::getDoubleTy(*TheContext), 2, "iftmp");
678 
679   PN->addIncoming(ThenV, ThenBB);
680   PN->addIncoming(ElseV, ElseBB);
681   return PN;
682 }
683 
684 // Output for-loop as:
685 //   ...
686 //   start = startexpr
687 //   goto loop
688 // loop:
689 //   variable = phi [start, loopheader], [nextvariable, loopend]
690 //   ...
691 //   bodyexpr
692 //   ...
693 // loopend:
694 //   step = stepexpr
695 //   nextvariable = variable + step
696 //   endcond = endexpr
697 //   br endcond, loop, endloop
698 // outloop:
699 Value *ForExprAST::codegen() {
700   // Emit the start code first, without 'variable' in scope.
701   Value *StartVal = Start->codegen();
702   if (!StartVal)
703     return nullptr;
704 
705   // Make the new basic block for the loop header, inserting after current
706   // block.
707   Function *TheFunction = Builder->GetInsertBlock()->getParent();
708   BasicBlock *PreheaderBB = Builder->GetInsertBlock();
709   BasicBlock *LoopBB = BasicBlock::Create(*TheContext, "loop", TheFunction);
710 
711   // Insert an explicit fall through from the current block to the LoopBB.
712   Builder->CreateBr(LoopBB);
713 
714   // Start insertion in LoopBB.
715   Builder->SetInsertPoint(LoopBB);
716 
717   // Start the PHI node with an entry for Start.
718   PHINode *Variable =
719       Builder->CreatePHI(Type::getDoubleTy(*TheContext), 2, VarName);
720   Variable->addIncoming(StartVal, PreheaderBB);
721 
722   // Within the loop, the variable is defined equal to the PHI node.  If it
723   // shadows an existing variable, we have to restore it, so save it now.
724   Value *OldVal = NamedValues[VarName];
725   NamedValues[VarName] = Variable;
726 
727   // Emit the body of the loop.  This, like any other expr, can change the
728   // current BB.  Note that we ignore the value computed by the body, but don't
729   // allow an error.
730   if (!Body->codegen())
731     return nullptr;
732 
733   // Emit the step value.
734   Value *StepVal = nullptr;
735   if (Step) {
736     StepVal = Step->codegen();
737     if (!StepVal)
738       return nullptr;
739   } else {
740     // If not specified, use 1.0.
741     StepVal = ConstantFP::get(*TheContext, APFloat(1.0));
742   }
743 
744   Value *NextVar = Builder->CreateFAdd(Variable, StepVal, "nextvar");
745 
746   // Compute the end condition.
747   Value *EndCond = End->codegen();
748   if (!EndCond)
749     return nullptr;
750 
751   // Convert condition to a bool by comparing non-equal to 0.0.
752   EndCond = Builder->CreateFCmpONE(
753       EndCond, ConstantFP::get(*TheContext, APFloat(0.0)), "loopcond");
754 
755   // Create the "after loop" block and insert it.
756   BasicBlock *LoopEndBB = Builder->GetInsertBlock();
757   BasicBlock *AfterBB =
758       BasicBlock::Create(*TheContext, "afterloop", TheFunction);
759 
760   // Insert the conditional branch into the end of LoopEndBB.
761   Builder->CreateCondBr(EndCond, LoopBB, AfterBB);
762 
763   // Any new code will be inserted in AfterBB.
764   Builder->SetInsertPoint(AfterBB);
765 
766   // Add a new entry to the PHI node for the backedge.
767   Variable->addIncoming(NextVar, LoopEndBB);
768 
769   // Restore the unshadowed variable.
770   if (OldVal)
771     NamedValues[VarName] = OldVal;
772   else
773     NamedValues.erase(VarName);
774 
775   // for expr always returns 0.0.
776   return Constant::getNullValue(Type::getDoubleTy(*TheContext));
777 }
778 
779 Function *PrototypeAST::codegen() {
780   // Make the function type:  double(double,double) etc.
781   std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
782   FunctionType *FT =
783       FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);
784 
785   Function *F =
786       Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
787 
788   // Set names for all arguments.
789   unsigned Idx = 0;
790   for (auto &Arg : F->args())
791     Arg.setName(Args[Idx++]);
792 
793   return F;
794 }
795 
796 Function *FunctionAST::codegen() {
797   // Transfer ownership of the prototype to the FunctionProtos map, but keep a
798   // reference to it for use below.
799   auto &P = *Proto;
800   FunctionProtos[Proto->getName()] = std::move(Proto);
801   Function *TheFunction = getFunction(P.getName());
802   if (!TheFunction)
803     return nullptr;
804 
805   // Create a new basic block to start insertion into.
806   BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
807   Builder->SetInsertPoint(BB);
808 
809   // Record the function arguments in the NamedValues map.
810   NamedValues.clear();
811   for (auto &Arg : TheFunction->args())
812     NamedValues[std::string(Arg.getName())] = &Arg;
813 
814   if (Value *RetVal = Body->codegen()) {
815     // Finish off the function.
816     Builder->CreateRet(RetVal);
817 
818     // Validate the generated code, checking for consistency.
819     verifyFunction(*TheFunction);
820 
821     // Run the optimizer on the function.
822     TheFPM->run(*TheFunction, *TheFAM);
823 
824     return TheFunction;
825   }
826 
827   // Error reading body, remove function.
828   TheFunction->eraseFromParent();
829   return nullptr;
830 }
831 
832 //===----------------------------------------------------------------------===//
833 // Top-Level parsing and JIT Driver
834 //===----------------------------------------------------------------------===//
835 
836 static void InitializeModuleAndManagers() {
837   // Open a new context and module.
838   TheContext = std::make_unique<LLVMContext>();
839   TheModule = std::make_unique<Module>("KaleidoscopeJIT", *TheContext);
840   TheModule->setDataLayout(TheJIT->getDataLayout());
841 
842   // Create a new builder for the module.
843   Builder = std::make_unique<IRBuilder<>>(*TheContext);
844 
845   // Create new pass and analysis managers.
846   TheFPM = std::make_unique<FunctionPassManager>();
847   TheLAM = std::make_unique<LoopAnalysisManager>();
848   TheFAM = std::make_unique<FunctionAnalysisManager>();
849   TheCGAM = std::make_unique<CGSCCAnalysisManager>();
850   TheMAM = std::make_unique<ModuleAnalysisManager>();
851   ThePIC = std::make_unique<PassInstrumentationCallbacks>();
852   TheSI = std::make_unique<StandardInstrumentations>(*TheContext,
853                                                      /*DebugLogging*/ true);
854   TheSI->registerCallbacks(*ThePIC, TheMAM.get());
855 
856   // Add transform passes.
857   // Do simple "peephole" optimizations and bit-twiddling optzns.
858   TheFPM->addPass(InstCombinePass());
859   // Reassociate expressions.
860   TheFPM->addPass(ReassociatePass());
861   // Eliminate Common SubExpressions.
862   TheFPM->addPass(GVNPass());
863   // Simplify the control flow graph (deleting unreachable blocks, etc).
864   TheFPM->addPass(SimplifyCFGPass());
865 
866   // Register analysis passes used in these transform passes.
867   PassBuilder PB;
868   PB.registerModuleAnalyses(*TheMAM);
869   PB.registerFunctionAnalyses(*TheFAM);
870   PB.crossRegisterProxies(*TheLAM, *TheFAM, *TheCGAM, *TheMAM);
871 }
872 
873 static void HandleDefinition() {
874   if (auto FnAST = ParseDefinition()) {
875     if (auto *FnIR = FnAST->codegen()) {
876       fprintf(stderr, "Read function definition:");
877       FnIR->print(errs());
878       fprintf(stderr, "\n");
879       ExitOnErr(TheJIT->addModule(
880           ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
881       InitializeModuleAndManagers();
882     }
883   } else {
884     // Skip token for error recovery.
885     getNextToken();
886   }
887 }
888 
889 static void HandleExtern() {
890   if (auto ProtoAST = ParseExtern()) {
891     if (auto *FnIR = ProtoAST->codegen()) {
892       fprintf(stderr, "Read extern: ");
893       FnIR->print(errs());
894       fprintf(stderr, "\n");
895       FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
896     }
897   } else {
898     // Skip token for error recovery.
899     getNextToken();
900   }
901 }
902 
903 static void HandleTopLevelExpression() {
904   // Evaluate a top-level expression into an anonymous function.
905   if (auto FnAST = ParseTopLevelExpr()) {
906     if (FnAST->codegen()) {
907       // Create a ResourceTracker to track JIT'd memory allocated to our
908       // anonymous expression -- that way we can free it after executing.
909       auto RT = TheJIT->getMainJITDylib().createResourceTracker();
910 
911       auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
912       ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
913       InitializeModuleAndManagers();
914 
915       // Search the JIT for the __anon_expr symbol.
916       auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));
917 
918       // Get the symbol's address and cast it to the right type (takes no
919       // arguments, returns a double) so we can call it as a native function.
920       double (*FP)() = ExprSymbol.toPtr<double (*)()>();
921       fprintf(stderr, "Evaluated to %f\n", FP());
922 
923       // Delete the anonymous expression module from the JIT.
924       ExitOnErr(RT->remove());
925     }
926   } else {
927     // Skip token for error recovery.
928     getNextToken();
929   }
930 }
931 
932 /// top ::= definition | external | expression | ';'
933 static void MainLoop() {
934   while (true) {
935     fprintf(stderr, "ready> ");
936     switch (CurTok) {
937     case tok_eof:
938       return;
939     case ';': // ignore top-level semicolons.
940       getNextToken();
941       break;
942     case tok_def:
943       HandleDefinition();
944       break;
945     case tok_extern:
946       HandleExtern();
947       break;
948     default:
949       HandleTopLevelExpression();
950       break;
951     }
952   }
953 }
954 
955 //===----------------------------------------------------------------------===//
956 // "Library" functions that can be "extern'd" from user code.
957 //===----------------------------------------------------------------------===//
958 
959 #ifdef _WIN32
960 #define DLLEXPORT __declspec(dllexport)
961 #else
962 #define DLLEXPORT
963 #endif
964 
965 /// putchard - putchar that takes a double and returns 0.
966 extern "C" DLLEXPORT double putchard(double X) {
967   fputc((char)X, stderr);
968   return 0;
969 }
970 
971 /// printd - printf that takes a double prints it as "%f\n", returning 0.
972 extern "C" DLLEXPORT double printd(double X) {
973   fprintf(stderr, "%f\n", X);
974   return 0;
975 }
976 
977 //===----------------------------------------------------------------------===//
978 // Main driver code.
979 //===----------------------------------------------------------------------===//
980 
981 int main() {
982   InitializeNativeTarget();
983   InitializeNativeTargetAsmPrinter();
984   InitializeNativeTargetAsmParser();
985 
986   // Install standard binary operators.
987   // 1 is lowest precedence.
988   BinopPrecedence['<'] = 10;
989   BinopPrecedence['+'] = 20;
990   BinopPrecedence['-'] = 20;
991   BinopPrecedence['*'] = 40; // highest.
992 
993   // Prime the first token.
994   fprintf(stderr, "ready> ");
995   getNextToken();
996 
997   TheJIT = ExitOnErr(KaleidoscopeJIT::Create());
998 
999   InitializeModuleAndManagers();
1000 
1001   // Run the main "interpreter loop" now.
1002   MainLoop();
1003 
1004   return 0;
1005 }
1006