1# Chapter 7: Adding a Composite Type to Toy 2 3[TOC] 4 5In the [previous chapter](Ch-6.md), we demonstrated an end-to-end compilation 6flow from our Toy front-end to LLVM IR. In this chapter, we will extend the Toy 7language to support a new composite `struct` type. 8 9## Defining a `struct` in Toy 10 11The first thing we need to define is the interface of this type in our `toy` 12source language. The general syntax of a `struct` type in Toy is as follows: 13 14```toy 15# A struct is defined by using the `struct` keyword followed by a name. 16struct MyStruct { 17 # Inside of the struct is a list of variable declarations without initializers 18 # or shapes, which may also be other previously defined structs. 19 var a; 20 var b; 21} 22``` 23 24Structs may now be used in functions as variables or parameters by using the 25name of the struct instead of `var`. The members of the struct are accessed via 26a `.` access operator. Values of `struct` type may be initialized with a 27composite initializer, or a comma-separated list of other initializers 28surrounded by `{}`. An example is shown below: 29 30```toy 31struct Struct { 32 var a; 33 var b; 34} 35 36# User defined generic function may operate on struct types as well. 37def multiply_transpose(Struct value) { 38 # We can access the elements of a struct via the '.' operator. 39 return transpose(value.a) * transpose(value.b); 40} 41 42def main() { 43 # We initialize struct values using a composite initializer. 44 Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]}; 45 46 # We pass these arguments to functions like we do with variables. 47 var c = multiply_transpose(value); 48 print(c); 49} 50``` 51 52## Defining a `struct` in MLIR 53 54In MLIR, we will also need a representation for our struct types. MLIR does not 55provide a type that does exactly what we need, so we will need to define our 56own. We will simply define our `struct` as an unnamed container of a set of 57element types. The name of the `struct` and its elements are only useful for the 58AST of our `toy` compiler, so we don't need to encode it in the MLIR 59representation. 60 61### Defining the Type Class 62 63#### Defining the Type Class 64 65As mentioned in [chapter 2](Ch-2.md), [`Type`](../../LangRef.md/#type-system) 66objects in MLIR are value-typed and rely on having an internal storage object 67that holds the actual data for the type. The `Type` class in itself acts as a 68simple wrapper around an internal `TypeStorage` object that is uniqued within an 69instance of an `MLIRContext`. When constructing a `Type`, we are internally just 70constructing and uniquing an instance of a storage class. 71 72When defining a new `Type` that contains parametric data (e.g. the `struct` 73type, which requires additional information to hold the element types), we will 74need to provide a derived storage class. The `singleton` types that don't have 75any additional data (e.g. the [`index` type](../../Dialects/Builtin.md/#indextype)) don't 76require a storage class and use the default `TypeStorage`. 77 78##### Defining the Storage Class 79 80Type storage objects contain all of the data necessary to construct and unique a 81type instance. Derived storage classes must inherit from the base 82`mlir::TypeStorage` and provide a set of aliases and hooks that will be used by 83the `MLIRContext` for uniquing. Below is the definition of the storage instance 84for our `struct` type, with each of the necessary requirements detailed inline: 85 86```c++ 87/// This class represents the internal storage of the Toy `StructType`. 88struct StructTypeStorage : public mlir::TypeStorage { 89 /// The `KeyTy` is a required type that provides an interface for the storage 90 /// instance. This type will be used when uniquing an instance of the type 91 /// storage. For our struct type, we will unique each instance structurally on 92 /// the elements that it contains. 93 using KeyTy = llvm::ArrayRef<mlir::Type>; 94 95 /// A constructor for the type storage instance. 96 StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes) 97 : elementTypes(elementTypes) {} 98 99 /// Define the comparison function for the key type with the current storage 100 /// instance. This is used when constructing a new instance to ensure that we 101 /// haven't already uniqued an instance of the given key. 102 bool operator==(const KeyTy &key) const { return key == elementTypes; } 103 104 /// Define a hash function for the key type. This is used when uniquing 105 /// instances of the storage. 106 /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type 107 /// have hash functions available, so we could just omit this entirely. 108 static llvm::hash_code hashKey(const KeyTy &key) { 109 return llvm::hash_value(key); 110 } 111 112 /// Define a construction function for the key type from a set of parameters. 113 /// These parameters will be provided when constructing the storage instance 114 /// itself, see the `StructType::get` method further below. 115 /// Note: This method isn't necessary because KeyTy can be directly 116 /// constructed with the given parameters. 117 static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) { 118 return KeyTy(elementTypes); 119 } 120 121 /// Define a construction method for creating a new instance of this storage. 122 /// This method takes an instance of a storage allocator, and an instance of a 123 /// `KeyTy`. The given allocator must be used for *all* necessary dynamic 124 /// allocations used to create the type storage and its internal. 125 static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator, 126 const KeyTy &key) { 127 // Copy the elements from the provided `KeyTy` into the allocator. 128 llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key); 129 130 // Allocate the storage instance and construct it. 131 return new (allocator.allocate<StructTypeStorage>()) 132 StructTypeStorage(elementTypes); 133 } 134 135 /// The following field contains the element types of the struct. 136 llvm::ArrayRef<mlir::Type> elementTypes; 137}; 138``` 139 140##### Defining the Type Class 141 142With the storage class defined, we can add the definition for the user-visible 143`StructType` class. This is the class that we will actually interface with. 144 145```c++ 146/// This class defines the Toy struct type. It represents a collection of 147/// element types. All derived types in MLIR must inherit from the CRTP class 148/// 'Type::TypeBase'. It takes as template parameters the concrete type 149/// (StructType), the base class to use (Type), and the storage class 150/// (StructTypeStorage). 151class StructType : public mlir::Type::TypeBase<StructType, mlir::Type, 152 StructTypeStorage> { 153public: 154 /// Inherit some necessary constructors from 'TypeBase'. 155 using Base::Base; 156 157 /// Create an instance of a `StructType` with the given element types. There 158 /// *must* be at least one element type. 159 static StructType get(llvm::ArrayRef<mlir::Type> elementTypes) { 160 assert(!elementTypes.empty() && "expected at least 1 element type"); 161 162 // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance 163 // of this type. The first parameter is the context to unique in. The 164 // parameters after are forwarded to the storage instance. 165 mlir::MLIRContext *ctx = elementTypes.front().getContext(); 166 return Base::get(ctx, elementTypes); 167 } 168 169 /// Returns the element types of this struct type. 170 llvm::ArrayRef<mlir::Type> getElementTypes() { 171 // 'getImpl' returns a pointer to the internal storage instance. 172 return getImpl()->elementTypes; 173 } 174 175 /// Returns the number of element type held by this struct. 176 size_t getNumElementTypes() { return getElementTypes().size(); } 177}; 178``` 179 180We register this type in the `ToyDialect` initializer in a similar way to how we 181did with operations: 182 183```c++ 184void ToyDialect::initialize() { 185 addTypes<StructType>(); 186} 187``` 188 189(An important note here is that when registering a type, the definition of the 190storage class must be visible.) 191 192With this we can now use our `StructType` when generating MLIR from Toy. See 193examples/toy/Ch7/mlir/MLIRGen.cpp for more details. 194 195### Exposing to ODS 196 197After defining a new type, we should make the ODS framework aware of our Type so 198that we can use it in the operation definitions and auto-generate utilities 199within the Dialect. A simple example is shown below: 200 201```tablegen 202// Provide a definition for the Toy StructType for use in ODS. This allows for 203// using StructType in a similar way to Tensor or MemRef. We use `DialectType` 204// to demarcate the StructType as belonging to the Toy dialect. 205def Toy_StructType : 206 DialectType<Toy_Dialect, CPred<"$_self.isa<StructType>()">, 207 "Toy struct type">; 208 209// Provide a definition of the types that are used within the Toy dialect. 210def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>; 211``` 212 213### Parsing and Printing 214 215At this point we can use our `StructType` during MLIR generation and 216transformation, but we can't output or parse `.mlir`. For this we need to add 217support for parsing and printing instances of the `StructType`. This can be done 218by overriding the `parseType` and `printType` methods on the `ToyDialect`. 219Declarations for these methods are automatically provided when the type is 220exposed to ODS as detailed in the previous section. 221 222```c++ 223class ToyDialect : public mlir::Dialect { 224public: 225 /// Parse an instance of a type registered to the toy dialect. 226 mlir::Type parseType(mlir::DialectAsmParser &parser) const override; 227 228 /// Print an instance of a type registered to the toy dialect. 229 void printType(mlir::Type type, 230 mlir::DialectAsmPrinter &printer) const override; 231}; 232``` 233 234These methods take an instance of a high-level parser or printer that allows for 235easily implementing the necessary functionality. Before going into the 236implementation, let's think about the syntax that we want for the `struct` type 237in the printed IR. As described in the 238[MLIR language reference](../../LangRef.md/#dialect-types), dialect types are 239generally represented as: `! dialect-namespace < type-data >`, with a pretty 240form available under certain circumstances. The responsibility of our `Toy` 241parser and printer is to provide the `type-data` bits. We will define our 242`StructType` as having the following form: 243 244``` 245 struct-type ::= `struct` `<` type (`,` type)* `>` 246``` 247 248#### Parsing 249 250An implementation of the parser is shown below: 251 252```c++ 253/// Parse an instance of a type registered to the toy dialect. 254mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const { 255 // Parse a struct type in the following form: 256 // struct-type ::= `struct` `<` type (`,` type)* `>` 257 258 // NOTE: All MLIR parser function return a ParseResult. This is a 259 // specialization of LogicalResult that auto-converts to a `true` boolean 260 // value on failure to allow for chaining, but may be used with explicit 261 // `mlir::failed/mlir::succeeded` as desired. 262 263 // Parse: `struct` `<` 264 if (parser.parseKeyword("struct") || parser.parseLess()) 265 return Type(); 266 267 // Parse the element types of the struct. 268 SmallVector<mlir::Type, 1> elementTypes; 269 do { 270 // Parse the current element type. 271 SMLoc typeLoc = parser.getCurrentLocation(); 272 mlir::Type elementType; 273 if (parser.parseType(elementType)) 274 return nullptr; 275 276 // Check that the type is either a TensorType or another StructType. 277 if (!elementType.isa<mlir::TensorType, StructType>()) { 278 parser.emitError(typeLoc, "element type for a struct must either " 279 "be a TensorType or a StructType, got: ") 280 << elementType; 281 return Type(); 282 } 283 elementTypes.push_back(elementType); 284 285 // Parse the optional: `,` 286 } while (succeeded(parser.parseOptionalComma())); 287 288 // Parse: `>` 289 if (parser.parseGreater()) 290 return Type(); 291 return StructType::get(elementTypes); 292} 293``` 294 295#### Printing 296 297An implementation of the printer is shown below: 298 299```c++ 300/// Print an instance of a type registered to the toy dialect. 301void ToyDialect::printType(mlir::Type type, 302 mlir::DialectAsmPrinter &printer) const { 303 // Currently the only toy type is a struct type. 304 StructType structType = type.cast<StructType>(); 305 306 // Print the struct type according to the parser format. 307 printer << "struct<"; 308 llvm::interleaveComma(structType.getElementTypes(), printer); 309 printer << '>'; 310} 311``` 312 313Before moving on, let's look at a quick of example showcasing the functionality 314we have now: 315 316```toy 317struct Struct { 318 var a; 319 var b; 320} 321 322def multiply_transpose(Struct value) { 323} 324``` 325 326Which generates the following: 327 328```mlir 329module { 330 toy.func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) { 331 toy.return 332 } 333} 334``` 335 336### Operating on `StructType` 337 338Now that the `struct` type has been defined, and we can round-trip it through 339the IR. The next step is to add support for using it within our operations. 340 341#### Updating Existing Operations 342 343A few of our existing operations, e.g. `ReturnOp`, will need to be updated to 344handle `Toy_StructType`. 345 346```tablegen 347def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { 348 ... 349 let arguments = (ins Variadic<Toy_Type>:$input); 350 ... 351} 352``` 353 354#### Adding New `Toy` Operations 355 356In addition to the existing operations, we will be adding a few new operations 357that will provide more specific handling of `structs`. 358 359##### `toy.struct_constant` 360 361This new operation materializes a constant value for a struct. In our current 362modeling, we just use an [array attribute](../../Dialects/Builtin.md/#arrayattr) 363that contains a set of constant values for each of the `struct` elements. 364 365```mlir 366 %0 = toy.struct_constant [ 367 dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> 368 ] : !toy.struct<tensor<*xf64>> 369``` 370 371##### `toy.struct_access` 372 373This new operation materializes the Nth element of a `struct` value. 374 375```mlir 376 // Using %0 from above 377 %1 = toy.struct_access %0[0] : !toy.struct<tensor<*xf64>> -> tensor<*xf64> 378``` 379 380With these operations, we can revisit our original example: 381 382```toy 383struct Struct { 384 var a; 385 var b; 386} 387 388# User defined generic function may operate on struct types as well. 389def multiply_transpose(Struct value) { 390 # We can access the elements of a struct via the '.' operator. 391 return transpose(value.a) * transpose(value.b); 392} 393 394def main() { 395 # We initialize struct values using a composite initializer. 396 Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]}; 397 398 # We pass these arguments to functions like we do with variables. 399 var c = multiply_transpose(value); 400 print(c); 401} 402``` 403 404and finally get a full MLIR module: 405 406```mlir 407module { 408 toy.func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> { 409 %0 = toy.struct_access %arg0[0] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64> 410 %1 = toy.transpose(%0 : tensor<*xf64>) to tensor<*xf64> 411 %2 = toy.struct_access %arg0[1] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64> 412 %3 = toy.transpose(%2 : tensor<*xf64>) to tensor<*xf64> 413 %4 = toy.mul %1, %3 : tensor<*xf64> 414 toy.return %4 : tensor<*xf64> 415 } 416 toy.func @main() { 417 %0 = toy.struct_constant [ 418 dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, 419 dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> 420 ] : !toy.struct<tensor<*xf64>, tensor<*xf64>> 421 %1 = toy.generic_call @multiply_transpose(%0) : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> 422 toy.print %1 : tensor<*xf64> 423 toy.return 424 } 425} 426``` 427 428#### Optimizing Operations on `StructType` 429 430Now that we have a few operations operating on `StructType`, we also have many 431new constant folding opportunities. 432 433After inlining, the MLIR module in the previous section looks something like: 434 435```mlir 436module { 437 toy.func @main() { 438 %0 = toy.struct_constant [ 439 dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, 440 dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> 441 ] : !toy.struct<tensor<*xf64>, tensor<*xf64>> 442 %1 = toy.struct_access %0[0] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64> 443 %2 = toy.transpose(%1 : tensor<*xf64>) to tensor<*xf64> 444 %3 = toy.struct_access %0[1] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64> 445 %4 = toy.transpose(%3 : tensor<*xf64>) to tensor<*xf64> 446 %5 = toy.mul %2, %4 : tensor<*xf64> 447 toy.print %5 : tensor<*xf64> 448 toy.return 449 } 450} 451``` 452 453We have several `toy.struct_access` operations that access into a 454`toy.struct_constant`. As detailed in [chapter 3](Ch-3.md) (FoldConstantReshape), 455we can add folders for these `toy` operations by setting the `hasFolder` bit 456on the operation definition and providing a definition of the `*Op::fold` 457method. 458 459```c++ 460/// Fold constants. 461OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return value(); } 462 463/// Fold struct constants. 464OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { 465 return value(); 466} 467 468/// Fold simple struct access operations that access into a constant. 469OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) { 470 auto structAttr = adaptor.getInput().dyn_cast_or_null<mlir::ArrayAttr>(); 471 if (!structAttr) 472 return nullptr; 473 474 size_t elementIndex = index().getZExtValue(); 475 return structAttr[elementIndex]; 476} 477``` 478 479To ensure that MLIR generates the proper constant operations when folding our 480`Toy` operations, i.e. `ConstantOp` for `TensorType` and `StructConstant` for 481`StructType`, we will need to provide an override for the dialect hook 482`materializeConstant`. This allows for generic MLIR operations to create 483constants for the `Toy` dialect when necessary. 484 485```c++ 486mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, 487 mlir::Attribute value, 488 mlir::Type type, 489 mlir::Location loc) { 490 if (type.isa<StructType>()) 491 return builder.create<StructConstantOp>(loc, type, 492 value.cast<mlir::ArrayAttr>()); 493 return builder.create<ConstantOp>(loc, type, 494 value.cast<mlir::DenseElementsAttr>()); 495} 496``` 497 498With this, we can now generate code that can be generated to LLVM without any 499changes to our pipeline. 500 501```mlir 502module { 503 toy.func @main() { 504 %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> 505 %1 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64> 506 %2 = toy.mul %1, %1 : tensor<3x2xf64> 507 toy.print %2 : tensor<3x2xf64> 508 toy.return 509 } 510} 511``` 512 513You can build `toyc-ch7` and try yourself: `toyc-ch7 514test/Examples/Toy/Ch7/struct-codegen.toy -emit=mlir`. More details on defining 515custom types can be found in 516[DefiningAttributesAndTypes](../../DefiningDialects/AttributesAndTypes.md). 517