1 //===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a simple IR generation targeting MLIR from a Module AST
10 // for the Toy language.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "toy/MLIRGen.h"
15 #include "mlir/IR/Block.h"
16 #include "mlir/IR/Diagnostics.h"
17 #include "mlir/IR/Value.h"
18 #include "toy/AST.h"
19 #include "toy/Dialect.h"
20
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/Verifier.h"
26 #include "toy/Lexer.h"
27
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/ScopedHashTable.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/ADT/Twine.h"
33 #include <cassert>
34 #include <cstdint>
35 #include <functional>
36 #include <numeric>
37 #include <optional>
38 #include <vector>
39
40 using namespace mlir::toy;
41 using namespace toy;
42
43 using llvm::ArrayRef;
44 using llvm::cast;
45 using llvm::dyn_cast;
46 using llvm::isa;
47 using llvm::ScopedHashTableScope;
48 using llvm::SmallVector;
49 using llvm::StringRef;
50 using llvm::Twine;
51
52 namespace {
53
54 /// Implementation of a simple MLIR emission from the Toy AST.
55 ///
56 /// This will emit operations that are specific to the Toy language, preserving
57 /// the semantics of the language and (hopefully) allow to perform accurate
58 /// analysis and transformation based on these high level semantics.
59 class MLIRGenImpl {
60 public:
MLIRGenImpl(mlir::MLIRContext & context)61 MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
62
63 /// Public API: convert the AST for a Toy module (source file) to an MLIR
64 /// Module operation.
mlirGen(ModuleAST & moduleAST)65 mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
66 // We create an empty MLIR module and codegen functions one at a time and
67 // add them to the module.
68 theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
69
70 for (FunctionAST &f : moduleAST)
71 mlirGen(f);
72
73 // Verify the module after we have finished constructing it, this will check
74 // the structural properties of the IR and invoke any specific verifiers we
75 // have on the Toy operations.
76 if (failed(mlir::verify(theModule))) {
77 theModule.emitError("module verification error");
78 return nullptr;
79 }
80
81 return theModule;
82 }
83
84 private:
85 /// A "module" matches a Toy source file: containing a list of functions.
86 mlir::ModuleOp theModule;
87
88 /// The builder is a helper class to create IR inside a function. The builder
89 /// is stateful, in particular it keeps an "insertion point": this is where
90 /// the next operations will be introduced.
91 mlir::OpBuilder builder;
92
93 /// The symbol table maps a variable name to a value in the current scope.
94 /// Entering a function creates a new scope, and the function arguments are
95 /// added to the mapping. When the processing of a function is terminated, the
96 /// scope is destroyed and the mappings created in this scope are dropped.
97 llvm::ScopedHashTable<StringRef, mlir::Value> symbolTable;
98
99 /// Helper conversion for a Toy AST location to an MLIR location.
loc(const Location & loc)100 mlir::Location loc(const Location &loc) {
101 return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line,
102 loc.col);
103 }
104
105 /// Declare a variable in the current scope, return success if the variable
106 /// wasn't declared yet.
declare(llvm::StringRef var,mlir::Value value)107 llvm::LogicalResult declare(llvm::StringRef var, mlir::Value value) {
108 if (symbolTable.count(var))
109 return mlir::failure();
110 symbolTable.insert(var, value);
111 return mlir::success();
112 }
113
114 /// Create the prototype for an MLIR function with as many arguments as the
115 /// provided Toy AST prototype.
mlirGen(PrototypeAST & proto)116 mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
117 auto location = loc(proto.loc());
118
119 // This is a generic function, the return type will be inferred later.
120 // Arguments type are uniformly unranked tensors.
121 llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
122 getType(VarType{}));
123 auto funcType = builder.getFunctionType(argTypes, std::nullopt);
124 return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
125 funcType);
126 }
127
128 /// Emit a new function and add it to the MLIR module.
mlirGen(FunctionAST & funcAST)129 mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
130 // Create a scope in the symbol table to hold variable declarations.
131 ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
132
133 // Create an MLIR function for the given prototype.
134 builder.setInsertionPointToEnd(theModule.getBody());
135 mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
136 if (!function)
137 return nullptr;
138
139 // Let's start the body of the function now!
140 mlir::Block &entryBlock = function.front();
141 auto protoArgs = funcAST.getProto()->getArgs();
142
143 // Declare all the function arguments in the symbol table.
144 for (const auto nameValue :
145 llvm::zip(protoArgs, entryBlock.getArguments())) {
146 if (failed(declare(std::get<0>(nameValue)->getName(),
147 std::get<1>(nameValue))))
148 return nullptr;
149 }
150
151 // Set the insertion point in the builder to the beginning of the function
152 // body, it will be used throughout the codegen to create operations in this
153 // function.
154 builder.setInsertionPointToStart(&entryBlock);
155
156 // Emit the body of the function.
157 if (mlir::failed(mlirGen(*funcAST.getBody()))) {
158 function.erase();
159 return nullptr;
160 }
161
162 // Implicitly return void if no return statement was emitted.
163 // FIXME: we may fix the parser instead to always return the last expression
164 // (this would possibly help the REPL case later)
165 ReturnOp returnOp;
166 if (!entryBlock.empty())
167 returnOp = dyn_cast<ReturnOp>(entryBlock.back());
168 if (!returnOp) {
169 builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
170 } else if (returnOp.hasOperand()) {
171 // Otherwise, if this return operation has an operand then add a result to
172 // the function.
173 function.setType(builder.getFunctionType(
174 function.getFunctionType().getInputs(), getType(VarType{})));
175 }
176
177 // If this function isn't main, then set the visibility to private.
178 if (funcAST.getProto()->getName() != "main")
179 function.setPrivate();
180
181 return function;
182 }
183
184 /// Emit a binary operation
mlirGen(BinaryExprAST & binop)185 mlir::Value mlirGen(BinaryExprAST &binop) {
186 // First emit the operations for each side of the operation before emitting
187 // the operation itself. For example if the expression is `a + foo(a)`
188 // 1) First it will visiting the LHS, which will return a reference to the
189 // value holding `a`. This value should have been emitted at declaration
190 // time and registered in the symbol table, so nothing would be
191 // codegen'd. If the value is not in the symbol table, an error has been
192 // emitted and nullptr is returned.
193 // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
194 // and the result value is returned. If an error occurs we get a nullptr
195 // and propagate.
196 //
197 mlir::Value lhs = mlirGen(*binop.getLHS());
198 if (!lhs)
199 return nullptr;
200 mlir::Value rhs = mlirGen(*binop.getRHS());
201 if (!rhs)
202 return nullptr;
203 auto location = loc(binop.loc());
204
205 // Derive the operation name from the binary operator. At the moment we only
206 // support '+' and '*'.
207 switch (binop.getOp()) {
208 case '+':
209 return builder.create<AddOp>(location, lhs, rhs);
210 case '*':
211 return builder.create<MulOp>(location, lhs, rhs);
212 }
213
214 emitError(location, "invalid binary operator '") << binop.getOp() << "'";
215 return nullptr;
216 }
217
218 /// This is a reference to a variable in an expression. The variable is
219 /// expected to have been declared and so should have a value in the symbol
220 /// table, otherwise emit an error and return nullptr.
mlirGen(VariableExprAST & expr)221 mlir::Value mlirGen(VariableExprAST &expr) {
222 if (auto variable = symbolTable.lookup(expr.getName()))
223 return variable;
224
225 emitError(loc(expr.loc()), "error: unknown variable '")
226 << expr.getName() << "'";
227 return nullptr;
228 }
229
230 /// Emit a return operation. This will return failure if any generation fails.
mlirGen(ReturnExprAST & ret)231 llvm::LogicalResult mlirGen(ReturnExprAST &ret) {
232 auto location = loc(ret.loc());
233
234 // 'return' takes an optional expression, handle that case here.
235 mlir::Value expr = nullptr;
236 if (ret.getExpr().has_value()) {
237 if (!(expr = mlirGen(**ret.getExpr())))
238 return mlir::failure();
239 }
240
241 // Otherwise, this return operation has zero operands.
242 builder.create<ReturnOp>(location,
243 expr ? ArrayRef(expr) : ArrayRef<mlir::Value>());
244 return mlir::success();
245 }
246
247 /// Emit a literal/constant array. It will be emitted as a flattened array of
248 /// data in an Attribute attached to a `toy.constant` operation.
249 /// See documentation on [Attributes](LangRef.md#attributes) for more details.
250 /// Here is an excerpt:
251 ///
252 /// Attributes are the mechanism for specifying constant data in MLIR in
253 /// places where a variable is never allowed [...]. They consist of a name
254 /// and a concrete attribute value. The set of expected attributes, their
255 /// structure, and their interpretation are all contextually dependent on
256 /// what they are attached to.
257 ///
258 /// Example, the source level statement:
259 /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
260 /// will be converted to:
261 /// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
262 /// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
263 /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
264 ///
mlirGen(LiteralExprAST & lit)265 mlir::Value mlirGen(LiteralExprAST &lit) {
266 auto type = getType(lit.getDims());
267
268 // The attribute is a vector with a floating point value per element
269 // (number) in the array, see `collectData()` below for more details.
270 std::vector<double> data;
271 data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
272 std::multiplies<int>()));
273 collectData(lit, data);
274
275 // The type of this attribute is tensor of 64-bit floating-point with the
276 // shape of the literal.
277 mlir::Type elementType = builder.getF64Type();
278 auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType);
279
280 // This is the actual attribute that holds the list of values for this
281 // tensor literal.
282 auto dataAttribute =
283 mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data));
284
285 // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
286 // method.
287 return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
288 }
289
290 /// Recursive helper function to accumulate the data that compose an array
291 /// literal. It flattens the nested structure in the supplied vector. For
292 /// example with this array:
293 /// [[1, 2], [3, 4]]
294 /// we will generate:
295 /// [ 1, 2, 3, 4 ]
296 /// Individual numbers are represented as doubles.
297 /// Attributes are the way MLIR attaches constant to operations.
collectData(ExprAST & expr,std::vector<double> & data)298 void collectData(ExprAST &expr, std::vector<double> &data) {
299 if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
300 for (auto &value : lit->getValues())
301 collectData(*value, data);
302 return;
303 }
304
305 assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
306 data.push_back(cast<NumberExprAST>(expr).getValue());
307 }
308
309 /// Emit a call expression. It emits specific operations for the `transpose`
310 /// builtin. Other identifiers are assumed to be user-defined functions.
mlirGen(CallExprAST & call)311 mlir::Value mlirGen(CallExprAST &call) {
312 llvm::StringRef callee = call.getCallee();
313 auto location = loc(call.loc());
314
315 // Codegen the operands first.
316 SmallVector<mlir::Value, 4> operands;
317 for (auto &expr : call.getArgs()) {
318 auto arg = mlirGen(*expr);
319 if (!arg)
320 return nullptr;
321 operands.push_back(arg);
322 }
323
324 // Builtin calls have their custom operation, meaning this is a
325 // straightforward emission.
326 if (callee == "transpose") {
327 if (call.getArgs().size() != 1) {
328 emitError(location, "MLIR codegen encountered an error: toy.transpose "
329 "does not accept multiple arguments");
330 return nullptr;
331 }
332 return builder.create<TransposeOp>(location, operands[0]);
333 }
334
335 // Otherwise this is a call to a user-defined function. Calls to
336 // user-defined functions are mapped to a custom call that takes the callee
337 // name as an attribute.
338 return builder.create<GenericCallOp>(location, callee, operands);
339 }
340
341 /// Emit a print expression. It emits specific operations for two builtins:
342 /// transpose(x) and print(x).
mlirGen(PrintExprAST & call)343 llvm::LogicalResult mlirGen(PrintExprAST &call) {
344 auto arg = mlirGen(*call.getArg());
345 if (!arg)
346 return mlir::failure();
347
348 builder.create<PrintOp>(loc(call.loc()), arg);
349 return mlir::success();
350 }
351
352 /// Emit a constant for a single number (FIXME: semantic? broadcast?)
mlirGen(NumberExprAST & num)353 mlir::Value mlirGen(NumberExprAST &num) {
354 return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
355 }
356
357 /// Dispatch codegen for the right expression subclass using RTTI.
mlirGen(ExprAST & expr)358 mlir::Value mlirGen(ExprAST &expr) {
359 switch (expr.getKind()) {
360 case toy::ExprAST::Expr_BinOp:
361 return mlirGen(cast<BinaryExprAST>(expr));
362 case toy::ExprAST::Expr_Var:
363 return mlirGen(cast<VariableExprAST>(expr));
364 case toy::ExprAST::Expr_Literal:
365 return mlirGen(cast<LiteralExprAST>(expr));
366 case toy::ExprAST::Expr_Call:
367 return mlirGen(cast<CallExprAST>(expr));
368 case toy::ExprAST::Expr_Num:
369 return mlirGen(cast<NumberExprAST>(expr));
370 default:
371 emitError(loc(expr.loc()))
372 << "MLIR codegen encountered an unhandled expr kind '"
373 << Twine(expr.getKind()) << "'";
374 return nullptr;
375 }
376 }
377
378 /// Handle a variable declaration, we'll codegen the expression that forms the
379 /// initializer and record the value in the symbol table before returning it.
380 /// Future expressions will be able to reference this variable through symbol
381 /// table lookup.
mlirGen(VarDeclExprAST & vardecl)382 mlir::Value mlirGen(VarDeclExprAST &vardecl) {
383 auto *init = vardecl.getInitVal();
384 if (!init) {
385 emitError(loc(vardecl.loc()),
386 "missing initializer in variable declaration");
387 return nullptr;
388 }
389
390 mlir::Value value = mlirGen(*init);
391 if (!value)
392 return nullptr;
393
394 // We have the initializer value, but in case the variable was declared
395 // with specific shape, we emit a "reshape" operation. It will get
396 // optimized out later as needed.
397 if (!vardecl.getType().shape.empty()) {
398 value = builder.create<ReshapeOp>(loc(vardecl.loc()),
399 getType(vardecl.getType()), value);
400 }
401
402 // Register the value in the symbol table.
403 if (failed(declare(vardecl.getName(), value)))
404 return nullptr;
405 return value;
406 }
407
408 /// Codegen a list of expression, return failure if one of them hit an error.
mlirGen(ExprASTList & blockAST)409 llvm::LogicalResult mlirGen(ExprASTList &blockAST) {
410 ScopedHashTableScope<StringRef, mlir::Value> varScope(symbolTable);
411 for (auto &expr : blockAST) {
412 // Specific handling for variable declarations, return statement, and
413 // print. These can only appear in block list and not in nested
414 // expressions.
415 if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
416 if (!mlirGen(*vardecl))
417 return mlir::failure();
418 continue;
419 }
420 if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
421 return mlirGen(*ret);
422 if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
423 if (mlir::failed(mlirGen(*print)))
424 return mlir::success();
425 continue;
426 }
427
428 // Generic expression dispatch codegen.
429 if (!mlirGen(*expr))
430 return mlir::failure();
431 }
432 return mlir::success();
433 }
434
435 /// Build a tensor type from a list of shape dimensions.
getType(ArrayRef<int64_t> shape)436 mlir::Type getType(ArrayRef<int64_t> shape) {
437 // If the shape is empty, then this type is unranked.
438 if (shape.empty())
439 return mlir::UnrankedTensorType::get(builder.getF64Type());
440
441 // Otherwise, we use the given shape.
442 return mlir::RankedTensorType::get(shape, builder.getF64Type());
443 }
444
445 /// Build an MLIR type from a Toy AST variable type (forward to the generic
446 /// getType above).
getType(const VarType & type)447 mlir::Type getType(const VarType &type) { return getType(type.shape); }
448 };
449
450 } // namespace
451
452 namespace toy {
453
454 // The public API for codegen.
mlirGen(mlir::MLIRContext & context,ModuleAST & moduleAST)455 mlir::OwningOpRef<mlir::ModuleOp> mlirGen(mlir::MLIRContext &context,
456 ModuleAST &moduleAST) {
457 return MLIRGenImpl(context).mlirGen(moduleAST);
458 }
459
460 } // namespace toy
461