1 //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the dialect for the Toy IR: custom type parsing and 10 // operation verification. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "toy/Dialect.h" 15 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinAttributes.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/IR/DialectImplementation.h" 21 #include "mlir/IR/Location.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/OperationSupport.h" 25 #include "mlir/IR/TypeSupport.h" 26 #include "mlir/IR/ValueRange.h" 27 #include "mlir/Interfaces/CallInterfaces.h" 28 #include "mlir/Interfaces/FunctionImplementation.h" 29 #include "mlir/Support/LLVM.h" 30 #include "mlir/Transforms/InliningUtils.h" 31 #include "llvm/ADT/ArrayRef.h" 32 #include "llvm/ADT/Hashing.h" 33 #include "llvm/ADT/STLExtras.h" 34 #include "llvm/ADT/StringRef.h" 35 #include "llvm/Support/Casting.h" 36 #include <algorithm> 37 #include <cassert> 38 #include <cstddef> 39 #include <cstdint> 40 #include <string> 41 42 using namespace mlir; 43 using namespace mlir::toy; 44 45 #include "toy/Dialect.cpp.inc" 46 47 //===----------------------------------------------------------------------===// 48 // ToyInlinerInterface 49 //===----------------------------------------------------------------------===// 50 51 /// This class defines the interface for handling inlining with Toy 52 /// operations. 53 struct ToyInlinerInterface : public DialectInlinerInterface { 54 using DialectInlinerInterface::DialectInlinerInterface; 55 56 //===--------------------------------------------------------------------===// 57 // Analysis Hooks 58 //===--------------------------------------------------------------------===// 59 60 /// All call operations within toy can be inlined. 61 bool isLegalToInline(Operation *call, Operation *callable, 62 bool wouldBeCloned) const final { 63 return true; 64 } 65 66 /// All operations within toy can be inlined. 67 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { 68 return true; 69 } 70 71 // All functions within toy can be inlined. 72 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { 73 return true; 74 } 75 76 //===--------------------------------------------------------------------===// 77 // Transformation Hooks 78 //===--------------------------------------------------------------------===// 79 80 /// Handle the given inlined terminator(toy.return) by replacing it with a new 81 /// operation as necessary. 82 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { 83 // Only "toy.return" needs to be handled here. 84 auto returnOp = cast<ReturnOp>(op); 85 86 // Replace the values directly with the return operands. 87 assert(returnOp.getNumOperands() == valuesToRepl.size()); 88 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 89 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 90 } 91 92 /// Attempts to materialize a conversion for a type mismatch between a call 93 /// from this dialect, and a callable region. This method should generate an 94 /// operation that takes 'input' as the only operand, and produces a single 95 /// result of 'resultType'. If a conversion can not be generated, nullptr 96 /// should be returned. 97 Operation *materializeCallConversion(OpBuilder &builder, Value input, 98 Type resultType, 99 Location conversionLoc) const final { 100 return builder.create<CastOp>(conversionLoc, resultType, input); 101 } 102 }; 103 104 //===----------------------------------------------------------------------===// 105 // Toy Operations 106 //===----------------------------------------------------------------------===// 107 108 /// A generalized parser for binary operations. This parses the different forms 109 /// of 'printBinaryOp' below. 110 static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, 111 mlir::OperationState &result) { 112 SmallVector<mlir::OpAsmParser::UnresolvedOperand, 2> operands; 113 SMLoc operandsLoc = parser.getCurrentLocation(); 114 Type type; 115 if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || 116 parser.parseOptionalAttrDict(result.attributes) || 117 parser.parseColonType(type)) 118 return mlir::failure(); 119 120 // If the type is a function type, it contains the input and result types of 121 // this operation. 122 if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) { 123 if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, 124 result.operands)) 125 return mlir::failure(); 126 result.addTypes(funcType.getResults()); 127 return mlir::success(); 128 } 129 130 // Otherwise, the parsed type is the type of both operands and results. 131 if (parser.resolveOperands(operands, type, result.operands)) 132 return mlir::failure(); 133 result.addTypes(type); 134 return mlir::success(); 135 } 136 137 /// A generalized printer for binary operations. It prints in two different 138 /// forms depending on if all of the types match. 139 static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { 140 printer << " " << op->getOperands(); 141 printer.printOptionalAttrDict(op->getAttrs()); 142 printer << " : "; 143 144 // If all of the types are the same, print the type directly. 145 Type resultType = *op->result_type_begin(); 146 if (llvm::all_of(op->getOperandTypes(), 147 [=](Type type) { return type == resultType; })) { 148 printer << resultType; 149 return; 150 } 151 152 // Otherwise, print a functional type. 153 printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); 154 } 155 156 //===----------------------------------------------------------------------===// 157 // ConstantOp 158 //===----------------------------------------------------------------------===// 159 160 /// Build a constant operation. 161 /// The builder is passed as an argument, so is the state that this method is 162 /// expected to fill in order to build the operation. 163 void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, 164 double value) { 165 auto dataType = RankedTensorType::get({}, builder.getF64Type()); 166 auto dataAttribute = DenseElementsAttr::get(dataType, value); 167 ConstantOp::build(builder, state, dataType, dataAttribute); 168 } 169 170 /// The 'OpAsmParser' class provides a collection of methods for parsing 171 /// various punctuation, as well as attributes, operands, types, etc. Each of 172 /// these methods returns a `ParseResult`. This class is a wrapper around 173 /// `LogicalResult` that can be converted to a boolean `true` value on failure, 174 /// or `false` on success. This allows for easily chaining together a set of 175 /// parser rules. These rules are used to populate an `mlir::OperationState` 176 /// similarly to the `build` methods described above. 177 mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser, 178 mlir::OperationState &result) { 179 mlir::DenseElementsAttr value; 180 if (parser.parseOptionalAttrDict(result.attributes) || 181 parser.parseAttribute(value, "value", result.attributes)) 182 return failure(); 183 184 result.addTypes(value.getType()); 185 return success(); 186 } 187 188 /// The 'OpAsmPrinter' class is a stream that allows for formatting 189 /// strings, attributes, operands, types, etc. 190 void ConstantOp::print(mlir::OpAsmPrinter &printer) { 191 printer << " "; 192 printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); 193 printer << getValue(); 194 } 195 196 /// Verify that the given attribute value is valid for the given type. 197 static llvm::LogicalResult verifyConstantForType(mlir::Type type, 198 mlir::Attribute opaqueValue, 199 mlir::Operation *op) { 200 if (llvm::isa<mlir::TensorType>(type)) { 201 // Check that the value is an elements attribute. 202 auto attrValue = llvm::dyn_cast<mlir::DenseFPElementsAttr>(opaqueValue); 203 if (!attrValue) 204 return op->emitError("constant of TensorType must be initialized by " 205 "a DenseFPElementsAttr, got ") 206 << opaqueValue; 207 208 // If the return type of the constant is not an unranked tensor, the shape 209 // must match the shape of the attribute holding the data. 210 auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(type); 211 if (!resultType) 212 return success(); 213 214 // Check that the rank of the attribute type matches the rank of the 215 // constant result type. 216 auto attrType = llvm::cast<mlir::RankedTensorType>(attrValue.getType()); 217 if (attrType.getRank() != resultType.getRank()) { 218 return op->emitOpError("return type must match the one of the attached " 219 "value attribute: ") 220 << attrType.getRank() << " != " << resultType.getRank(); 221 } 222 223 // Check that each of the dimensions match between the two types. 224 for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { 225 if (attrType.getShape()[dim] != resultType.getShape()[dim]) { 226 return op->emitOpError( 227 "return type shape mismatches its attribute at dimension ") 228 << dim << ": " << attrType.getShape()[dim] 229 << " != " << resultType.getShape()[dim]; 230 } 231 } 232 return mlir::success(); 233 } 234 auto resultType = llvm::cast<StructType>(type); 235 llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes(); 236 237 // Verify that the initializer is an Array. 238 auto attrValue = llvm::dyn_cast<ArrayAttr>(opaqueValue); 239 if (!attrValue || attrValue.getValue().size() != resultElementTypes.size()) 240 return op->emitError("constant of StructType must be initialized by an " 241 "ArrayAttr with the same number of elements, got ") 242 << opaqueValue; 243 244 // Check that each of the elements are valid. 245 llvm::ArrayRef<mlir::Attribute> attrElementValues = attrValue.getValue(); 246 for (const auto it : llvm::zip(resultElementTypes, attrElementValues)) 247 if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op))) 248 return mlir::failure(); 249 return mlir::success(); 250 } 251 252 /// Verifier for the constant operation. This corresponds to the `::verify(...)` 253 /// in the op definition. 254 llvm::LogicalResult ConstantOp::verify() { 255 return verifyConstantForType(getResult().getType(), getValue(), *this); 256 } 257 258 llvm::LogicalResult StructConstantOp::verify() { 259 return verifyConstantForType(getResult().getType(), getValue(), *this); 260 } 261 262 /// Infer the output shape of the ConstantOp, this is required by the shape 263 /// inference interface. 264 void ConstantOp::inferShapes() { 265 getResult().setType(cast<TensorType>(getValue().getType())); 266 } 267 268 //===----------------------------------------------------------------------===// 269 // AddOp 270 //===----------------------------------------------------------------------===// 271 272 void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, 273 mlir::Value lhs, mlir::Value rhs) { 274 state.addTypes(UnrankedTensorType::get(builder.getF64Type())); 275 state.addOperands({lhs, rhs}); 276 } 277 278 mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser, 279 mlir::OperationState &result) { 280 return parseBinaryOp(parser, result); 281 } 282 283 void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } 284 285 /// Infer the output shape of the AddOp, this is required by the shape inference 286 /// interface. 287 void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } 288 289 //===----------------------------------------------------------------------===// 290 // CastOp 291 //===----------------------------------------------------------------------===// 292 293 /// Infer the output shape of the CastOp, this is required by the shape 294 /// inference interface. 295 void CastOp::inferShapes() { getResult().setType(getInput().getType()); } 296 297 /// Returns true if the given set of input and result types are compatible with 298 /// this cast operation. This is required by the `CastOpInterface` to verify 299 /// this operation and provide other additional utilities. 300 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 301 if (inputs.size() != 1 || outputs.size() != 1) 302 return false; 303 // The inputs must be Tensors with the same element type. 304 TensorType input = llvm::dyn_cast<TensorType>(inputs.front()); 305 TensorType output = llvm::dyn_cast<TensorType>(outputs.front()); 306 if (!input || !output || input.getElementType() != output.getElementType()) 307 return false; 308 // The shape is required to match if both types are ranked. 309 return !input.hasRank() || !output.hasRank() || input == output; 310 } 311 312 //===----------------------------------------------------------------------===// 313 // FuncOp 314 //===----------------------------------------------------------------------===// 315 316 void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, 317 llvm::StringRef name, mlir::FunctionType type, 318 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 319 // FunctionOpInterface provides a convenient `build` method that will populate 320 // the state of our FuncOp, and create an entry block. 321 buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); 322 } 323 324 mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, 325 mlir::OperationState &result) { 326 // Dispatch to the FunctionOpInterface provided utility method that parses the 327 // function operation. 328 auto buildFuncType = 329 [](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes, 330 llvm::ArrayRef<mlir::Type> results, 331 mlir::function_interface_impl::VariadicFlag, 332 std::string &) { return builder.getFunctionType(argTypes, results); }; 333 334 return mlir::function_interface_impl::parseFunctionOp( 335 parser, result, /*allowVariadic=*/false, 336 getFunctionTypeAttrName(result.name), buildFuncType, 337 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 338 } 339 340 void FuncOp::print(mlir::OpAsmPrinter &p) { 341 // Dispatch to the FunctionOpInterface provided utility method that prints the 342 // function operation. 343 mlir::function_interface_impl::printFunctionOp( 344 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), 345 getArgAttrsAttrName(), getResAttrsAttrName()); 346 } 347 348 //===----------------------------------------------------------------------===// 349 // GenericCallOp 350 //===----------------------------------------------------------------------===// 351 352 void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, 353 StringRef callee, ArrayRef<mlir::Value> arguments) { 354 // Generic call always returns an unranked Tensor initially. 355 state.addTypes(UnrankedTensorType::get(builder.getF64Type())); 356 state.addOperands(arguments); 357 state.addAttribute("callee", 358 mlir::SymbolRefAttr::get(builder.getContext(), callee)); 359 } 360 361 /// Return the callee of the generic call operation, this is required by the 362 /// call interface. 363 CallInterfaceCallable GenericCallOp::getCallableForCallee() { 364 return (*this)->getAttrOfType<SymbolRefAttr>("callee"); 365 } 366 367 /// Set the callee for the generic call operation, this is required by the call 368 /// interface. 369 void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { 370 (*this)->setAttr("callee", cast<SymbolRefAttr>(callee)); 371 } 372 373 /// Get the argument operands to the called function, this is required by the 374 /// call interface. 375 Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } 376 377 /// Get the argument operands to the called function as a mutable range, this is 378 /// required by the call interface. 379 MutableOperandRange GenericCallOp::getArgOperandsMutable() { 380 return getInputsMutable(); 381 } 382 383 //===----------------------------------------------------------------------===// 384 // MulOp 385 //===----------------------------------------------------------------------===// 386 387 void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, 388 mlir::Value lhs, mlir::Value rhs) { 389 state.addTypes(UnrankedTensorType::get(builder.getF64Type())); 390 state.addOperands({lhs, rhs}); 391 } 392 393 mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser, 394 mlir::OperationState &result) { 395 return parseBinaryOp(parser, result); 396 } 397 398 void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } 399 400 /// Infer the output shape of the MulOp, this is required by the shape inference 401 /// interface. 402 void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } 403 404 //===----------------------------------------------------------------------===// 405 // ReturnOp 406 //===----------------------------------------------------------------------===// 407 408 llvm::LogicalResult ReturnOp::verify() { 409 // We know that the parent operation is a function, because of the 'HasParent' 410 // trait attached to the operation definition. 411 auto function = cast<FuncOp>((*this)->getParentOp()); 412 413 /// ReturnOps can only have a single optional operand. 414 if (getNumOperands() > 1) 415 return emitOpError() << "expects at most 1 return operand"; 416 417 // The operand number and types must match the function signature. 418 const auto &results = function.getFunctionType().getResults(); 419 if (getNumOperands() != results.size()) 420 return emitOpError() << "does not return the same number of values (" 421 << getNumOperands() << ") as the enclosing function (" 422 << results.size() << ")"; 423 424 // If the operation does not have an input, we are done. 425 if (!hasOperand()) 426 return mlir::success(); 427 428 auto inputType = *operand_type_begin(); 429 auto resultType = results.front(); 430 431 // Check that the result type of the function matches the operand type. 432 if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) || 433 llvm::isa<mlir::UnrankedTensorType>(resultType)) 434 return mlir::success(); 435 436 return emitError() << "type of return operand (" << inputType 437 << ") doesn't match function result type (" << resultType 438 << ")"; 439 } 440 441 //===----------------------------------------------------------------------===// 442 // StructAccessOp 443 //===----------------------------------------------------------------------===// 444 445 void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state, 446 mlir::Value input, size_t index) { 447 // Extract the result type from the input type. 448 StructType structTy = llvm::cast<StructType>(input.getType()); 449 assert(index < structTy.getNumElementTypes()); 450 mlir::Type resultType = structTy.getElementTypes()[index]; 451 452 // Call into the auto-generated build method. 453 build(b, state, resultType, input, b.getI64IntegerAttr(index)); 454 } 455 456 llvm::LogicalResult StructAccessOp::verify() { 457 StructType structTy = llvm::cast<StructType>(getInput().getType()); 458 size_t indexValue = getIndex(); 459 if (indexValue >= structTy.getNumElementTypes()) 460 return emitOpError() 461 << "index should be within the range of the input struct type"; 462 mlir::Type resultType = getResult().getType(); 463 if (resultType != structTy.getElementTypes()[indexValue]) 464 return emitOpError() << "must have the same result type as the struct " 465 "element referred to by the index"; 466 return mlir::success(); 467 } 468 469 //===----------------------------------------------------------------------===// 470 // TransposeOp 471 //===----------------------------------------------------------------------===// 472 473 void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, 474 mlir::Value value) { 475 state.addTypes(UnrankedTensorType::get(builder.getF64Type())); 476 state.addOperands(value); 477 } 478 479 void TransposeOp::inferShapes() { 480 auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType()); 481 SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape())); 482 getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); 483 } 484 485 llvm::LogicalResult TransposeOp::verify() { 486 auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType()); 487 auto resultType = llvm::dyn_cast<RankedTensorType>(getType()); 488 if (!inputType || !resultType) 489 return mlir::success(); 490 491 auto inputShape = inputType.getShape(); 492 if (!std::equal(inputShape.begin(), inputShape.end(), 493 resultType.getShape().rbegin())) { 494 return emitError() 495 << "expected result shape to be a transpose of the input"; 496 } 497 return mlir::success(); 498 } 499 500 //===----------------------------------------------------------------------===// 501 // Toy Types 502 //===----------------------------------------------------------------------===// 503 504 namespace mlir { 505 namespace toy { 506 namespace detail { 507 /// This class represents the internal storage of the Toy `StructType`. 508 struct StructTypeStorage : public mlir::TypeStorage { 509 /// The `KeyTy` is a required type that provides an interface for the storage 510 /// instance. This type will be used when uniquing an instance of the type 511 /// storage. For our struct type, we will unique each instance structurally on 512 /// the elements that it contains. 513 using KeyTy = llvm::ArrayRef<mlir::Type>; 514 515 /// A constructor for the type storage instance. 516 StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes) 517 : elementTypes(elementTypes) {} 518 519 /// Define the comparison function for the key type with the current storage 520 /// instance. This is used when constructing a new instance to ensure that we 521 /// haven't already uniqued an instance of the given key. 522 bool operator==(const KeyTy &key) const { return key == elementTypes; } 523 524 /// Define a hash function for the key type. This is used when uniquing 525 /// instances of the storage, see the `StructType::get` method. 526 /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type 527 /// have hash functions available, so we could just omit this entirely. 528 static llvm::hash_code hashKey(const KeyTy &key) { 529 return llvm::hash_value(key); 530 } 531 532 /// Define a construction function for the key type from a set of parameters. 533 /// These parameters will be provided when constructing the storage instance 534 /// itself. 535 /// Note: This method isn't necessary because KeyTy can be directly 536 /// constructed with the given parameters. 537 static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) { 538 return KeyTy(elementTypes); 539 } 540 541 /// Define a construction method for creating a new instance of this storage. 542 /// This method takes an instance of a storage allocator, and an instance of a 543 /// `KeyTy`. The given allocator must be used for *all* necessary dynamic 544 /// allocations used to create the type storage and its internal. 545 static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator, 546 const KeyTy &key) { 547 // Copy the elements from the provided `KeyTy` into the allocator. 548 llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key); 549 550 // Allocate the storage instance and construct it. 551 return new (allocator.allocate<StructTypeStorage>()) 552 StructTypeStorage(elementTypes); 553 } 554 555 /// The following field contains the element types of the struct. 556 llvm::ArrayRef<mlir::Type> elementTypes; 557 }; 558 } // namespace detail 559 } // namespace toy 560 } // namespace mlir 561 562 /// Create an instance of a `StructType` with the given element types. There 563 /// *must* be at least one element type. 564 StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) { 565 assert(!elementTypes.empty() && "expected at least 1 element type"); 566 567 // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance 568 // of this type. The first parameter is the context to unique in. The 569 // parameters after the context are forwarded to the storage instance. 570 mlir::MLIRContext *ctx = elementTypes.front().getContext(); 571 return Base::get(ctx, elementTypes); 572 } 573 574 /// Returns the element types of this struct type. 575 llvm::ArrayRef<mlir::Type> StructType::getElementTypes() { 576 // 'getImpl' returns a pointer to the internal storage instance. 577 return getImpl()->elementTypes; 578 } 579 580 /// Parse an instance of a type registered to the toy dialect. 581 mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const { 582 // Parse a struct type in the following form: 583 // struct-type ::= `struct` `<` type (`,` type)* `>` 584 585 // NOTE: All MLIR parser function return a ParseResult. This is a 586 // specialization of LogicalResult that auto-converts to a `true` boolean 587 // value on failure to allow for chaining, but may be used with explicit 588 // `mlir::failed/mlir::succeeded` as desired. 589 590 // Parse: `struct` `<` 591 if (parser.parseKeyword("struct") || parser.parseLess()) 592 return Type(); 593 594 // Parse the element types of the struct. 595 SmallVector<mlir::Type, 1> elementTypes; 596 do { 597 // Parse the current element type. 598 SMLoc typeLoc = parser.getCurrentLocation(); 599 mlir::Type elementType; 600 if (parser.parseType(elementType)) 601 return nullptr; 602 603 // Check that the type is either a TensorType or another StructType. 604 if (!llvm::isa<mlir::TensorType, StructType>(elementType)) { 605 parser.emitError(typeLoc, "element type for a struct must either " 606 "be a TensorType or a StructType, got: ") 607 << elementType; 608 return Type(); 609 } 610 elementTypes.push_back(elementType); 611 612 // Parse the optional: `,` 613 } while (succeeded(parser.parseOptionalComma())); 614 615 // Parse: `>` 616 if (parser.parseGreater()) 617 return Type(); 618 return StructType::get(elementTypes); 619 } 620 621 /// Print an instance of a type registered to the toy dialect. 622 void ToyDialect::printType(mlir::Type type, 623 mlir::DialectAsmPrinter &printer) const { 624 // Currently the only toy type is a struct type. 625 StructType structType = llvm::cast<StructType>(type); 626 627 // Print the struct type according to the parser format. 628 printer << "struct<"; 629 llvm::interleaveComma(structType.getElementTypes(), printer); 630 printer << '>'; 631 } 632 633 //===----------------------------------------------------------------------===// 634 // TableGen'd op method definitions 635 //===----------------------------------------------------------------------===// 636 637 #define GET_OP_CLASSES 638 #include "toy/Ops.cpp.inc" 639 640 //===----------------------------------------------------------------------===// 641 // ToyDialect 642 //===----------------------------------------------------------------------===// 643 644 /// Dialect initialization, the instance will be owned by the context. This is 645 /// the point of registration of types and operations for the dialect. 646 void ToyDialect::initialize() { 647 addOperations< 648 #define GET_OP_LIST 649 #include "toy/Ops.cpp.inc" 650 >(); 651 addInterfaces<ToyInlinerInterface>(); 652 addTypes<StructType>(); 653 } 654 655 mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, 656 mlir::Attribute value, 657 mlir::Type type, 658 mlir::Location loc) { 659 if (llvm::isa<StructType>(type)) 660 return builder.create<StructConstantOp>(loc, type, 661 llvm::cast<mlir::ArrayAttr>(value)); 662 return builder.create<ConstantOp>(loc, type, 663 llvm::cast<mlir::DenseElementsAttr>(value)); 664 } 665