1 //===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a simple IR generation targeting MLIR from a Module AST
10 // for the Toy language.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "toy/MLIRGen.h"
15 #include "mlir/IR/Block.h"
16 #include "mlir/IR/Diagnostics.h"
17 #include "mlir/IR/Value.h"
18 #include "toy/AST.h"
19 #include "toy/Dialect.h"
20
21 #include "mlir/IR/Attributes.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/IR/Verifier.h"
27 #include "toy/Lexer.h"
28
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/ScopedHashTable.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/StringMap.h"
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/Twine.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include <cassert>
37 #include <cstddef>
38 #include <cstdint>
39 #include <functional>
40 #include <numeric>
41 #include <optional>
42 #include <tuple>
43 #include <utility>
44 #include <vector>
45
46 using namespace mlir::toy;
47 using namespace toy;
48
49 using llvm::ArrayRef;
50 using llvm::cast;
51 using llvm::dyn_cast;
52 using llvm::isa;
53 using llvm::ScopedHashTableScope;
54 using llvm::SmallVector;
55 using llvm::StringRef;
56 using llvm::Twine;
57
58 namespace {
59
60 /// Implementation of a simple MLIR emission from the Toy AST.
61 ///
62 /// This will emit operations that are specific to the Toy language, preserving
63 /// the semantics of the language and (hopefully) allow to perform accurate
64 /// analysis and transformation based on these high level semantics.
65 class MLIRGenImpl {
66 public:
MLIRGenImpl(mlir::MLIRContext & context)67 MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
68
69 /// Public API: convert the AST for a Toy module (source file) to an MLIR
70 /// Module operation.
mlirGen(ModuleAST & moduleAST)71 mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
72 // We create an empty MLIR module and codegen functions one at a time and
73 // add them to the module.
74 theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
75
76 for (auto &record : moduleAST) {
77 if (FunctionAST *funcAST = llvm::dyn_cast<FunctionAST>(record.get())) {
78 mlir::toy::FuncOp func = mlirGen(*funcAST);
79 if (!func)
80 return nullptr;
81 functionMap.insert({func.getName(), func});
82 } else if (StructAST *str = llvm::dyn_cast<StructAST>(record.get())) {
83 if (failed(mlirGen(*str)))
84 return nullptr;
85 } else {
86 llvm_unreachable("unknown record type");
87 }
88 }
89
90 // Verify the module after we have finished constructing it, this will check
91 // the structural properties of the IR and invoke any specific verifiers we
92 // have on the Toy operations.
93 if (failed(mlir::verify(theModule))) {
94 theModule.emitError("module verification error");
95 return nullptr;
96 }
97
98 return theModule;
99 }
100
101 private:
102 /// A "module" matches a Toy source file: containing a list of functions.
103 mlir::ModuleOp theModule;
104
105 /// The builder is a helper class to create IR inside a function. The builder
106 /// is stateful, in particular it keeps an "insertion point": this is where
107 /// the next operations will be introduced.
108 mlir::OpBuilder builder;
109
110 /// The symbol table maps a variable name to a value in the current scope.
111 /// Entering a function creates a new scope, and the function arguments are
112 /// added to the mapping. When the processing of a function is terminated, the
113 /// scope is destroyed and the mappings created in this scope are dropped.
114 llvm::ScopedHashTable<StringRef, std::pair<mlir::Value, VarDeclExprAST *>>
115 symbolTable;
116 using SymbolTableScopeT =
117 llvm::ScopedHashTableScope<StringRef,
118 std::pair<mlir::Value, VarDeclExprAST *>>;
119
120 /// A mapping for the functions that have been code generated to MLIR.
121 llvm::StringMap<mlir::toy::FuncOp> functionMap;
122
123 /// A mapping for named struct types to the underlying MLIR type and the
124 /// original AST node.
125 llvm::StringMap<std::pair<mlir::Type, StructAST *>> structMap;
126
127 /// Helper conversion for a Toy AST location to an MLIR location.
loc(const Location & loc)128 mlir::Location loc(const Location &loc) {
129 return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line,
130 loc.col);
131 }
132
133 /// Declare a variable in the current scope, return success if the variable
134 /// wasn't declared yet.
declare(VarDeclExprAST & var,mlir::Value value)135 llvm::LogicalResult declare(VarDeclExprAST &var, mlir::Value value) {
136 if (symbolTable.count(var.getName()))
137 return mlir::failure();
138 symbolTable.insert(var.getName(), {value, &var});
139 return mlir::success();
140 }
141
142 /// Create an MLIR type for the given struct.
mlirGen(StructAST & str)143 llvm::LogicalResult mlirGen(StructAST &str) {
144 if (structMap.count(str.getName()))
145 return emitError(loc(str.loc())) << "error: struct type with name `"
146 << str.getName() << "' already exists";
147
148 auto variables = str.getVariables();
149 std::vector<mlir::Type> elementTypes;
150 elementTypes.reserve(variables.size());
151 for (auto &variable : variables) {
152 if (variable->getInitVal())
153 return emitError(loc(variable->loc()))
154 << "error: variables within a struct definition must not have "
155 "initializers";
156 if (!variable->getType().shape.empty())
157 return emitError(loc(variable->loc()))
158 << "error: variables within a struct definition must not have "
159 "initializers";
160
161 mlir::Type type = getType(variable->getType(), variable->loc());
162 if (!type)
163 return mlir::failure();
164 elementTypes.push_back(type);
165 }
166
167 structMap.try_emplace(str.getName(), StructType::get(elementTypes), &str);
168 return mlir::success();
169 }
170
171 /// Create the prototype for an MLIR function with as many arguments as the
172 /// provided Toy AST prototype.
mlirGen(PrototypeAST & proto)173 mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
174 auto location = loc(proto.loc());
175
176 // This is a generic function, the return type will be inferred later.
177 llvm::SmallVector<mlir::Type, 4> argTypes;
178 argTypes.reserve(proto.getArgs().size());
179 for (auto &arg : proto.getArgs()) {
180 mlir::Type type = getType(arg->getType(), arg->loc());
181 if (!type)
182 return nullptr;
183 argTypes.push_back(type);
184 }
185 auto funcType = builder.getFunctionType(argTypes, std::nullopt);
186 return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
187 funcType);
188 }
189
190 /// Emit a new function and add it to the MLIR module.
mlirGen(FunctionAST & funcAST)191 mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
192 // Create a scope in the symbol table to hold variable declarations.
193 SymbolTableScopeT varScope(symbolTable);
194
195 // Create an MLIR function for the given prototype.
196 builder.setInsertionPointToEnd(theModule.getBody());
197 mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
198 if (!function)
199 return nullptr;
200
201 // Let's start the body of the function now!
202 mlir::Block &entryBlock = function.front();
203 auto protoArgs = funcAST.getProto()->getArgs();
204
205 // Declare all the function arguments in the symbol table.
206 for (const auto nameValue :
207 llvm::zip(protoArgs, entryBlock.getArguments())) {
208 if (failed(declare(*std::get<0>(nameValue), std::get<1>(nameValue))))
209 return nullptr;
210 }
211
212 // Set the insertion point in the builder to the beginning of the function
213 // body, it will be used throughout the codegen to create operations in this
214 // function.
215 builder.setInsertionPointToStart(&entryBlock);
216
217 // Emit the body of the function.
218 if (mlir::failed(mlirGen(*funcAST.getBody()))) {
219 function.erase();
220 return nullptr;
221 }
222
223 // Implicitly return void if no return statement was emitted.
224 // FIXME: we may fix the parser instead to always return the last expression
225 // (this would possibly help the REPL case later)
226 ReturnOp returnOp;
227 if (!entryBlock.empty())
228 returnOp = dyn_cast<ReturnOp>(entryBlock.back());
229 if (!returnOp) {
230 builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
231 } else if (returnOp.hasOperand()) {
232 // Otherwise, if this return operation has an operand then add a result to
233 // the function.
234 function.setType(
235 builder.getFunctionType(function.getFunctionType().getInputs(),
236 *returnOp.operand_type_begin()));
237 }
238
239 // If this function isn't main, then set the visibility to private.
240 if (funcAST.getProto()->getName() != "main")
241 function.setPrivate();
242
243 return function;
244 }
245
246 /// Return the struct type that is the result of the given expression, or null
247 /// if it cannot be inferred.
getStructFor(ExprAST * expr)248 StructAST *getStructFor(ExprAST *expr) {
249 llvm::StringRef structName;
250 if (auto *decl = llvm::dyn_cast<VariableExprAST>(expr)) {
251 auto varIt = symbolTable.lookup(decl->getName());
252 if (!varIt.first)
253 return nullptr;
254 structName = varIt.second->getType().name;
255 } else if (auto *access = llvm::dyn_cast<BinaryExprAST>(expr)) {
256 if (access->getOp() != '.')
257 return nullptr;
258 // The name being accessed should be in the RHS.
259 auto *name = llvm::dyn_cast<VariableExprAST>(access->getRHS());
260 if (!name)
261 return nullptr;
262 StructAST *parentStruct = getStructFor(access->getLHS());
263 if (!parentStruct)
264 return nullptr;
265
266 // Get the element within the struct corresponding to the name.
267 VarDeclExprAST *decl = nullptr;
268 for (auto &var : parentStruct->getVariables()) {
269 if (var->getName() == name->getName()) {
270 decl = var.get();
271 break;
272 }
273 }
274 if (!decl)
275 return nullptr;
276 structName = decl->getType().name;
277 }
278 if (structName.empty())
279 return nullptr;
280
281 // If the struct name was valid, check for an entry in the struct map.
282 auto structIt = structMap.find(structName);
283 if (structIt == structMap.end())
284 return nullptr;
285 return structIt->second.second;
286 }
287
288 /// Return the numeric member index of the given struct access expression.
getMemberIndex(BinaryExprAST & accessOp)289 std::optional<size_t> getMemberIndex(BinaryExprAST &accessOp) {
290 assert(accessOp.getOp() == '.' && "expected access operation");
291
292 // Lookup the struct node for the LHS.
293 StructAST *structAST = getStructFor(accessOp.getLHS());
294 if (!structAST)
295 return std::nullopt;
296
297 // Get the name from the RHS.
298 VariableExprAST *name = llvm::dyn_cast<VariableExprAST>(accessOp.getRHS());
299 if (!name)
300 return std::nullopt;
301
302 auto structVars = structAST->getVariables();
303 const auto *it = llvm::find_if(structVars, [&](auto &var) {
304 return var->getName() == name->getName();
305 });
306 if (it == structVars.end())
307 return std::nullopt;
308 return it - structVars.begin();
309 }
310
311 /// Emit a binary operation
mlirGen(BinaryExprAST & binop)312 mlir::Value mlirGen(BinaryExprAST &binop) {
313 // First emit the operations for each side of the operation before emitting
314 // the operation itself. For example if the expression is `a + foo(a)`
315 // 1) First it will visiting the LHS, which will return a reference to the
316 // value holding `a`. This value should have been emitted at declaration
317 // time and registered in the symbol table, so nothing would be
318 // codegen'd. If the value is not in the symbol table, an error has been
319 // emitted and nullptr is returned.
320 // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
321 // and the result value is returned. If an error occurs we get a nullptr
322 // and propagate.
323 //
324 mlir::Value lhs = mlirGen(*binop.getLHS());
325 if (!lhs)
326 return nullptr;
327 auto location = loc(binop.loc());
328
329 // If this is an access operation, handle it immediately.
330 if (binop.getOp() == '.') {
331 std::optional<size_t> accessIndex = getMemberIndex(binop);
332 if (!accessIndex) {
333 emitError(location, "invalid access into struct expression");
334 return nullptr;
335 }
336 return builder.create<StructAccessOp>(location, lhs, *accessIndex);
337 }
338
339 // Otherwise, this is a normal binary op.
340 mlir::Value rhs = mlirGen(*binop.getRHS());
341 if (!rhs)
342 return nullptr;
343
344 // Derive the operation name from the binary operator. At the moment we only
345 // support '+' and '*'.
346 switch (binop.getOp()) {
347 case '+':
348 return builder.create<AddOp>(location, lhs, rhs);
349 case '*':
350 return builder.create<MulOp>(location, lhs, rhs);
351 }
352
353 emitError(location, "invalid binary operator '") << binop.getOp() << "'";
354 return nullptr;
355 }
356
357 /// This is a reference to a variable in an expression. The variable is
358 /// expected to have been declared and so should have a value in the symbol
359 /// table, otherwise emit an error and return nullptr.
mlirGen(VariableExprAST & expr)360 mlir::Value mlirGen(VariableExprAST &expr) {
361 if (auto variable = symbolTable.lookup(expr.getName()).first)
362 return variable;
363
364 emitError(loc(expr.loc()), "error: unknown variable '")
365 << expr.getName() << "'";
366 return nullptr;
367 }
368
369 /// Emit a return operation. This will return failure if any generation fails.
mlirGen(ReturnExprAST & ret)370 llvm::LogicalResult mlirGen(ReturnExprAST &ret) {
371 auto location = loc(ret.loc());
372
373 // 'return' takes an optional expression, handle that case here.
374 mlir::Value expr = nullptr;
375 if (ret.getExpr().has_value()) {
376 if (!(expr = mlirGen(**ret.getExpr())))
377 return mlir::failure();
378 }
379
380 // Otherwise, this return operation has zero operands.
381 builder.create<ReturnOp>(location,
382 expr ? ArrayRef(expr) : ArrayRef<mlir::Value>());
383 return mlir::success();
384 }
385
386 /// Emit a constant for a literal/constant array. It will be emitted as a
387 /// flattened array of data in an Attribute attached to a `toy.constant`
388 /// operation. See documentation on [Attributes](LangRef.md#attributes) for
389 /// more details. Here is an excerpt:
390 ///
391 /// Attributes are the mechanism for specifying constant data in MLIR in
392 /// places where a variable is never allowed [...]. They consist of a name
393 /// and a concrete attribute value. The set of expected attributes, their
394 /// structure, and their interpretation are all contextually dependent on
395 /// what they are attached to.
396 ///
397 /// Example, the source level statement:
398 /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
399 /// will be converted to:
400 /// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
401 /// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
402 /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
403 ///
getConstantAttr(LiteralExprAST & lit)404 mlir::DenseElementsAttr getConstantAttr(LiteralExprAST &lit) {
405 // The attribute is a vector with a floating point value per element
406 // (number) in the array, see `collectData()` below for more details.
407 std::vector<double> data;
408 data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
409 std::multiplies<int>()));
410 collectData(lit, data);
411
412 // The type of this attribute is tensor of 64-bit floating-point with the
413 // shape of the literal.
414 mlir::Type elementType = builder.getF64Type();
415 auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType);
416
417 // This is the actual attribute that holds the list of values for this
418 // tensor literal.
419 return mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data));
420 }
getConstantAttr(NumberExprAST & lit)421 mlir::DenseElementsAttr getConstantAttr(NumberExprAST &lit) {
422 // The type of this attribute is tensor of 64-bit floating-point with no
423 // shape.
424 mlir::Type elementType = builder.getF64Type();
425 auto dataType = mlir::RankedTensorType::get({}, elementType);
426
427 // This is the actual attribute that holds the list of values for this
428 // tensor literal.
429 return mlir::DenseElementsAttr::get(dataType,
430 llvm::ArrayRef(lit.getValue()));
431 }
432 /// Emit a constant for a struct literal. It will be emitted as an array of
433 /// other literals in an Attribute attached to a `toy.struct_constant`
434 /// operation. This function returns the generated constant, along with the
435 /// corresponding struct type.
436 std::pair<mlir::ArrayAttr, mlir::Type>
getConstantAttr(StructLiteralExprAST & lit)437 getConstantAttr(StructLiteralExprAST &lit) {
438 std::vector<mlir::Attribute> attrElements;
439 std::vector<mlir::Type> typeElements;
440
441 for (auto &var : lit.getValues()) {
442 if (auto *number = llvm::dyn_cast<NumberExprAST>(var.get())) {
443 attrElements.push_back(getConstantAttr(*number));
444 typeElements.push_back(getType(std::nullopt));
445 } else if (auto *lit = llvm::dyn_cast<LiteralExprAST>(var.get())) {
446 attrElements.push_back(getConstantAttr(*lit));
447 typeElements.push_back(getType(std::nullopt));
448 } else {
449 auto *structLit = llvm::cast<StructLiteralExprAST>(var.get());
450 auto attrTypePair = getConstantAttr(*structLit);
451 attrElements.push_back(attrTypePair.first);
452 typeElements.push_back(attrTypePair.second);
453 }
454 }
455 mlir::ArrayAttr dataAttr = builder.getArrayAttr(attrElements);
456 mlir::Type dataType = StructType::get(typeElements);
457 return std::make_pair(dataAttr, dataType);
458 }
459
460 /// Emit an array literal.
mlirGen(LiteralExprAST & lit)461 mlir::Value mlirGen(LiteralExprAST &lit) {
462 mlir::Type type = getType(lit.getDims());
463 mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit);
464
465 // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
466 // method.
467 return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
468 }
469
470 /// Emit a struct literal. It will be emitted as an array of
471 /// other literals in an Attribute attached to a `toy.struct_constant`
472 /// operation.
mlirGen(StructLiteralExprAST & lit)473 mlir::Value mlirGen(StructLiteralExprAST &lit) {
474 mlir::ArrayAttr dataAttr;
475 mlir::Type dataType;
476 std::tie(dataAttr, dataType) = getConstantAttr(lit);
477
478 // Build the MLIR op `toy.struct_constant`. This invokes the
479 // `StructConstantOp::build` method.
480 return builder.create<StructConstantOp>(loc(lit.loc()), dataType, dataAttr);
481 }
482
483 /// Recursive helper function to accumulate the data that compose an array
484 /// literal. It flattens the nested structure in the supplied vector. For
485 /// example with this array:
486 /// [[1, 2], [3, 4]]
487 /// we will generate:
488 /// [ 1, 2, 3, 4 ]
489 /// Individual numbers are represented as doubles.
490 /// Attributes are the way MLIR attaches constant to operations.
collectData(ExprAST & expr,std::vector<double> & data)491 void collectData(ExprAST &expr, std::vector<double> &data) {
492 if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
493 for (auto &value : lit->getValues())
494 collectData(*value, data);
495 return;
496 }
497
498 assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
499 data.push_back(cast<NumberExprAST>(expr).getValue());
500 }
501
502 /// Emit a call expression. It emits specific operations for the `transpose`
503 /// builtin. Other identifiers are assumed to be user-defined functions.
mlirGen(CallExprAST & call)504 mlir::Value mlirGen(CallExprAST &call) {
505 llvm::StringRef callee = call.getCallee();
506 auto location = loc(call.loc());
507
508 // Codegen the operands first.
509 SmallVector<mlir::Value, 4> operands;
510 for (auto &expr : call.getArgs()) {
511 auto arg = mlirGen(*expr);
512 if (!arg)
513 return nullptr;
514 operands.push_back(arg);
515 }
516
517 // Builtin calls have their custom operation, meaning this is a
518 // straightforward emission.
519 if (callee == "transpose") {
520 if (call.getArgs().size() != 1) {
521 emitError(location, "MLIR codegen encountered an error: toy.transpose "
522 "does not accept multiple arguments");
523 return nullptr;
524 }
525 return builder.create<TransposeOp>(location, operands[0]);
526 }
527
528 // Otherwise this is a call to a user-defined function. Calls to
529 // user-defined functions are mapped to a custom call that takes the callee
530 // name as an attribute.
531 auto calledFuncIt = functionMap.find(callee);
532 if (calledFuncIt == functionMap.end()) {
533 emitError(location) << "no defined function found for '" << callee << "'";
534 return nullptr;
535 }
536 mlir::toy::FuncOp calledFunc = calledFuncIt->second;
537 return builder.create<GenericCallOp>(
538 location, calledFunc.getFunctionType().getResult(0),
539 mlir::SymbolRefAttr::get(builder.getContext(), callee), operands);
540 }
541
542 /// Emit a print expression. It emits specific operations for two builtins:
543 /// transpose(x) and print(x).
mlirGen(PrintExprAST & call)544 llvm::LogicalResult mlirGen(PrintExprAST &call) {
545 auto arg = mlirGen(*call.getArg());
546 if (!arg)
547 return mlir::failure();
548
549 builder.create<PrintOp>(loc(call.loc()), arg);
550 return mlir::success();
551 }
552
553 /// Emit a constant for a single number (FIXME: semantic? broadcast?)
mlirGen(NumberExprAST & num)554 mlir::Value mlirGen(NumberExprAST &num) {
555 return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
556 }
557
558 /// Dispatch codegen for the right expression subclass using RTTI.
mlirGen(ExprAST & expr)559 mlir::Value mlirGen(ExprAST &expr) {
560 switch (expr.getKind()) {
561 case toy::ExprAST::Expr_BinOp:
562 return mlirGen(cast<BinaryExprAST>(expr));
563 case toy::ExprAST::Expr_Var:
564 return mlirGen(cast<VariableExprAST>(expr));
565 case toy::ExprAST::Expr_Literal:
566 return mlirGen(cast<LiteralExprAST>(expr));
567 case toy::ExprAST::Expr_StructLiteral:
568 return mlirGen(cast<StructLiteralExprAST>(expr));
569 case toy::ExprAST::Expr_Call:
570 return mlirGen(cast<CallExprAST>(expr));
571 case toy::ExprAST::Expr_Num:
572 return mlirGen(cast<NumberExprAST>(expr));
573 default:
574 emitError(loc(expr.loc()))
575 << "MLIR codegen encountered an unhandled expr kind '"
576 << Twine(expr.getKind()) << "'";
577 return nullptr;
578 }
579 }
580
581 /// Handle a variable declaration, we'll codegen the expression that forms the
582 /// initializer and record the value in the symbol table before returning it.
583 /// Future expressions will be able to reference this variable through symbol
584 /// table lookup.
mlirGen(VarDeclExprAST & vardecl)585 mlir::Value mlirGen(VarDeclExprAST &vardecl) {
586 auto *init = vardecl.getInitVal();
587 if (!init) {
588 emitError(loc(vardecl.loc()),
589 "missing initializer in variable declaration");
590 return nullptr;
591 }
592
593 mlir::Value value = mlirGen(*init);
594 if (!value)
595 return nullptr;
596
597 // Handle the case where we are initializing a struct value.
598 VarType varType = vardecl.getType();
599 if (!varType.name.empty()) {
600 // Check that the initializer type is the same as the variable
601 // declaration.
602 mlir::Type type = getType(varType, vardecl.loc());
603 if (!type)
604 return nullptr;
605 if (type != value.getType()) {
606 emitError(loc(vardecl.loc()))
607 << "struct type of initializer is different than the variable "
608 "declaration. Got "
609 << value.getType() << ", but expected " << type;
610 return nullptr;
611 }
612
613 // Otherwise, we have the initializer value, but in case the variable was
614 // declared with specific shape, we emit a "reshape" operation. It will
615 // get optimized out later as needed.
616 } else if (!varType.shape.empty()) {
617 value = builder.create<ReshapeOp>(loc(vardecl.loc()),
618 getType(varType.shape), value);
619 }
620
621 // Register the value in the symbol table.
622 if (failed(declare(vardecl, value)))
623 return nullptr;
624 return value;
625 }
626
627 /// Codegen a list of expression, return failure if one of them hit an error.
mlirGen(ExprASTList & blockAST)628 llvm::LogicalResult mlirGen(ExprASTList &blockAST) {
629 SymbolTableScopeT varScope(symbolTable);
630 for (auto &expr : blockAST) {
631 // Specific handling for variable declarations, return statement, and
632 // print. These can only appear in block list and not in nested
633 // expressions.
634 if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
635 if (!mlirGen(*vardecl))
636 return mlir::failure();
637 continue;
638 }
639 if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
640 return mlirGen(*ret);
641 if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
642 if (mlir::failed(mlirGen(*print)))
643 return mlir::success();
644 continue;
645 }
646
647 // Generic expression dispatch codegen.
648 if (!mlirGen(*expr))
649 return mlir::failure();
650 }
651 return mlir::success();
652 }
653
654 /// Build a tensor type from a list of shape dimensions.
getType(ArrayRef<int64_t> shape)655 mlir::Type getType(ArrayRef<int64_t> shape) {
656 // If the shape is empty, then this type is unranked.
657 if (shape.empty())
658 return mlir::UnrankedTensorType::get(builder.getF64Type());
659
660 // Otherwise, we use the given shape.
661 return mlir::RankedTensorType::get(shape, builder.getF64Type());
662 }
663
664 /// Build an MLIR type from a Toy AST variable type (forward to the generic
665 /// getType above for non-struct types).
getType(const VarType & type,const Location & location)666 mlir::Type getType(const VarType &type, const Location &location) {
667 if (!type.name.empty()) {
668 auto it = structMap.find(type.name);
669 if (it == structMap.end()) {
670 emitError(loc(location))
671 << "error: unknown struct type '" << type.name << "'";
672 return nullptr;
673 }
674 return it->second.first;
675 }
676
677 return getType(type.shape);
678 }
679 };
680
681 } // namespace
682
683 namespace toy {
684
685 // The public API for codegen.
mlirGen(mlir::MLIRContext & context,ModuleAST & moduleAST)686 mlir::OwningOpRef<mlir::ModuleOp> mlirGen(mlir::MLIRContext &context,
687 ModuleAST &moduleAST) {
688 return MLIRGenImpl(context).mlirGen(moduleAST);
689 }
690
691 } // namespace toy
692