xref: /llvm-project/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a simple IR generation targeting MLIR from a Module AST
10 // for the Toy language.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "toy/MLIRGen.h"
15 #include "mlir/IR/Block.h"
16 #include "mlir/IR/Diagnostics.h"
17 #include "mlir/IR/Value.h"
18 #include "toy/AST.h"
19 #include "toy/Dialect.h"
20 
21 #include "mlir/IR/Attributes.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/IR/Verifier.h"
27 #include "toy/Lexer.h"
28 
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/ScopedHashTable.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/StringMap.h"
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/Twine.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include <cassert>
37 #include <cstddef>
38 #include <cstdint>
39 #include <functional>
40 #include <numeric>
41 #include <optional>
42 #include <tuple>
43 #include <utility>
44 #include <vector>
45 
46 using namespace mlir::toy;
47 using namespace toy;
48 
49 using llvm::ArrayRef;
50 using llvm::cast;
51 using llvm::dyn_cast;
52 using llvm::isa;
53 using llvm::ScopedHashTableScope;
54 using llvm::SmallVector;
55 using llvm::StringRef;
56 using llvm::Twine;
57 
58 namespace {
59 
60 /// Implementation of a simple MLIR emission from the Toy AST.
61 ///
62 /// This will emit operations that are specific to the Toy language, preserving
63 /// the semantics of the language and (hopefully) allow to perform accurate
64 /// analysis and transformation based on these high level semantics.
65 class MLIRGenImpl {
66 public:
MLIRGenImpl(mlir::MLIRContext & context)67   MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
68 
69   /// Public API: convert the AST for a Toy module (source file) to an MLIR
70   /// Module operation.
mlirGen(ModuleAST & moduleAST)71   mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
72     // We create an empty MLIR module and codegen functions one at a time and
73     // add them to the module.
74     theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
75 
76     for (auto &record : moduleAST) {
77       if (FunctionAST *funcAST = llvm::dyn_cast<FunctionAST>(record.get())) {
78         mlir::toy::FuncOp func = mlirGen(*funcAST);
79         if (!func)
80           return nullptr;
81         functionMap.insert({func.getName(), func});
82       } else if (StructAST *str = llvm::dyn_cast<StructAST>(record.get())) {
83         if (failed(mlirGen(*str)))
84           return nullptr;
85       } else {
86         llvm_unreachable("unknown record type");
87       }
88     }
89 
90     // Verify the module after we have finished constructing it, this will check
91     // the structural properties of the IR and invoke any specific verifiers we
92     // have on the Toy operations.
93     if (failed(mlir::verify(theModule))) {
94       theModule.emitError("module verification error");
95       return nullptr;
96     }
97 
98     return theModule;
99   }
100 
101 private:
102   /// A "module" matches a Toy source file: containing a list of functions.
103   mlir::ModuleOp theModule;
104 
105   /// The builder is a helper class to create IR inside a function. The builder
106   /// is stateful, in particular it keeps an "insertion point": this is where
107   /// the next operations will be introduced.
108   mlir::OpBuilder builder;
109 
110   /// The symbol table maps a variable name to a value in the current scope.
111   /// Entering a function creates a new scope, and the function arguments are
112   /// added to the mapping. When the processing of a function is terminated, the
113   /// scope is destroyed and the mappings created in this scope are dropped.
114   llvm::ScopedHashTable<StringRef, std::pair<mlir::Value, VarDeclExprAST *>>
115       symbolTable;
116   using SymbolTableScopeT =
117       llvm::ScopedHashTableScope<StringRef,
118                                  std::pair<mlir::Value, VarDeclExprAST *>>;
119 
120   /// A mapping for the functions that have been code generated to MLIR.
121   llvm::StringMap<mlir::toy::FuncOp> functionMap;
122 
123   /// A mapping for named struct types to the underlying MLIR type and the
124   /// original AST node.
125   llvm::StringMap<std::pair<mlir::Type, StructAST *>> structMap;
126 
127   /// Helper conversion for a Toy AST location to an MLIR location.
loc(const Location & loc)128   mlir::Location loc(const Location &loc) {
129     return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line,
130                                      loc.col);
131   }
132 
133   /// Declare a variable in the current scope, return success if the variable
134   /// wasn't declared yet.
declare(VarDeclExprAST & var,mlir::Value value)135   llvm::LogicalResult declare(VarDeclExprAST &var, mlir::Value value) {
136     if (symbolTable.count(var.getName()))
137       return mlir::failure();
138     symbolTable.insert(var.getName(), {value, &var});
139     return mlir::success();
140   }
141 
142   /// Create an MLIR type for the given struct.
mlirGen(StructAST & str)143   llvm::LogicalResult mlirGen(StructAST &str) {
144     if (structMap.count(str.getName()))
145       return emitError(loc(str.loc())) << "error: struct type with name `"
146                                        << str.getName() << "' already exists";
147 
148     auto variables = str.getVariables();
149     std::vector<mlir::Type> elementTypes;
150     elementTypes.reserve(variables.size());
151     for (auto &variable : variables) {
152       if (variable->getInitVal())
153         return emitError(loc(variable->loc()))
154                << "error: variables within a struct definition must not have "
155                   "initializers";
156       if (!variable->getType().shape.empty())
157         return emitError(loc(variable->loc()))
158                << "error: variables within a struct definition must not have "
159                   "initializers";
160 
161       mlir::Type type = getType(variable->getType(), variable->loc());
162       if (!type)
163         return mlir::failure();
164       elementTypes.push_back(type);
165     }
166 
167     structMap.try_emplace(str.getName(), StructType::get(elementTypes), &str);
168     return mlir::success();
169   }
170 
171   /// Create the prototype for an MLIR function with as many arguments as the
172   /// provided Toy AST prototype.
mlirGen(PrototypeAST & proto)173   mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
174     auto location = loc(proto.loc());
175 
176     // This is a generic function, the return type will be inferred later.
177     llvm::SmallVector<mlir::Type, 4> argTypes;
178     argTypes.reserve(proto.getArgs().size());
179     for (auto &arg : proto.getArgs()) {
180       mlir::Type type = getType(arg->getType(), arg->loc());
181       if (!type)
182         return nullptr;
183       argTypes.push_back(type);
184     }
185     auto funcType = builder.getFunctionType(argTypes, std::nullopt);
186     return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
187                                              funcType);
188   }
189 
190   /// Emit a new function and add it to the MLIR module.
mlirGen(FunctionAST & funcAST)191   mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
192     // Create a scope in the symbol table to hold variable declarations.
193     SymbolTableScopeT varScope(symbolTable);
194 
195     // Create an MLIR function for the given prototype.
196     builder.setInsertionPointToEnd(theModule.getBody());
197     mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
198     if (!function)
199       return nullptr;
200 
201     // Let's start the body of the function now!
202     mlir::Block &entryBlock = function.front();
203     auto protoArgs = funcAST.getProto()->getArgs();
204 
205     // Declare all the function arguments in the symbol table.
206     for (const auto nameValue :
207          llvm::zip(protoArgs, entryBlock.getArguments())) {
208       if (failed(declare(*std::get<0>(nameValue), std::get<1>(nameValue))))
209         return nullptr;
210     }
211 
212     // Set the insertion point in the builder to the beginning of the function
213     // body, it will be used throughout the codegen to create operations in this
214     // function.
215     builder.setInsertionPointToStart(&entryBlock);
216 
217     // Emit the body of the function.
218     if (mlir::failed(mlirGen(*funcAST.getBody()))) {
219       function.erase();
220       return nullptr;
221     }
222 
223     // Implicitly return void if no return statement was emitted.
224     // FIXME: we may fix the parser instead to always return the last expression
225     // (this would possibly help the REPL case later)
226     ReturnOp returnOp;
227     if (!entryBlock.empty())
228       returnOp = dyn_cast<ReturnOp>(entryBlock.back());
229     if (!returnOp) {
230       builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
231     } else if (returnOp.hasOperand()) {
232       // Otherwise, if this return operation has an operand then add a result to
233       // the function.
234       function.setType(
235           builder.getFunctionType(function.getFunctionType().getInputs(),
236                                   *returnOp.operand_type_begin()));
237     }
238 
239     // If this function isn't main, then set the visibility to private.
240     if (funcAST.getProto()->getName() != "main")
241       function.setPrivate();
242 
243     return function;
244   }
245 
246   /// Return the struct type that is the result of the given expression, or null
247   /// if it cannot be inferred.
getStructFor(ExprAST * expr)248   StructAST *getStructFor(ExprAST *expr) {
249     llvm::StringRef structName;
250     if (auto *decl = llvm::dyn_cast<VariableExprAST>(expr)) {
251       auto varIt = symbolTable.lookup(decl->getName());
252       if (!varIt.first)
253         return nullptr;
254       structName = varIt.second->getType().name;
255     } else if (auto *access = llvm::dyn_cast<BinaryExprAST>(expr)) {
256       if (access->getOp() != '.')
257         return nullptr;
258       // The name being accessed should be in the RHS.
259       auto *name = llvm::dyn_cast<VariableExprAST>(access->getRHS());
260       if (!name)
261         return nullptr;
262       StructAST *parentStruct = getStructFor(access->getLHS());
263       if (!parentStruct)
264         return nullptr;
265 
266       // Get the element within the struct corresponding to the name.
267       VarDeclExprAST *decl = nullptr;
268       for (auto &var : parentStruct->getVariables()) {
269         if (var->getName() == name->getName()) {
270           decl = var.get();
271           break;
272         }
273       }
274       if (!decl)
275         return nullptr;
276       structName = decl->getType().name;
277     }
278     if (structName.empty())
279       return nullptr;
280 
281     // If the struct name was valid, check for an entry in the struct map.
282     auto structIt = structMap.find(structName);
283     if (structIt == structMap.end())
284       return nullptr;
285     return structIt->second.second;
286   }
287 
288   /// Return the numeric member index of the given struct access expression.
getMemberIndex(BinaryExprAST & accessOp)289   std::optional<size_t> getMemberIndex(BinaryExprAST &accessOp) {
290     assert(accessOp.getOp() == '.' && "expected access operation");
291 
292     // Lookup the struct node for the LHS.
293     StructAST *structAST = getStructFor(accessOp.getLHS());
294     if (!structAST)
295       return std::nullopt;
296 
297     // Get the name from the RHS.
298     VariableExprAST *name = llvm::dyn_cast<VariableExprAST>(accessOp.getRHS());
299     if (!name)
300       return std::nullopt;
301 
302     auto structVars = structAST->getVariables();
303     const auto *it = llvm::find_if(structVars, [&](auto &var) {
304       return var->getName() == name->getName();
305     });
306     if (it == structVars.end())
307       return std::nullopt;
308     return it - structVars.begin();
309   }
310 
311   /// Emit a binary operation
mlirGen(BinaryExprAST & binop)312   mlir::Value mlirGen(BinaryExprAST &binop) {
313     // First emit the operations for each side of the operation before emitting
314     // the operation itself. For example if the expression is `a + foo(a)`
315     // 1) First it will visiting the LHS, which will return a reference to the
316     //    value holding `a`. This value should have been emitted at declaration
317     //    time and registered in the symbol table, so nothing would be
318     //    codegen'd. If the value is not in the symbol table, an error has been
319     //    emitted and nullptr is returned.
320     // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
321     //    and the result value is returned. If an error occurs we get a nullptr
322     //    and propagate.
323     //
324     mlir::Value lhs = mlirGen(*binop.getLHS());
325     if (!lhs)
326       return nullptr;
327     auto location = loc(binop.loc());
328 
329     // If this is an access operation, handle it immediately.
330     if (binop.getOp() == '.') {
331       std::optional<size_t> accessIndex = getMemberIndex(binop);
332       if (!accessIndex) {
333         emitError(location, "invalid access into struct expression");
334         return nullptr;
335       }
336       return builder.create<StructAccessOp>(location, lhs, *accessIndex);
337     }
338 
339     // Otherwise, this is a normal binary op.
340     mlir::Value rhs = mlirGen(*binop.getRHS());
341     if (!rhs)
342       return nullptr;
343 
344     // Derive the operation name from the binary operator. At the moment we only
345     // support '+' and '*'.
346     switch (binop.getOp()) {
347     case '+':
348       return builder.create<AddOp>(location, lhs, rhs);
349     case '*':
350       return builder.create<MulOp>(location, lhs, rhs);
351     }
352 
353     emitError(location, "invalid binary operator '") << binop.getOp() << "'";
354     return nullptr;
355   }
356 
357   /// This is a reference to a variable in an expression. The variable is
358   /// expected to have been declared and so should have a value in the symbol
359   /// table, otherwise emit an error and return nullptr.
mlirGen(VariableExprAST & expr)360   mlir::Value mlirGen(VariableExprAST &expr) {
361     if (auto variable = symbolTable.lookup(expr.getName()).first)
362       return variable;
363 
364     emitError(loc(expr.loc()), "error: unknown variable '")
365         << expr.getName() << "'";
366     return nullptr;
367   }
368 
369   /// Emit a return operation. This will return failure if any generation fails.
mlirGen(ReturnExprAST & ret)370   llvm::LogicalResult mlirGen(ReturnExprAST &ret) {
371     auto location = loc(ret.loc());
372 
373     // 'return' takes an optional expression, handle that case here.
374     mlir::Value expr = nullptr;
375     if (ret.getExpr().has_value()) {
376       if (!(expr = mlirGen(**ret.getExpr())))
377         return mlir::failure();
378     }
379 
380     // Otherwise, this return operation has zero operands.
381     builder.create<ReturnOp>(location,
382                              expr ? ArrayRef(expr) : ArrayRef<mlir::Value>());
383     return mlir::success();
384   }
385 
386   /// Emit a constant for a literal/constant array. It will be emitted as a
387   /// flattened array of data in an Attribute attached to a `toy.constant`
388   /// operation. See documentation on [Attributes](LangRef.md#attributes) for
389   /// more details. Here is an excerpt:
390   ///
391   ///   Attributes are the mechanism for specifying constant data in MLIR in
392   ///   places where a variable is never allowed [...]. They consist of a name
393   ///   and a concrete attribute value. The set of expected attributes, their
394   ///   structure, and their interpretation are all contextually dependent on
395   ///   what they are attached to.
396   ///
397   /// Example, the source level statement:
398   ///   var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
399   /// will be converted to:
400   ///   %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
401   ///     [[1.000000e+00, 2.000000e+00, 3.000000e+00],
402   ///      [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
403   ///
getConstantAttr(LiteralExprAST & lit)404   mlir::DenseElementsAttr getConstantAttr(LiteralExprAST &lit) {
405     // The attribute is a vector with a floating point value per element
406     // (number) in the array, see `collectData()` below for more details.
407     std::vector<double> data;
408     data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
409                                  std::multiplies<int>()));
410     collectData(lit, data);
411 
412     // The type of this attribute is tensor of 64-bit floating-point with the
413     // shape of the literal.
414     mlir::Type elementType = builder.getF64Type();
415     auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType);
416 
417     // This is the actual attribute that holds the list of values for this
418     // tensor literal.
419     return mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data));
420   }
getConstantAttr(NumberExprAST & lit)421   mlir::DenseElementsAttr getConstantAttr(NumberExprAST &lit) {
422     // The type of this attribute is tensor of 64-bit floating-point with no
423     // shape.
424     mlir::Type elementType = builder.getF64Type();
425     auto dataType = mlir::RankedTensorType::get({}, elementType);
426 
427     // This is the actual attribute that holds the list of values for this
428     // tensor literal.
429     return mlir::DenseElementsAttr::get(dataType,
430                                         llvm::ArrayRef(lit.getValue()));
431   }
432   /// Emit a constant for a struct literal. It will be emitted as an array of
433   /// other literals in an Attribute attached to a `toy.struct_constant`
434   /// operation. This function returns the generated constant, along with the
435   /// corresponding struct type.
436   std::pair<mlir::ArrayAttr, mlir::Type>
getConstantAttr(StructLiteralExprAST & lit)437   getConstantAttr(StructLiteralExprAST &lit) {
438     std::vector<mlir::Attribute> attrElements;
439     std::vector<mlir::Type> typeElements;
440 
441     for (auto &var : lit.getValues()) {
442       if (auto *number = llvm::dyn_cast<NumberExprAST>(var.get())) {
443         attrElements.push_back(getConstantAttr(*number));
444         typeElements.push_back(getType(std::nullopt));
445       } else if (auto *lit = llvm::dyn_cast<LiteralExprAST>(var.get())) {
446         attrElements.push_back(getConstantAttr(*lit));
447         typeElements.push_back(getType(std::nullopt));
448       } else {
449         auto *structLit = llvm::cast<StructLiteralExprAST>(var.get());
450         auto attrTypePair = getConstantAttr(*structLit);
451         attrElements.push_back(attrTypePair.first);
452         typeElements.push_back(attrTypePair.second);
453       }
454     }
455     mlir::ArrayAttr dataAttr = builder.getArrayAttr(attrElements);
456     mlir::Type dataType = StructType::get(typeElements);
457     return std::make_pair(dataAttr, dataType);
458   }
459 
460   /// Emit an array literal.
mlirGen(LiteralExprAST & lit)461   mlir::Value mlirGen(LiteralExprAST &lit) {
462     mlir::Type type = getType(lit.getDims());
463     mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit);
464 
465     // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
466     // method.
467     return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
468   }
469 
470   /// Emit a struct literal. It will be emitted as an array of
471   /// other literals in an Attribute attached to a `toy.struct_constant`
472   /// operation.
mlirGen(StructLiteralExprAST & lit)473   mlir::Value mlirGen(StructLiteralExprAST &lit) {
474     mlir::ArrayAttr dataAttr;
475     mlir::Type dataType;
476     std::tie(dataAttr, dataType) = getConstantAttr(lit);
477 
478     // Build the MLIR op `toy.struct_constant`. This invokes the
479     // `StructConstantOp::build` method.
480     return builder.create<StructConstantOp>(loc(lit.loc()), dataType, dataAttr);
481   }
482 
483   /// Recursive helper function to accumulate the data that compose an array
484   /// literal. It flattens the nested structure in the supplied vector. For
485   /// example with this array:
486   ///  [[1, 2], [3, 4]]
487   /// we will generate:
488   ///  [ 1, 2, 3, 4 ]
489   /// Individual numbers are represented as doubles.
490   /// Attributes are the way MLIR attaches constant to operations.
collectData(ExprAST & expr,std::vector<double> & data)491   void collectData(ExprAST &expr, std::vector<double> &data) {
492     if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
493       for (auto &value : lit->getValues())
494         collectData(*value, data);
495       return;
496     }
497 
498     assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
499     data.push_back(cast<NumberExprAST>(expr).getValue());
500   }
501 
502   /// Emit a call expression. It emits specific operations for the `transpose`
503   /// builtin. Other identifiers are assumed to be user-defined functions.
mlirGen(CallExprAST & call)504   mlir::Value mlirGen(CallExprAST &call) {
505     llvm::StringRef callee = call.getCallee();
506     auto location = loc(call.loc());
507 
508     // Codegen the operands first.
509     SmallVector<mlir::Value, 4> operands;
510     for (auto &expr : call.getArgs()) {
511       auto arg = mlirGen(*expr);
512       if (!arg)
513         return nullptr;
514       operands.push_back(arg);
515     }
516 
517     // Builtin calls have their custom operation, meaning this is a
518     // straightforward emission.
519     if (callee == "transpose") {
520       if (call.getArgs().size() != 1) {
521         emitError(location, "MLIR codegen encountered an error: toy.transpose "
522                             "does not accept multiple arguments");
523         return nullptr;
524       }
525       return builder.create<TransposeOp>(location, operands[0]);
526     }
527 
528     // Otherwise this is a call to a user-defined function. Calls to
529     // user-defined functions are mapped to a custom call that takes the callee
530     // name as an attribute.
531     auto calledFuncIt = functionMap.find(callee);
532     if (calledFuncIt == functionMap.end()) {
533       emitError(location) << "no defined function found for '" << callee << "'";
534       return nullptr;
535     }
536     mlir::toy::FuncOp calledFunc = calledFuncIt->second;
537     return builder.create<GenericCallOp>(
538         location, calledFunc.getFunctionType().getResult(0),
539         mlir::SymbolRefAttr::get(builder.getContext(), callee), operands);
540   }
541 
542   /// Emit a print expression. It emits specific operations for two builtins:
543   /// transpose(x) and print(x).
mlirGen(PrintExprAST & call)544   llvm::LogicalResult mlirGen(PrintExprAST &call) {
545     auto arg = mlirGen(*call.getArg());
546     if (!arg)
547       return mlir::failure();
548 
549     builder.create<PrintOp>(loc(call.loc()), arg);
550     return mlir::success();
551   }
552 
553   /// Emit a constant for a single number (FIXME: semantic? broadcast?)
mlirGen(NumberExprAST & num)554   mlir::Value mlirGen(NumberExprAST &num) {
555     return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
556   }
557 
558   /// Dispatch codegen for the right expression subclass using RTTI.
mlirGen(ExprAST & expr)559   mlir::Value mlirGen(ExprAST &expr) {
560     switch (expr.getKind()) {
561     case toy::ExprAST::Expr_BinOp:
562       return mlirGen(cast<BinaryExprAST>(expr));
563     case toy::ExprAST::Expr_Var:
564       return mlirGen(cast<VariableExprAST>(expr));
565     case toy::ExprAST::Expr_Literal:
566       return mlirGen(cast<LiteralExprAST>(expr));
567     case toy::ExprAST::Expr_StructLiteral:
568       return mlirGen(cast<StructLiteralExprAST>(expr));
569     case toy::ExprAST::Expr_Call:
570       return mlirGen(cast<CallExprAST>(expr));
571     case toy::ExprAST::Expr_Num:
572       return mlirGen(cast<NumberExprAST>(expr));
573     default:
574       emitError(loc(expr.loc()))
575           << "MLIR codegen encountered an unhandled expr kind '"
576           << Twine(expr.getKind()) << "'";
577       return nullptr;
578     }
579   }
580 
581   /// Handle a variable declaration, we'll codegen the expression that forms the
582   /// initializer and record the value in the symbol table before returning it.
583   /// Future expressions will be able to reference this variable through symbol
584   /// table lookup.
mlirGen(VarDeclExprAST & vardecl)585   mlir::Value mlirGen(VarDeclExprAST &vardecl) {
586     auto *init = vardecl.getInitVal();
587     if (!init) {
588       emitError(loc(vardecl.loc()),
589                 "missing initializer in variable declaration");
590       return nullptr;
591     }
592 
593     mlir::Value value = mlirGen(*init);
594     if (!value)
595       return nullptr;
596 
597     // Handle the case where we are initializing a struct value.
598     VarType varType = vardecl.getType();
599     if (!varType.name.empty()) {
600       // Check that the initializer type is the same as the variable
601       // declaration.
602       mlir::Type type = getType(varType, vardecl.loc());
603       if (!type)
604         return nullptr;
605       if (type != value.getType()) {
606         emitError(loc(vardecl.loc()))
607             << "struct type of initializer is different than the variable "
608                "declaration. Got "
609             << value.getType() << ", but expected " << type;
610         return nullptr;
611       }
612 
613       // Otherwise, we have the initializer value, but in case the variable was
614       // declared with specific shape, we emit a "reshape" operation. It will
615       // get optimized out later as needed.
616     } else if (!varType.shape.empty()) {
617       value = builder.create<ReshapeOp>(loc(vardecl.loc()),
618                                         getType(varType.shape), value);
619     }
620 
621     // Register the value in the symbol table.
622     if (failed(declare(vardecl, value)))
623       return nullptr;
624     return value;
625   }
626 
627   /// Codegen a list of expression, return failure if one of them hit an error.
mlirGen(ExprASTList & blockAST)628   llvm::LogicalResult mlirGen(ExprASTList &blockAST) {
629     SymbolTableScopeT varScope(symbolTable);
630     for (auto &expr : blockAST) {
631       // Specific handling for variable declarations, return statement, and
632       // print. These can only appear in block list and not in nested
633       // expressions.
634       if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
635         if (!mlirGen(*vardecl))
636           return mlir::failure();
637         continue;
638       }
639       if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
640         return mlirGen(*ret);
641       if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
642         if (mlir::failed(mlirGen(*print)))
643           return mlir::success();
644         continue;
645       }
646 
647       // Generic expression dispatch codegen.
648       if (!mlirGen(*expr))
649         return mlir::failure();
650     }
651     return mlir::success();
652   }
653 
654   /// Build a tensor type from a list of shape dimensions.
getType(ArrayRef<int64_t> shape)655   mlir::Type getType(ArrayRef<int64_t> shape) {
656     // If the shape is empty, then this type is unranked.
657     if (shape.empty())
658       return mlir::UnrankedTensorType::get(builder.getF64Type());
659 
660     // Otherwise, we use the given shape.
661     return mlir::RankedTensorType::get(shape, builder.getF64Type());
662   }
663 
664   /// Build an MLIR type from a Toy AST variable type (forward to the generic
665   /// getType above for non-struct types).
getType(const VarType & type,const Location & location)666   mlir::Type getType(const VarType &type, const Location &location) {
667     if (!type.name.empty()) {
668       auto it = structMap.find(type.name);
669       if (it == structMap.end()) {
670         emitError(loc(location))
671             << "error: unknown struct type '" << type.name << "'";
672         return nullptr;
673       }
674       return it->second.first;
675     }
676 
677     return getType(type.shape);
678   }
679 };
680 
681 } // namespace
682 
683 namespace toy {
684 
685 // The public API for codegen.
mlirGen(mlir::MLIRContext & context,ModuleAST & moduleAST)686 mlir::OwningOpRef<mlir::ModuleOp> mlirGen(mlir::MLIRContext &context,
687                                           ModuleAST &moduleAST) {
688   return MLIRGenImpl(context).mlirGen(moduleAST);
689 }
690 
691 } // namespace toy
692