xref: /llvm-project/mlir/examples/toy/Ch7/mlir/Dialect.cpp (revision 2655ae54db6d7e9276a5ef4208cbeff1ae2ee72c)
1 //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the dialect for the Toy IR: custom type parsing and
10 // operation verification.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "toy/Dialect.h"
15 
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/OperationSupport.h"
25 #include "mlir/IR/TypeSupport.h"
26 #include "mlir/IR/ValueRange.h"
27 #include "mlir/Interfaces/CallInterfaces.h"
28 #include "mlir/Interfaces/FunctionImplementation.h"
29 #include "mlir/Support/LLVM.h"
30 #include "mlir/Transforms/InliningUtils.h"
31 #include "llvm/ADT/ArrayRef.h"
32 #include "llvm/ADT/Hashing.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/StringRef.h"
35 #include "llvm/Support/Casting.h"
36 #include <algorithm>
37 #include <cassert>
38 #include <cstddef>
39 #include <cstdint>
40 #include <string>
41 
42 using namespace mlir;
43 using namespace mlir::toy;
44 
45 #include "toy/Dialect.cpp.inc"
46 
47 //===----------------------------------------------------------------------===//
48 // ToyInlinerInterface
49 //===----------------------------------------------------------------------===//
50 
51 /// This class defines the interface for handling inlining with Toy
52 /// operations.
53 struct ToyInlinerInterface : public DialectInlinerInterface {
54   using DialectInlinerInterface::DialectInlinerInterface;
55 
56   //===--------------------------------------------------------------------===//
57   // Analysis Hooks
58   //===--------------------------------------------------------------------===//
59 
60   /// All call operations within toy can be inlined.
61   bool isLegalToInline(Operation *call, Operation *callable,
62                        bool wouldBeCloned) const final {
63     return true;
64   }
65 
66   /// All operations within toy can be inlined.
67   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
68     return true;
69   }
70 
71   // All functions within toy can be inlined.
72   bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
73     return true;
74   }
75 
76   //===--------------------------------------------------------------------===//
77   // Transformation Hooks
78   //===--------------------------------------------------------------------===//
79 
80   /// Handle the given inlined terminator(toy.return) by replacing it with a new
81   /// operation as necessary.
82   void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
83     // Only "toy.return" needs to be handled here.
84     auto returnOp = cast<ReturnOp>(op);
85 
86     // Replace the values directly with the return operands.
87     assert(returnOp.getNumOperands() == valuesToRepl.size());
88     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
89       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
90   }
91 
92   /// Attempts to materialize a conversion for a type mismatch between a call
93   /// from this dialect, and a callable region. This method should generate an
94   /// operation that takes 'input' as the only operand, and produces a single
95   /// result of 'resultType'. If a conversion can not be generated, nullptr
96   /// should be returned.
97   Operation *materializeCallConversion(OpBuilder &builder, Value input,
98                                        Type resultType,
99                                        Location conversionLoc) const final {
100     return builder.create<CastOp>(conversionLoc, resultType, input);
101   }
102 };
103 
104 //===----------------------------------------------------------------------===//
105 // Toy Operations
106 //===----------------------------------------------------------------------===//
107 
108 /// A generalized parser for binary operations. This parses the different forms
109 /// of 'printBinaryOp' below.
110 static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
111                                        mlir::OperationState &result) {
112   SmallVector<mlir::OpAsmParser::UnresolvedOperand, 2> operands;
113   SMLoc operandsLoc = parser.getCurrentLocation();
114   Type type;
115   if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
116       parser.parseOptionalAttrDict(result.attributes) ||
117       parser.parseColonType(type))
118     return mlir::failure();
119 
120   // If the type is a function type, it contains the input and result types of
121   // this operation.
122   if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
123     if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
124                                result.operands))
125       return mlir::failure();
126     result.addTypes(funcType.getResults());
127     return mlir::success();
128   }
129 
130   // Otherwise, the parsed type is the type of both operands and results.
131   if (parser.resolveOperands(operands, type, result.operands))
132     return mlir::failure();
133   result.addTypes(type);
134   return mlir::success();
135 }
136 
137 /// A generalized printer for binary operations. It prints in two different
138 /// forms depending on if all of the types match.
139 static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) {
140   printer << " " << op->getOperands();
141   printer.printOptionalAttrDict(op->getAttrs());
142   printer << " : ";
143 
144   // If all of the types are the same, print the type directly.
145   Type resultType = *op->result_type_begin();
146   if (llvm::all_of(op->getOperandTypes(),
147                    [=](Type type) { return type == resultType; })) {
148     printer << resultType;
149     return;
150   }
151 
152   // Otherwise, print a functional type.
153   printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // ConstantOp
158 //===----------------------------------------------------------------------===//
159 
160 /// Build a constant operation.
161 /// The builder is passed as an argument, so is the state that this method is
162 /// expected to fill in order to build the operation.
163 void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
164                        double value) {
165   auto dataType = RankedTensorType::get({}, builder.getF64Type());
166   auto dataAttribute = DenseElementsAttr::get(dataType, value);
167   ConstantOp::build(builder, state, dataType, dataAttribute);
168 }
169 
170 /// The 'OpAsmParser' class provides a collection of methods for parsing
171 /// various punctuation, as well as attributes, operands, types, etc. Each of
172 /// these methods returns a `ParseResult`. This class is a wrapper around
173 /// `LogicalResult` that can be converted to a boolean `true` value on failure,
174 /// or `false` on success. This allows for easily chaining together a set of
175 /// parser rules. These rules are used to populate an `mlir::OperationState`
176 /// similarly to the `build` methods described above.
177 mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser,
178                                     mlir::OperationState &result) {
179   mlir::DenseElementsAttr value;
180   if (parser.parseOptionalAttrDict(result.attributes) ||
181       parser.parseAttribute(value, "value", result.attributes))
182     return failure();
183 
184   result.addTypes(value.getType());
185   return success();
186 }
187 
188 /// The 'OpAsmPrinter' class is a stream that allows for formatting
189 /// strings, attributes, operands, types, etc.
190 void ConstantOp::print(mlir::OpAsmPrinter &printer) {
191   printer << " ";
192   printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
193   printer << getValue();
194 }
195 
196 /// Verify that the given attribute value is valid for the given type.
197 static llvm::LogicalResult verifyConstantForType(mlir::Type type,
198                                                  mlir::Attribute opaqueValue,
199                                                  mlir::Operation *op) {
200   if (llvm::isa<mlir::TensorType>(type)) {
201     // Check that the value is an elements attribute.
202     auto attrValue = llvm::dyn_cast<mlir::DenseFPElementsAttr>(opaqueValue);
203     if (!attrValue)
204       return op->emitError("constant of TensorType must be initialized by "
205                            "a DenseFPElementsAttr, got ")
206              << opaqueValue;
207 
208     // If the return type of the constant is not an unranked tensor, the shape
209     // must match the shape of the attribute holding the data.
210     auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(type);
211     if (!resultType)
212       return success();
213 
214     // Check that the rank of the attribute type matches the rank of the
215     // constant result type.
216     auto attrType = llvm::cast<mlir::RankedTensorType>(attrValue.getType());
217     if (attrType.getRank() != resultType.getRank()) {
218       return op->emitOpError("return type must match the one of the attached "
219                              "value attribute: ")
220              << attrType.getRank() << " != " << resultType.getRank();
221     }
222 
223     // Check that each of the dimensions match between the two types.
224     for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
225       if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
226         return op->emitOpError(
227                    "return type shape mismatches its attribute at dimension ")
228                << dim << ": " << attrType.getShape()[dim]
229                << " != " << resultType.getShape()[dim];
230       }
231     }
232     return mlir::success();
233   }
234   auto resultType = llvm::cast<StructType>(type);
235   llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
236 
237   // Verify that the initializer is an Array.
238   auto attrValue = llvm::dyn_cast<ArrayAttr>(opaqueValue);
239   if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
240     return op->emitError("constant of StructType must be initialized by an "
241                          "ArrayAttr with the same number of elements, got ")
242            << opaqueValue;
243 
244   // Check that each of the elements are valid.
245   llvm::ArrayRef<mlir::Attribute> attrElementValues = attrValue.getValue();
246   for (const auto it : llvm::zip(resultElementTypes, attrElementValues))
247     if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op)))
248       return mlir::failure();
249   return mlir::success();
250 }
251 
252 /// Verifier for the constant operation. This corresponds to the `::verify(...)`
253 /// in the op definition.
254 llvm::LogicalResult ConstantOp::verify() {
255   return verifyConstantForType(getResult().getType(), getValue(), *this);
256 }
257 
258 llvm::LogicalResult StructConstantOp::verify() {
259   return verifyConstantForType(getResult().getType(), getValue(), *this);
260 }
261 
262 /// Infer the output shape of the ConstantOp, this is required by the shape
263 /// inference interface.
264 void ConstantOp::inferShapes() {
265   getResult().setType(cast<TensorType>(getValue().getType()));
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // AddOp
270 //===----------------------------------------------------------------------===//
271 
272 void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
273                   mlir::Value lhs, mlir::Value rhs) {
274   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
275   state.addOperands({lhs, rhs});
276 }
277 
278 mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser,
279                                mlir::OperationState &result) {
280   return parseBinaryOp(parser, result);
281 }
282 
283 void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
284 
285 /// Infer the output shape of the AddOp, this is required by the shape inference
286 /// interface.
287 void AddOp::inferShapes() { getResult().setType(getLhs().getType()); }
288 
289 //===----------------------------------------------------------------------===//
290 // CastOp
291 //===----------------------------------------------------------------------===//
292 
293 /// Infer the output shape of the CastOp, this is required by the shape
294 /// inference interface.
295 void CastOp::inferShapes() { getResult().setType(getInput().getType()); }
296 
297 /// Returns true if the given set of input and result types are compatible with
298 /// this cast operation. This is required by the `CastOpInterface` to verify
299 /// this operation and provide other additional utilities.
300 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
301   if (inputs.size() != 1 || outputs.size() != 1)
302     return false;
303   // The inputs must be Tensors with the same element type.
304   TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
305   TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
306   if (!input || !output || input.getElementType() != output.getElementType())
307     return false;
308   // The shape is required to match if both types are ranked.
309   return !input.hasRank() || !output.hasRank() || input == output;
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // FuncOp
314 //===----------------------------------------------------------------------===//
315 
316 void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
317                    llvm::StringRef name, mlir::FunctionType type,
318                    llvm::ArrayRef<mlir::NamedAttribute> attrs) {
319   // FunctionOpInterface provides a convenient `build` method that will populate
320   // the state of our FuncOp, and create an entry block.
321   buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
322 }
323 
324 mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
325                                 mlir::OperationState &result) {
326   // Dispatch to the FunctionOpInterface provided utility method that parses the
327   // function operation.
328   auto buildFuncType =
329       [](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
330          llvm::ArrayRef<mlir::Type> results,
331          mlir::function_interface_impl::VariadicFlag,
332          std::string &) { return builder.getFunctionType(argTypes, results); };
333 
334   return mlir::function_interface_impl::parseFunctionOp(
335       parser, result, /*allowVariadic=*/false,
336       getFunctionTypeAttrName(result.name), buildFuncType,
337       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
338 }
339 
340 void FuncOp::print(mlir::OpAsmPrinter &p) {
341   // Dispatch to the FunctionOpInterface provided utility method that prints the
342   // function operation.
343   mlir::function_interface_impl::printFunctionOp(
344       p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
345       getArgAttrsAttrName(), getResAttrsAttrName());
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // GenericCallOp
350 //===----------------------------------------------------------------------===//
351 
352 void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
353                           StringRef callee, ArrayRef<mlir::Value> arguments) {
354   // Generic call always returns an unranked Tensor initially.
355   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
356   state.addOperands(arguments);
357   state.addAttribute("callee",
358                      mlir::SymbolRefAttr::get(builder.getContext(), callee));
359 }
360 
361 /// Return the callee of the generic call operation, this is required by the
362 /// call interface.
363 CallInterfaceCallable GenericCallOp::getCallableForCallee() {
364   return (*this)->getAttrOfType<SymbolRefAttr>("callee");
365 }
366 
367 /// Set the callee for the generic call operation, this is required by the call
368 /// interface.
369 void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
370   (*this)->setAttr("callee", cast<SymbolRefAttr>(callee));
371 }
372 
373 /// Get the argument operands to the called function, this is required by the
374 /// call interface.
375 Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
376 
377 /// Get the argument operands to the called function as a mutable range, this is
378 /// required by the call interface.
379 MutableOperandRange GenericCallOp::getArgOperandsMutable() {
380   return getInputsMutable();
381 }
382 
383 //===----------------------------------------------------------------------===//
384 // MulOp
385 //===----------------------------------------------------------------------===//
386 
387 void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
388                   mlir::Value lhs, mlir::Value rhs) {
389   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
390   state.addOperands({lhs, rhs});
391 }
392 
393 mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser,
394                                mlir::OperationState &result) {
395   return parseBinaryOp(parser, result);
396 }
397 
398 void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
399 
400 /// Infer the output shape of the MulOp, this is required by the shape inference
401 /// interface.
402 void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
403 
404 //===----------------------------------------------------------------------===//
405 // ReturnOp
406 //===----------------------------------------------------------------------===//
407 
408 llvm::LogicalResult ReturnOp::verify() {
409   // We know that the parent operation is a function, because of the 'HasParent'
410   // trait attached to the operation definition.
411   auto function = cast<FuncOp>((*this)->getParentOp());
412 
413   /// ReturnOps can only have a single optional operand.
414   if (getNumOperands() > 1)
415     return emitOpError() << "expects at most 1 return operand";
416 
417   // The operand number and types must match the function signature.
418   const auto &results = function.getFunctionType().getResults();
419   if (getNumOperands() != results.size())
420     return emitOpError() << "does not return the same number of values ("
421                          << getNumOperands() << ") as the enclosing function ("
422                          << results.size() << ")";
423 
424   // If the operation does not have an input, we are done.
425   if (!hasOperand())
426     return mlir::success();
427 
428   auto inputType = *operand_type_begin();
429   auto resultType = results.front();
430 
431   // Check that the result type of the function matches the operand type.
432   if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
433       llvm::isa<mlir::UnrankedTensorType>(resultType))
434     return mlir::success();
435 
436   return emitError() << "type of return operand (" << inputType
437                      << ") doesn't match function result type (" << resultType
438                      << ")";
439 }
440 
441 //===----------------------------------------------------------------------===//
442 // StructAccessOp
443 //===----------------------------------------------------------------------===//
444 
445 void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
446                            mlir::Value input, size_t index) {
447   // Extract the result type from the input type.
448   StructType structTy = llvm::cast<StructType>(input.getType());
449   assert(index < structTy.getNumElementTypes());
450   mlir::Type resultType = structTy.getElementTypes()[index];
451 
452   // Call into the auto-generated build method.
453   build(b, state, resultType, input, b.getI64IntegerAttr(index));
454 }
455 
456 llvm::LogicalResult StructAccessOp::verify() {
457   StructType structTy = llvm::cast<StructType>(getInput().getType());
458   size_t indexValue = getIndex();
459   if (indexValue >= structTy.getNumElementTypes())
460     return emitOpError()
461            << "index should be within the range of the input struct type";
462   mlir::Type resultType = getResult().getType();
463   if (resultType != structTy.getElementTypes()[indexValue])
464     return emitOpError() << "must have the same result type as the struct "
465                             "element referred to by the index";
466   return mlir::success();
467 }
468 
469 //===----------------------------------------------------------------------===//
470 // TransposeOp
471 //===----------------------------------------------------------------------===//
472 
473 void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
474                         mlir::Value value) {
475   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
476   state.addOperands(value);
477 }
478 
479 void TransposeOp::inferShapes() {
480   auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
481   SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
482   getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
483 }
484 
485 llvm::LogicalResult TransposeOp::verify() {
486   auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
487   auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
488   if (!inputType || !resultType)
489     return mlir::success();
490 
491   auto inputShape = inputType.getShape();
492   if (!std::equal(inputShape.begin(), inputShape.end(),
493                   resultType.getShape().rbegin())) {
494     return emitError()
495            << "expected result shape to be a transpose of the input";
496   }
497   return mlir::success();
498 }
499 
500 //===----------------------------------------------------------------------===//
501 // Toy Types
502 //===----------------------------------------------------------------------===//
503 
504 namespace mlir {
505 namespace toy {
506 namespace detail {
507 /// This class represents the internal storage of the Toy `StructType`.
508 struct StructTypeStorage : public mlir::TypeStorage {
509   /// The `KeyTy` is a required type that provides an interface for the storage
510   /// instance. This type will be used when uniquing an instance of the type
511   /// storage. For our struct type, we will unique each instance structurally on
512   /// the elements that it contains.
513   using KeyTy = llvm::ArrayRef<mlir::Type>;
514 
515   /// A constructor for the type storage instance.
516   StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes)
517       : elementTypes(elementTypes) {}
518 
519   /// Define the comparison function for the key type with the current storage
520   /// instance. This is used when constructing a new instance to ensure that we
521   /// haven't already uniqued an instance of the given key.
522   bool operator==(const KeyTy &key) const { return key == elementTypes; }
523 
524   /// Define a hash function for the key type. This is used when uniquing
525   /// instances of the storage, see the `StructType::get` method.
526   /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type
527   /// have hash functions available, so we could just omit this entirely.
528   static llvm::hash_code hashKey(const KeyTy &key) {
529     return llvm::hash_value(key);
530   }
531 
532   /// Define a construction function for the key type from a set of parameters.
533   /// These parameters will be provided when constructing the storage instance
534   /// itself.
535   /// Note: This method isn't necessary because KeyTy can be directly
536   /// constructed with the given parameters.
537   static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) {
538     return KeyTy(elementTypes);
539   }
540 
541   /// Define a construction method for creating a new instance of this storage.
542   /// This method takes an instance of a storage allocator, and an instance of a
543   /// `KeyTy`. The given allocator must be used for *all* necessary dynamic
544   /// allocations used to create the type storage and its internal.
545   static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
546                                       const KeyTy &key) {
547     // Copy the elements from the provided `KeyTy` into the allocator.
548     llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key);
549 
550     // Allocate the storage instance and construct it.
551     return new (allocator.allocate<StructTypeStorage>())
552         StructTypeStorage(elementTypes);
553   }
554 
555   /// The following field contains the element types of the struct.
556   llvm::ArrayRef<mlir::Type> elementTypes;
557 };
558 } // namespace detail
559 } // namespace toy
560 } // namespace mlir
561 
562 /// Create an instance of a `StructType` with the given element types. There
563 /// *must* be at least one element type.
564 StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
565   assert(!elementTypes.empty() && "expected at least 1 element type");
566 
567   // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
568   // of this type. The first parameter is the context to unique in. The
569   // parameters after the context are forwarded to the storage instance.
570   mlir::MLIRContext *ctx = elementTypes.front().getContext();
571   return Base::get(ctx, elementTypes);
572 }
573 
574 /// Returns the element types of this struct type.
575 llvm::ArrayRef<mlir::Type> StructType::getElementTypes() {
576   // 'getImpl' returns a pointer to the internal storage instance.
577   return getImpl()->elementTypes;
578 }
579 
580 /// Parse an instance of a type registered to the toy dialect.
581 mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
582   // Parse a struct type in the following form:
583   //   struct-type ::= `struct` `<` type (`,` type)* `>`
584 
585   // NOTE: All MLIR parser function return a ParseResult. This is a
586   // specialization of LogicalResult that auto-converts to a `true` boolean
587   // value on failure to allow for chaining, but may be used with explicit
588   // `mlir::failed/mlir::succeeded` as desired.
589 
590   // Parse: `struct` `<`
591   if (parser.parseKeyword("struct") || parser.parseLess())
592     return Type();
593 
594   // Parse the element types of the struct.
595   SmallVector<mlir::Type, 1> elementTypes;
596   do {
597     // Parse the current element type.
598     SMLoc typeLoc = parser.getCurrentLocation();
599     mlir::Type elementType;
600     if (parser.parseType(elementType))
601       return nullptr;
602 
603     // Check that the type is either a TensorType or another StructType.
604     if (!llvm::isa<mlir::TensorType, StructType>(elementType)) {
605       parser.emitError(typeLoc, "element type for a struct must either "
606                                 "be a TensorType or a StructType, got: ")
607           << elementType;
608       return Type();
609     }
610     elementTypes.push_back(elementType);
611 
612     // Parse the optional: `,`
613   } while (succeeded(parser.parseOptionalComma()));
614 
615   // Parse: `>`
616   if (parser.parseGreater())
617     return Type();
618   return StructType::get(elementTypes);
619 }
620 
621 /// Print an instance of a type registered to the toy dialect.
622 void ToyDialect::printType(mlir::Type type,
623                            mlir::DialectAsmPrinter &printer) const {
624   // Currently the only toy type is a struct type.
625   StructType structType = llvm::cast<StructType>(type);
626 
627   // Print the struct type according to the parser format.
628   printer << "struct<";
629   llvm::interleaveComma(structType.getElementTypes(), printer);
630   printer << '>';
631 }
632 
633 //===----------------------------------------------------------------------===//
634 // TableGen'd op method definitions
635 //===----------------------------------------------------------------------===//
636 
637 #define GET_OP_CLASSES
638 #include "toy/Ops.cpp.inc"
639 
640 //===----------------------------------------------------------------------===//
641 // ToyDialect
642 //===----------------------------------------------------------------------===//
643 
644 /// Dialect initialization, the instance will be owned by the context. This is
645 /// the point of registration of types and operations for the dialect.
646 void ToyDialect::initialize() {
647   addOperations<
648 #define GET_OP_LIST
649 #include "toy/Ops.cpp.inc"
650       >();
651   addInterfaces<ToyInlinerInterface>();
652   addTypes<StructType>();
653 }
654 
655 mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
656                                                  mlir::Attribute value,
657                                                  mlir::Type type,
658                                                  mlir::Location loc) {
659   if (llvm::isa<StructType>(type))
660     return builder.create<StructConstantOp>(loc, type,
661                                             llvm::cast<mlir::ArrayAttr>(value));
662   return builder.create<ConstantOp>(loc, type,
663                                     llvm::cast<mlir::DenseElementsAttr>(value));
664 }
665