1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpDefinitionsGen uses the description of operations to generate C++ 10 // definitions for ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "OpClass.h" 15 #include "OpFormatGen.h" 16 #include "OpGenHelpers.h" 17 #include "mlir/TableGen/Argument.h" 18 #include "mlir/TableGen/Attribute.h" 19 #include "mlir/TableGen/Class.h" 20 #include "mlir/TableGen/CodeGenHelpers.h" 21 #include "mlir/TableGen/Format.h" 22 #include "mlir/TableGen/GenInfo.h" 23 #include "mlir/TableGen/Interfaces.h" 24 #include "mlir/TableGen/Operator.h" 25 #include "mlir/TableGen/Property.h" 26 #include "mlir/TableGen/SideEffects.h" 27 #include "mlir/TableGen/Trait.h" 28 #include "llvm/ADT/BitVector.h" 29 #include "llvm/ADT/MapVector.h" 30 #include "llvm/ADT/Sequence.h" 31 #include "llvm/ADT/SmallVector.h" 32 #include "llvm/ADT/StringExtras.h" 33 #include "llvm/ADT/StringSet.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Support/ErrorHandling.h" 36 #include "llvm/Support/Signals.h" 37 #include "llvm/Support/raw_ostream.h" 38 #include "llvm/TableGen/Error.h" 39 #include "llvm/TableGen/Record.h" 40 #include "llvm/TableGen/TableGenBackend.h" 41 42 #define DEBUG_TYPE "mlir-tblgen-opdefgen" 43 44 using namespace llvm; 45 using namespace mlir; 46 using namespace mlir::tblgen; 47 48 static const char *const tblgenNamePrefix = "tblgen_"; 49 static const char *const generatedArgName = "odsArg"; 50 static const char *const odsBuilder = "odsBuilder"; 51 static const char *const builderOpState = "odsState"; 52 static const char *const propertyStorage = "propStorage"; 53 static const char *const propertyValue = "propValue"; 54 static const char *const propertyAttr = "propAttr"; 55 static const char *const propertyDiag = "emitError"; 56 57 /// The names of the implicit attributes that contain variadic operand and 58 /// result segment sizes. 59 static const char *const operandSegmentAttrName = "operandSegmentSizes"; 60 static const char *const resultSegmentAttrName = "resultSegmentSizes"; 61 62 /// Code for an Op to lookup an attribute. Uses cached identifiers and subrange 63 /// lookup. 64 /// 65 /// {0}: Code snippet to get the attribute's name or identifier. 66 /// {1}: The lower bound on the sorted subrange. 67 /// {2}: The upper bound on the sorted subrange. 68 /// {3}: Code snippet to get the array of named attributes. 69 /// {4}: "Named" to get the named attribute. 70 static const char *const subrangeGetAttr = 71 "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - " 72 "{2}, {0})"; 73 74 /// The logic to calculate the actual value range for a declared operand/result 75 /// of an op with variadic operands/results. Note that this logic is not for 76 /// general use; it assumes all variadic operands/results must have the same 77 /// number of values. 78 /// 79 /// {0}: The list of whether each declared operand/result is variadic. 80 /// {1}: The total number of non-variadic operands/results. 81 /// {2}: The total number of variadic operands/results. 82 /// {3}: The total number of actual values. 83 /// {4}: "operand" or "result". 84 static const char *const sameVariadicSizeValueRangeCalcCode = R"( 85 bool isVariadic[] = {{{0}}; 86 int prevVariadicCount = 0; 87 for (unsigned i = 0; i < index; ++i) 88 if (isVariadic[i]) ++prevVariadicCount; 89 90 // Calculate how many dynamic values a static variadic {4} corresponds to. 91 // This assumes all static variadic {4}s have the same dynamic value count. 92 int variadicSize = ({3} - {1}) / {2}; 93 // `index` passed in as the parameter is the static index which counts each 94 // {4} (variadic or not) as size 1. So here for each previous static variadic 95 // {4}, we need to offset by (variadicSize - 1) to get where the dynamic 96 // value pack for this static {4} starts. 97 int start = index + (variadicSize - 1) * prevVariadicCount; 98 int size = isVariadic[index] ? variadicSize : 1; 99 return {{start, size}; 100 )"; 101 102 /// The logic to calculate the actual value range for a declared operand/result 103 /// of an op with variadic operands/results. Note that this logic is assumes 104 /// the op has an attribute specifying the size of each operand/result segment 105 /// (variadic or not). 106 static const char *const attrSizedSegmentValueRangeCalcCode = R"( 107 unsigned start = 0; 108 for (unsigned i = 0; i < index; ++i) 109 start += sizeAttr[i]; 110 return {start, sizeAttr[index]}; 111 )"; 112 /// The code snippet to initialize the sizes for the value range calculation. 113 /// 114 /// {0}: The code to get the attribute. 115 static const char *const adapterSegmentSizeAttrInitCode = R"( 116 assert({0} && "missing segment size attribute for op"); 117 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0}); 118 )"; 119 static const char *const adapterSegmentSizeAttrInitCodeProperties = R"( 120 ::llvm::ArrayRef<int32_t> sizeAttr = {0}; 121 )"; 122 123 /// The code snippet to initialize the sizes for the value range calculation. 124 /// 125 /// {0}: The code to get the attribute. 126 static const char *const opSegmentSizeAttrInitCode = R"( 127 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0}); 128 )"; 129 130 /// The logic to calculate the actual value range for a declared operand 131 /// of an op with variadic of variadic operands within the OpAdaptor. 132 /// 133 /// {0}: The name of the segment attribute. 134 /// {1}: The index of the main operand. 135 /// {2}: The range type of adaptor. 136 static const char *const variadicOfVariadicAdaptorCalcCode = R"( 137 auto tblgenTmpOperands = getODSOperands({1}); 138 auto sizes = {0}(); 139 140 ::llvm::SmallVector<{2}> tblgenTmpOperandGroups; 141 for (int i = 0, e = sizes.size(); i < e; ++i) {{ 142 tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(sizes[i])); 143 tblgenTmpOperands = tblgenTmpOperands.drop_front(sizes[i]); 144 } 145 return tblgenTmpOperandGroups; 146 )"; 147 148 /// The logic to build a range of either operand or result values. 149 /// 150 /// {0}: The begin iterator of the actual values. 151 /// {1}: The call to generate the start and length of the value range. 152 static const char *const valueRangeReturnCode = R"( 153 auto valueRange = {1}; 154 return {{std::next({0}, valueRange.first), 155 std::next({0}, valueRange.first + valueRange.second)}; 156 )"; 157 158 /// Parse operand/result segment_size property. 159 /// {0}: Number of elements in the segment array 160 static const char *const parseTextualSegmentSizeFormat = R"( 161 size_t i = 0; 162 auto parseElem = [&]() -> ::mlir::ParseResult { 163 if (i >= {0}) 164 return $_parser.emitError($_parser.getCurrentLocation(), 165 "expected `]` after {0} segment sizes"); 166 if (failed($_parser.parseInteger($_storage[i]))) 167 return ::mlir::failure(); 168 i += 1; 169 return ::mlir::success(); 170 }; 171 if (failed($_parser.parseCommaSeparatedList( 172 ::mlir::AsmParser::Delimeter::Square, parseElem))) 173 return failure(); 174 if (i < {0}) 175 return $_parser.emitError($_parser.getCurrentLocation(), 176 "expected {0} segment sizes, found only ") << i; 177 return success(); 178 )"; 179 180 static const char *const printTextualSegmentSize = R"( 181 [&]() { 182 $_printer << '['; 183 ::llvm::interleaveComma($_storage, $_printer); 184 $_printer << ']'; 185 }() 186 )"; 187 188 /// Read operand/result segment_size from bytecode. 189 static const char *const readBytecodeSegmentSizeNative = R"( 190 if ($_reader.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6) 191 return $_reader.readSparseArray(::llvm::MutableArrayRef($_storage)); 192 )"; 193 194 static const char *const readBytecodeSegmentSizeLegacy = R"( 195 if ($_reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) { 196 auto &$_storage = prop.$_propName; 197 ::mlir::DenseI32ArrayAttr attr; 198 if (::mlir::failed($_reader.readAttribute(attr))) return ::mlir::failure(); 199 if (attr.size() > static_cast<int64_t>(sizeof($_storage) / sizeof(int32_t))) { 200 $_reader.emitError("size mismatch for operand/result_segment_size"); 201 return ::mlir::failure(); 202 } 203 ::llvm::copy(::llvm::ArrayRef<int32_t>(attr), $_storage.begin()); 204 } 205 )"; 206 207 /// Write operand/result segment_size to bytecode. 208 static const char *const writeBytecodeSegmentSizeNative = R"( 209 if ($_writer.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6) 210 $_writer.writeSparseArray(::llvm::ArrayRef($_storage)); 211 )"; 212 213 /// Write operand/result segment_size to bytecode. 214 static const char *const writeBytecodeSegmentSizeLegacy = R"( 215 if ($_writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) { 216 auto &$_storage = prop.$_propName; 217 $_writer.writeAttribute(::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage)); 218 } 219 )"; 220 221 /// A header for indicating code sections. 222 /// 223 /// {0}: Some text, or a class name. 224 /// {1}: Some text. 225 static const char *const opCommentHeader = R"( 226 //===----------------------------------------------------------------------===// 227 // {0} {1} 228 //===----------------------------------------------------------------------===// 229 230 )"; 231 232 //===----------------------------------------------------------------------===// 233 // Utility structs and functions 234 //===----------------------------------------------------------------------===// 235 236 // Replaces all occurrences of `match` in `str` with `substitute`. 237 static std::string replaceAllSubstrs(std::string str, const std::string &match, 238 const std::string &substitute) { 239 std::string::size_type scanLoc = 0, matchLoc = std::string::npos; 240 while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) { 241 str = str.replace(matchLoc, match.size(), substitute); 242 scanLoc = matchLoc + substitute.size(); 243 } 244 return str; 245 } 246 247 // Returns whether the record has a value of the given name that can be returned 248 // via getValueAsString. 249 static inline bool hasStringAttribute(const Record &record, 250 StringRef fieldName) { 251 auto *valueInit = record.getValueInit(fieldName); 252 return isa<StringInit>(valueInit); 253 } 254 255 static std::string getArgumentName(const Operator &op, int index) { 256 const auto &operand = op.getOperand(index); 257 if (!operand.name.empty()) 258 return std::string(operand.name); 259 return std::string(formatv("{0}_{1}", generatedArgName, index)); 260 } 261 262 // Returns true if we can use unwrapped value for the given `attr` in builders. 263 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) { 264 return attr.getReturnType() != attr.getStorageType() && 265 // We need to wrap the raw value into an attribute in the builder impl 266 // so we need to make sure that the attribute specifies how to do that. 267 !attr.getConstBuilderTemplate().empty(); 268 } 269 270 /// Build an attribute from a parameter value using the constant builder. 271 static std::string constBuildAttrFromParam(const tblgen::Attribute &attr, 272 FmtContext &fctx, 273 StringRef paramName) { 274 std::string builderTemplate = attr.getConstBuilderTemplate().str(); 275 276 // For StringAttr, its constant builder call will wrap the input in 277 // quotes, which is correct for normal string literals, but incorrect 278 // here given we use function arguments. So we need to strip the 279 // wrapping quotes. 280 if (StringRef(builderTemplate).contains("\"$0\"")) 281 builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); 282 283 return tgfmt(builderTemplate, &fctx, paramName).str(); 284 } 285 286 namespace { 287 /// Metadata on a registered attribute. Given that attributes are stored in 288 /// sorted order on operations, we can use information from ODS to deduce the 289 /// number of required attributes less and and greater than each attribute, 290 /// allowing us to search only a subrange of the attributes in ODS-generated 291 /// getters. 292 struct AttributeMetadata { 293 /// The attribute name. 294 StringRef attrName; 295 /// Whether the attribute is required. 296 bool isRequired; 297 /// The ODS attribute constraint. Not present for implicit attributes. 298 std::optional<Attribute> constraint; 299 /// The number of required attributes less than this attribute. 300 unsigned lowerBound = 0; 301 /// The number of required attributes greater than this attribute. 302 unsigned upperBound = 0; 303 }; 304 305 /// Helper class to select between OpAdaptor and Op code templates. 306 class OpOrAdaptorHelper { 307 public: 308 OpOrAdaptorHelper(const Operator &op, bool emitForOp) 309 : op(op), emitForOp(emitForOp) { 310 computeAttrMetadata(); 311 } 312 313 /// Object that wraps a functor in a stream operator for interop with 314 /// llvm::formatv. 315 class Formatter { 316 public: 317 template <typename Functor> 318 Formatter(Functor &&func) : func(std::forward<Functor>(func)) {} 319 320 std::string str() const { 321 std::string result; 322 llvm::raw_string_ostream os(result); 323 os << *this; 324 return os.str(); 325 } 326 327 private: 328 std::function<raw_ostream &(raw_ostream &)> func; 329 330 friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) { 331 return fmt.func(os); 332 } 333 }; 334 335 // Generate code for getting an attribute. 336 Formatter getAttr(StringRef attrName, bool isNamed = false) const { 337 assert(attrMetadata.count(attrName) && "expected attribute metadata"); 338 return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & { 339 const AttributeMetadata &attr = attrMetadata.find(attrName)->second; 340 if (hasProperties()) { 341 assert(!isNamed); 342 return os << "getProperties()." << attrName; 343 } 344 return os << formatv(subrangeGetAttr, getAttrName(attrName), 345 attr.lowerBound, attr.upperBound, getAttrRange(), 346 isNamed ? "Named" : ""); 347 }; 348 } 349 350 // Generate code for getting the name of an attribute. 351 Formatter getAttrName(StringRef attrName) const { 352 return [this, attrName](raw_ostream &os) -> raw_ostream & { 353 if (emitForOp) 354 return os << op.getGetterName(attrName) << "AttrName()"; 355 return os << formatv("{0}::{1}AttrName(*odsOpName)", op.getCppClassName(), 356 op.getGetterName(attrName)); 357 }; 358 } 359 360 // Get the code snippet for getting the named attribute range. 361 StringRef getAttrRange() const { 362 return emitForOp ? "(*this)->getAttrs()" : "odsAttrs"; 363 } 364 365 // Get the prefix code for emitting an error. 366 Formatter emitErrorPrefix() const { 367 return [this](raw_ostream &os) -> raw_ostream & { 368 if (emitForOp) 369 return os << "emitOpError("; 370 return os << formatv("emitError(loc, \"'{0}' op \"", 371 op.getOperationName()); 372 }; 373 } 374 375 // Get the call to get an operand or segment of operands. 376 Formatter getOperand(unsigned index) const { 377 return [this, index](raw_ostream &os) -> raw_ostream & { 378 return os << formatv(op.getOperand(index).isVariadic() 379 ? "this->getODSOperands({0})" 380 : "(*this->getODSOperands({0}).begin())", 381 index); 382 }; 383 } 384 385 // Get the call to get a result of segment of results. 386 Formatter getResult(unsigned index) const { 387 return [this, index](raw_ostream &os) -> raw_ostream & { 388 if (!emitForOp) 389 return os << "<no results should be generated>"; 390 return os << formatv(op.getResult(index).isVariadic() 391 ? "this->getODSResults({0})" 392 : "(*this->getODSResults({0}).begin())", 393 index); 394 }; 395 } 396 397 // Return whether an op instance is available. 398 bool isEmittingForOp() const { return emitForOp; } 399 400 // Return the ODS operation wrapper. 401 const Operator &getOp() const { return op; } 402 403 // Get the attribute metadata sorted by name. 404 const llvm::MapVector<StringRef, AttributeMetadata> &getAttrMetadata() const { 405 return attrMetadata; 406 } 407 408 /// Returns whether to emit a `Properties` struct for this operation or not. 409 bool hasProperties() const { 410 if (!op.getProperties().empty()) 411 return true; 412 if (!op.getDialect().usePropertiesForAttributes()) 413 return false; 414 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments") || 415 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) 416 return true; 417 return llvm::any_of(getAttrMetadata(), 418 [](const std::pair<StringRef, AttributeMetadata> &it) { 419 return !it.second.constraint || 420 !it.second.constraint->isDerivedAttr(); 421 }); 422 } 423 424 std::optional<NamedProperty> &getOperandSegmentsSize() { 425 return operandSegmentsSize; 426 } 427 428 std::optional<NamedProperty> &getResultSegmentsSize() { 429 return resultSegmentsSize; 430 } 431 432 uint32_t getOperandSegmentSizesLegacyIndex() { 433 return operandSegmentSizesLegacyIndex; 434 } 435 436 uint32_t getResultSegmentSizesLegacyIndex() { 437 return resultSegmentSizesLegacyIndex; 438 } 439 440 private: 441 // Compute the attribute metadata. 442 void computeAttrMetadata(); 443 444 // The operation ODS wrapper. 445 const Operator &op; 446 // True if code is being generate for an op. False for an adaptor. 447 const bool emitForOp; 448 449 // The attribute metadata, mapped by name. 450 llvm::MapVector<StringRef, AttributeMetadata> attrMetadata; 451 452 // Property 453 std::optional<NamedProperty> operandSegmentsSize; 454 std::string operandSegmentsSizeStorage; 455 std::string operandSegmentsSizeParser; 456 std::optional<NamedProperty> resultSegmentsSize; 457 std::string resultSegmentsSizeStorage; 458 std::string resultSegmentsSizeParser; 459 460 // Indices to store the position in the emission order of the operand/result 461 // segment sizes attribute if emitted as part of the properties for legacy 462 // bytecode encodings, i.e. versions less than 6. 463 uint32_t operandSegmentSizesLegacyIndex = 0; 464 uint32_t resultSegmentSizesLegacyIndex = 0; 465 466 // The number of required attributes. 467 unsigned numRequired; 468 }; 469 470 } // namespace 471 472 void OpOrAdaptorHelper::computeAttrMetadata() { 473 // Enumerate the attribute names of this op, ensuring the attribute names are 474 // unique in case implicit attributes are explicitly registered. 475 for (const NamedAttribute &namedAttr : op.getAttributes()) { 476 Attribute attr = namedAttr.attr; 477 bool isOptional = 478 attr.hasDefaultValue() || attr.isOptional() || attr.isDerivedAttr(); 479 attrMetadata.insert( 480 {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}}); 481 } 482 483 auto makeProperty = [&](StringRef storageType, StringRef parserCall) { 484 return Property( 485 /*summary=*/"", 486 /*description=*/"", 487 /*storageType=*/storageType, 488 /*interfaceType=*/"::llvm::ArrayRef<int32_t>", 489 /*convertFromStorageCall=*/"$_storage", 490 /*assignToStorageCall=*/ 491 "::llvm::copy($_value, $_storage.begin())", 492 /*convertToAttributeCall=*/ 493 "return ::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage);", 494 /*convertFromAttributeCall=*/ 495 "return convertFromAttribute($_storage, $_attr, $_diag);", 496 /*parserCall=*/parserCall, 497 /*optionalParserCall=*/"", 498 /*printerCall=*/printTextualSegmentSize, 499 /*readFromMlirBytecodeCall=*/readBytecodeSegmentSizeNative, 500 /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSizeNative, 501 /*hashPropertyCall=*/ 502 "::llvm::hash_combine_range(std::begin($_storage), " 503 "std::end($_storage));", 504 /*StringRef defaultValue=*/"", 505 /*storageTypeValueOverride=*/""); 506 }; 507 // Include key attributes from several traits as implicitly registered. 508 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 509 if (op.getDialect().usePropertiesForAttributes()) { 510 operandSegmentsSizeStorage = 511 llvm::formatv("std::array<int32_t, {0}>", op.getNumOperands()); 512 operandSegmentsSizeParser = 513 llvm::formatv(parseTextualSegmentSizeFormat, op.getNumOperands()); 514 operandSegmentsSize = { 515 "operandSegmentSizes", 516 makeProperty(operandSegmentsSizeStorage, operandSegmentsSizeParser)}; 517 } else { 518 attrMetadata.insert( 519 {operandSegmentAttrName, AttributeMetadata{operandSegmentAttrName, 520 /*isRequired=*/true, 521 /*attr=*/std::nullopt}}); 522 } 523 } 524 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { 525 if (op.getDialect().usePropertiesForAttributes()) { 526 resultSegmentsSizeStorage = 527 llvm::formatv("std::array<int32_t, {0}>", op.getNumResults()); 528 resultSegmentsSizeParser = 529 llvm::formatv(parseTextualSegmentSizeFormat, op.getNumResults()); 530 resultSegmentsSize = { 531 "resultSegmentSizes", 532 makeProperty(resultSegmentsSizeStorage, resultSegmentsSizeParser)}; 533 } else { 534 attrMetadata.insert( 535 {resultSegmentAttrName, 536 AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true, 537 /*attr=*/std::nullopt}}); 538 } 539 } 540 541 // Store the metadata in sorted order. 542 SmallVector<AttributeMetadata> sortedAttrMetadata = 543 llvm::to_vector(llvm::make_second_range(attrMetadata.takeVector())); 544 llvm::sort(sortedAttrMetadata, 545 [](const AttributeMetadata &lhs, const AttributeMetadata &rhs) { 546 return lhs.attrName < rhs.attrName; 547 }); 548 549 // Store the position of the legacy operand_segment_sizes / 550 // result_segment_sizes so we can emit a backward compatible property readers 551 // and writers. 552 StringRef legacyOperandSegmentSizeName = 553 StringLiteral("operand_segment_sizes"); 554 StringRef legacyResultSegmentSizeName = StringLiteral("result_segment_sizes"); 555 operandSegmentSizesLegacyIndex = 0; 556 resultSegmentSizesLegacyIndex = 0; 557 for (auto item : sortedAttrMetadata) { 558 if (item.attrName < legacyOperandSegmentSizeName) 559 ++operandSegmentSizesLegacyIndex; 560 if (item.attrName < legacyResultSegmentSizeName) 561 ++resultSegmentSizesLegacyIndex; 562 } 563 564 // Compute the subrange bounds for each attribute. 565 numRequired = 0; 566 for (AttributeMetadata &attr : sortedAttrMetadata) { 567 attr.lowerBound = numRequired; 568 numRequired += attr.isRequired; 569 }; 570 for (AttributeMetadata &attr : sortedAttrMetadata) 571 attr.upperBound = numRequired - attr.lowerBound - attr.isRequired; 572 573 // Store the results back into the map. 574 for (const AttributeMetadata &attr : sortedAttrMetadata) 575 attrMetadata.insert({attr.attrName, attr}); 576 } 577 578 //===----------------------------------------------------------------------===// 579 // Op emitter 580 //===----------------------------------------------------------------------===// 581 582 namespace { 583 // Helper class to emit a record into the given output stream. 584 class OpEmitter { 585 using ConstArgument = 586 llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>; 587 588 public: 589 static void 590 emitDecl(const Operator &op, raw_ostream &os, 591 const StaticVerifierFunctionEmitter &staticVerifierEmitter); 592 static void 593 emitDef(const Operator &op, raw_ostream &os, 594 const StaticVerifierFunctionEmitter &staticVerifierEmitter); 595 596 private: 597 OpEmitter(const Operator &op, 598 const StaticVerifierFunctionEmitter &staticVerifierEmitter); 599 600 void emitDecl(raw_ostream &os); 601 void emitDef(raw_ostream &os); 602 603 // Generate methods for accessing the attribute names of this operation. 604 void genAttrNameGetters(); 605 606 // Generates the OpAsmOpInterface for this operation if possible. 607 void genOpAsmInterface(); 608 609 // Generates the `getOperationName` method for this op. 610 void genOpNameGetter(); 611 612 // Generates code to manage the properties, if any! 613 void genPropertiesSupport(); 614 615 // Generates code to manage the encoding of properties to bytecode. 616 void 617 genPropertiesSupportForBytecode(ArrayRef<ConstArgument> attrOrProperties); 618 619 // Generates getters for the properties. 620 void genPropGetters(); 621 622 // Generates seters for the properties. 623 void genPropSetters(); 624 625 // Generates getters for the attributes. 626 void genAttrGetters(); 627 628 // Generates setter for the attributes. 629 void genAttrSetters(); 630 631 // Generates removers for optional attributes. 632 void genOptionalAttrRemovers(); 633 634 // Generates getters for named operands. 635 void genNamedOperandGetters(); 636 637 // Generates setters for named operands. 638 void genNamedOperandSetters(); 639 640 // Generates getters for named results. 641 void genNamedResultGetters(); 642 643 // Generates getters for named regions. 644 void genNamedRegionGetters(); 645 646 // Generates getters for named successors. 647 void genNamedSuccessorGetters(); 648 649 // Generates the method to populate default attributes. 650 void genPopulateDefaultAttributes(); 651 652 // Generates builder methods for the operation. 653 void genBuilder(); 654 655 // Generates the build() method that takes each operand/attribute 656 // as a stand-alone parameter. 657 void genSeparateArgParamBuilder(); 658 659 // Generates the build() method that takes each operand/attribute as a 660 // stand-alone parameter. The generated build() method uses first operand's 661 // type as all results' types. 662 void genUseOperandAsResultTypeSeparateParamBuilder(); 663 664 // Generates the build() method that takes all operands/attributes 665 // collectively as one parameter. The generated build() method uses first 666 // operand's type as all results' types. 667 void genUseOperandAsResultTypeCollectiveParamBuilder(); 668 669 // Generates the build() method that takes aggregate operands/attributes 670 // parameters. This build() method uses inferred types as result types. 671 // Requires: The type needs to be inferable via InferTypeOpInterface. 672 void genInferredTypeCollectiveParamBuilder(); 673 674 // Generates the build() method that takes each operand/attribute as a 675 // stand-alone parameter. The generated build() method uses first attribute's 676 // type as all result's types. 677 void genUseAttrAsResultTypeBuilder(); 678 679 // Generates the build() method that takes all result types collectively as 680 // one parameter. Similarly for operands and attributes. 681 void genCollectiveParamBuilder(); 682 683 // The kind of parameter to generate for result types in builders. 684 enum class TypeParamKind { 685 None, // No result type in parameter list. 686 Separate, // A separate parameter for each result type. 687 Collective, // An ArrayRef<Type> for all result types. 688 }; 689 690 // The kind of parameter to generate for attributes in builders. 691 enum class AttrParamKind { 692 WrappedAttr, // A wrapped MLIR Attribute instance. 693 UnwrappedValue, // A raw value without MLIR Attribute wrapper. 694 }; 695 696 // Builds the parameter list for build() method of this op. This method writes 697 // to `paramList` the comma-separated parameter list and updates 698 // `resultTypeNames` with the names for parameters for specifying result 699 // types. `inferredAttributes` is populated with any attributes that are 700 // elided from the build list. The given `typeParamKind` and `attrParamKind` 701 // controls how result types and attributes are placed in the parameter list. 702 void buildParamList(SmallVectorImpl<MethodParameter> ¶mList, 703 llvm::StringSet<> &inferredAttributes, 704 SmallVectorImpl<std::string> &resultTypeNames, 705 TypeParamKind typeParamKind, 706 AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); 707 708 // Adds op arguments and regions into operation state for build() methods. 709 void 710 genCodeForAddingArgAndRegionForBuilder(MethodBody &body, 711 llvm::StringSet<> &inferredAttributes, 712 bool isRawValueAttr = false); 713 714 // Generates canonicalizer declaration for the operation. 715 void genCanonicalizerDecls(); 716 717 // Generates the folder declaration for the operation. 718 void genFolderDecls(); 719 720 // Generates the parser for the operation. 721 void genParser(); 722 723 // Generates the printer for the operation. 724 void genPrinter(); 725 726 // Generates verify method for the operation. 727 void genVerifier(); 728 729 // Generates custom verify methods for the operation. 730 void genCustomVerifier(); 731 732 // Generates verify statements for operands and results in the operation. 733 // The generated code will be attached to `body`. 734 void genOperandResultVerifier(MethodBody &body, 735 Operator::const_value_range values, 736 StringRef valueKind); 737 738 // Generates verify statements for regions in the operation. 739 // The generated code will be attached to `body`. 740 void genRegionVerifier(MethodBody &body); 741 742 // Generates verify statements for successors in the operation. 743 // The generated code will be attached to `body`. 744 void genSuccessorVerifier(MethodBody &body); 745 746 // Generates the traits used by the object. 747 void genTraits(); 748 749 // Generate the OpInterface methods for all interfaces. 750 void genOpInterfaceMethods(); 751 752 // Generate op interface methods for the given interface. 753 void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait); 754 755 // Generate op interface method for the given interface method. If 756 // 'declaration' is true, generates a declaration, else a definition. 757 Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method, 758 bool declaration = true); 759 760 // Generate the side effect interface methods. 761 void genSideEffectInterfaceMethods(); 762 763 // Generate the type inference interface methods. 764 void genTypeInterfaceMethods(); 765 766 private: 767 // The TableGen record for this op. 768 // TODO: OpEmitter should not have a Record directly, 769 // it should rather go through the Operator for better abstraction. 770 const Record &def; 771 772 // The wrapper operator class for querying information from this op. 773 const Operator &op; 774 775 // The C++ code builder for this op 776 OpClass opClass; 777 778 // The format context for verification code generation. 779 FmtContext verifyCtx; 780 781 // The emitter containing all of the locally emitted verification functions. 782 const StaticVerifierFunctionEmitter &staticVerifierEmitter; 783 784 // Helper for emitting op code. 785 OpOrAdaptorHelper emitHelper; 786 }; 787 788 } // namespace 789 790 // Populate the format context `ctx` with substitutions of attributes, operands 791 // and results. 792 static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper, 793 FmtContext &ctx) { 794 // Populate substitutions for attributes. 795 auto &op = emitHelper.getOp(); 796 for (const auto &namedAttr : op.getAttributes()) 797 ctx.addSubst(namedAttr.name, 798 emitHelper.getOp().getGetterName(namedAttr.name) + "()"); 799 800 // Populate substitutions for named operands. 801 for (int i = 0, e = op.getNumOperands(); i < e; ++i) { 802 auto &value = op.getOperand(i); 803 if (!value.name.empty()) 804 ctx.addSubst(value.name, emitHelper.getOperand(i).str()); 805 } 806 807 // Populate substitutions for results. 808 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 809 auto &value = op.getResult(i); 810 if (!value.name.empty()) 811 ctx.addSubst(value.name, emitHelper.getResult(i).str()); 812 } 813 } 814 815 /// Generate verification on native traits requiring attributes. 816 static void genNativeTraitAttrVerifier(MethodBody &body, 817 const OpOrAdaptorHelper &emitHelper) { 818 // Check that the variadic segment sizes attribute exists and contains the 819 // expected number of elements. 820 // 821 // {0}: Attribute name. 822 // {1}: Expected number of elements. 823 // {2}: "operand" or "result". 824 // {3}: Emit error prefix. 825 const char *const checkAttrSizedValueSegmentsCode = R"( 826 { 827 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>(tblgen_{0}); 828 auto numElements = sizeAttr.asArrayRef().size(); 829 if (numElements != {1}) 830 return {3}"'{0}' attribute for specifying {2} segments must have {1} " 831 "elements, but got ") << numElements; 832 } 833 )"; 834 835 // Verify a few traits first so that we can use getODSOperands() and 836 // getODSResults() in the rest of the verifier. 837 auto &op = emitHelper.getOp(); 838 if (!op.getDialect().usePropertiesForAttributes()) { 839 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 840 body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName, 841 op.getNumOperands(), "operand", 842 emitHelper.emitErrorPrefix()); 843 } 844 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { 845 body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName, 846 op.getNumResults(), "result", 847 emitHelper.emitErrorPrefix()); 848 } 849 } 850 } 851 852 // Return true if a verifier can be emitted for the attribute: it is not a 853 // derived attribute, it has a predicate, its condition is not empty, and, for 854 // adaptors, the condition does not reference the op. 855 static bool canEmitAttrVerifier(Attribute attr, bool isEmittingForOp) { 856 if (attr.isDerivedAttr()) 857 return false; 858 Pred pred = attr.getPredicate(); 859 if (pred.isNull()) 860 return false; 861 std::string condition = pred.getCondition(); 862 return !condition.empty() && 863 (!StringRef(condition).contains("$_op") || isEmittingForOp); 864 } 865 866 // Generate attribute verification. If an op instance is not available, then 867 // attribute checks that require one will not be emitted. 868 // 869 // Attribute verification is performed as follows: 870 // 871 // 1. Verify that all required attributes are present in sorted order. This 872 // ensures that we can use subrange lookup even with potentially missing 873 // attributes. 874 // 2. Verify native trait attributes so that other attributes may call methods 875 // that depend on the validity of these attributes, e.g. segment size attributes 876 // and operand or result getters. 877 // 3. Verify the constraints on all present attributes. 878 static void 879 genAttributeVerifier(const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, 880 MethodBody &body, 881 const StaticVerifierFunctionEmitter &staticVerifierEmitter, 882 bool useProperties) { 883 if (emitHelper.getAttrMetadata().empty()) 884 return; 885 886 // Verify the attribute if it is present. This assumes that default values 887 // are valid. This code snippet pastes the condition inline. 888 // 889 // TODO: verify the default value is valid (perhaps in debug mode only). 890 // 891 // {0}: Attribute variable name. 892 // {1}: Attribute condition code. 893 // {2}: Emit error prefix. 894 // {3}: Attribute name. 895 // {4}: Attribute/constraint description. 896 const char *const verifyAttrInline = R"( 897 if ({0} && !({1})) 898 return {2}"attribute '{3}' failed to satisfy constraint: {4}"); 899 )"; 900 // Verify the attribute using a uniqued constraint. Can only be used within 901 // the context of an op. 902 // 903 // {0}: Unique constraint name. 904 // {1}: Attribute variable name. 905 // {2}: Attribute name. 906 const char *const verifyAttrUnique = R"( 907 if (::mlir::failed({0}(*this, {1}, "{2}"))) 908 return ::mlir::failure(); 909 )"; 910 911 // Traverse the array until the required attribute is found. Return an error 912 // if the traversal reached the end. 913 // 914 // {0}: Code to get the name of the attribute. 915 // {1}: The emit error prefix. 916 // {2}: The name of the attribute. 917 const char *const findRequiredAttr = R"( 918 while (true) {{ 919 if (namedAttrIt == namedAttrRange.end()) 920 return {1}"requires attribute '{2}'"); 921 if (namedAttrIt->getName() == {0}) {{ 922 tblgen_{2} = namedAttrIt->getValue(); 923 break; 924 })"; 925 926 // Emit a check to see if the iteration has encountered an optional attribute. 927 // 928 // {0}: Code to get the name of the attribute. 929 // {1}: The name of the attribute. 930 const char *const checkOptionalAttr = R"( 931 else if (namedAttrIt->getName() == {0}) {{ 932 tblgen_{1} = namedAttrIt->getValue(); 933 })"; 934 935 // Emit the start of the loop for checking trailing attributes. 936 const char *const checkTrailingAttrs = R"(while (true) { 937 if (namedAttrIt == namedAttrRange.end()) { 938 break; 939 })"; 940 941 // Emit the verifier for the attribute. 942 const auto emitVerifier = [&](Attribute attr, StringRef attrName, 943 StringRef varName) { 944 std::string condition = attr.getPredicate().getCondition(); 945 946 std::optional<StringRef> constraintFn; 947 if (emitHelper.isEmittingForOp() && 948 (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) { 949 body << formatv(verifyAttrUnique, *constraintFn, varName, attrName); 950 } else { 951 body << formatv(verifyAttrInline, varName, 952 tgfmt(condition, &ctx.withSelf(varName)), 953 emitHelper.emitErrorPrefix(), attrName, 954 escapeString(attr.getSummary())); 955 } 956 }; 957 958 // Prefix variables with `tblgen_` to avoid hiding the attribute accessor. 959 const auto getVarName = [&](StringRef attrName) { 960 return (tblgenNamePrefix + attrName).str(); 961 }; 962 963 body.indent(); 964 if (useProperties) { 965 for (const std::pair<StringRef, AttributeMetadata> &it : 966 emitHelper.getAttrMetadata()) { 967 const AttributeMetadata &metadata = it.second; 968 if (metadata.constraint && metadata.constraint->isDerivedAttr()) 969 continue; 970 body << formatv( 971 "auto tblgen_{0} = getProperties().{0}; (void)tblgen_{0};\n", 972 it.first); 973 if (metadata.isRequired) 974 body << formatv( 975 "if (!tblgen_{0}) return {1}\"requires attribute '{0}'\");\n", 976 it.first, emitHelper.emitErrorPrefix()); 977 } 978 } else { 979 body << formatv("auto namedAttrRange = {0};\n", emitHelper.getAttrRange()); 980 body << "auto namedAttrIt = namedAttrRange.begin();\n"; 981 982 // Iterate over the attributes in sorted order. Keep track of the optional 983 // attributes that may be encountered along the way. 984 SmallVector<const AttributeMetadata *> optionalAttrs; 985 986 for (const std::pair<StringRef, AttributeMetadata> &it : 987 emitHelper.getAttrMetadata()) { 988 const AttributeMetadata &metadata = it.second; 989 if (!metadata.isRequired) { 990 optionalAttrs.push_back(&metadata); 991 continue; 992 } 993 994 body << formatv("::mlir::Attribute {0};\n", getVarName(it.first)); 995 for (const AttributeMetadata *optional : optionalAttrs) { 996 body << formatv("::mlir::Attribute {0};\n", 997 getVarName(optional->attrName)); 998 } 999 body << formatv(findRequiredAttr, emitHelper.getAttrName(it.first), 1000 emitHelper.emitErrorPrefix(), it.first); 1001 for (const AttributeMetadata *optional : optionalAttrs) { 1002 body << formatv(checkOptionalAttr, 1003 emitHelper.getAttrName(optional->attrName), 1004 optional->attrName); 1005 } 1006 body << "\n ++namedAttrIt;\n}\n"; 1007 optionalAttrs.clear(); 1008 } 1009 // Get trailing optional attributes. 1010 if (!optionalAttrs.empty()) { 1011 for (const AttributeMetadata *optional : optionalAttrs) { 1012 body << formatv("::mlir::Attribute {0};\n", 1013 getVarName(optional->attrName)); 1014 } 1015 body << checkTrailingAttrs; 1016 for (const AttributeMetadata *optional : optionalAttrs) { 1017 body << formatv(checkOptionalAttr, 1018 emitHelper.getAttrName(optional->attrName), 1019 optional->attrName); 1020 } 1021 body << "\n ++namedAttrIt;\n}\n"; 1022 } 1023 } 1024 body.unindent(); 1025 1026 // Emit the checks for segment attributes first so that the other 1027 // constraints can call operand and result getters. 1028 genNativeTraitAttrVerifier(body, emitHelper); 1029 1030 bool isEmittingForOp = emitHelper.isEmittingForOp(); 1031 for (const auto &namedAttr : emitHelper.getOp().getAttributes()) 1032 if (canEmitAttrVerifier(namedAttr.attr, isEmittingForOp)) 1033 emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name)); 1034 } 1035 1036 static void genPropertyVerifier(const OpOrAdaptorHelper &emitHelper, 1037 FmtContext &ctx, MethodBody &body) { 1038 1039 // Code to get a reference to a property into a variable to avoid multiple 1040 // evaluations while verifying a property. 1041 // {0}: Property variable name. 1042 // {1}: Property name, with the first letter capitalized, to find the getter. 1043 // {2}: Property interface type. 1044 const char *const fetchProperty = R"( 1045 [[maybe_unused]] {2} {0} = this->get{1}(); 1046 )"; 1047 1048 // Code to verify that the predicate of a property holds. Embeds the 1049 // condition inline. 1050 // {0}: Property condition code, with tgfmt() applied. 1051 // {1}: Emit error prefix. 1052 // {2}: Property name. 1053 // {3}: Property description. 1054 const char *const verifyProperty = R"( 1055 if (!({0})) 1056 return {1}"property '{2}' failed to satisfy constraint: {3}"); 1057 )"; 1058 1059 // Prefix variables with `tblgen_` to avoid hiding the attribute accessor. 1060 const auto getVarName = [&](const NamedProperty &prop) { 1061 std::string varName = 1062 convertToCamelFromSnakeCase(prop.name, /*capitalizeFirst=*/false); 1063 return (tblgenNamePrefix + Twine(varName)).str(); 1064 }; 1065 1066 for (const NamedProperty &prop : emitHelper.getOp().getProperties()) { 1067 Pred predicate = prop.prop.getPredicate(); 1068 // Null predicate, nothing to verify. 1069 if (predicate == Pred()) 1070 continue; 1071 1072 std::string rawCondition = predicate.getCondition(); 1073 if (rawCondition == "true") 1074 continue; 1075 bool needsOp = StringRef(rawCondition).contains("$_op"); 1076 if (needsOp && !emitHelper.isEmittingForOp()) 1077 continue; 1078 1079 auto scope = body.scope("{\n", "}\n", /*indent=*/true); 1080 std::string varName = getVarName(prop); 1081 std::string getterName = 1082 convertToCamelFromSnakeCase(prop.name, /*capitalizeFirst=*/true); 1083 body << formatv(fetchProperty, varName, getterName, 1084 prop.prop.getInterfaceType()); 1085 body << formatv(verifyProperty, tgfmt(rawCondition, &ctx.withSelf(varName)), 1086 emitHelper.emitErrorPrefix(), prop.name, 1087 prop.prop.getSummary()); 1088 } 1089 } 1090 1091 /// Include declarations specified on NativeTrait 1092 static std::string formatExtraDeclarations(const Operator &op) { 1093 SmallVector<StringRef> extraDeclarations; 1094 // Include extra class declarations from NativeTrait 1095 for (const auto &trait : op.getTraits()) { 1096 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) { 1097 StringRef value = opTrait->getExtraConcreteClassDeclaration(); 1098 if (value.empty()) 1099 continue; 1100 extraDeclarations.push_back(value); 1101 } 1102 } 1103 extraDeclarations.push_back(op.getExtraClassDeclaration()); 1104 return llvm::join(extraDeclarations, "\n"); 1105 } 1106 1107 /// Op extra class definitions have a `$cppClass` substitution that is to be 1108 /// replaced by the C++ class name. 1109 /// Include declarations specified on NativeTrait 1110 static std::string formatExtraDefinitions(const Operator &op) { 1111 SmallVector<StringRef> extraDefinitions; 1112 // Include extra class definitions from NativeTrait 1113 for (const auto &trait : op.getTraits()) { 1114 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) { 1115 StringRef value = opTrait->getExtraConcreteClassDefinition(); 1116 if (value.empty()) 1117 continue; 1118 extraDefinitions.push_back(value); 1119 } 1120 } 1121 extraDefinitions.push_back(op.getExtraClassDefinition()); 1122 FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName()); 1123 return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str(); 1124 } 1125 1126 OpEmitter::OpEmitter(const Operator &op, 1127 const StaticVerifierFunctionEmitter &staticVerifierEmitter) 1128 : def(op.getDef()), op(op), 1129 opClass(op.getCppClassName(), formatExtraDeclarations(op), 1130 formatExtraDefinitions(op)), 1131 staticVerifierEmitter(staticVerifierEmitter), 1132 emitHelper(op, /*emitForOp=*/true) { 1133 verifyCtx.addSubst("_op", "(*this->getOperation())"); 1134 verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()"); 1135 1136 genTraits(); 1137 1138 // Generate C++ code for various op methods. The order here determines the 1139 // methods in the generated file. 1140 genAttrNameGetters(); 1141 genOpAsmInterface(); 1142 genOpNameGetter(); 1143 genNamedOperandGetters(); 1144 genNamedOperandSetters(); 1145 genNamedResultGetters(); 1146 genNamedRegionGetters(); 1147 genNamedSuccessorGetters(); 1148 genPropertiesSupport(); 1149 genPropGetters(); 1150 genPropSetters(); 1151 genAttrGetters(); 1152 genAttrSetters(); 1153 genOptionalAttrRemovers(); 1154 genBuilder(); 1155 genPopulateDefaultAttributes(); 1156 genParser(); 1157 genPrinter(); 1158 genVerifier(); 1159 genCustomVerifier(); 1160 genCanonicalizerDecls(); 1161 genFolderDecls(); 1162 genTypeInterfaceMethods(); 1163 genOpInterfaceMethods(); 1164 generateOpFormat(op, opClass, emitHelper.hasProperties()); 1165 genSideEffectInterfaceMethods(); 1166 } 1167 void OpEmitter::emitDecl( 1168 const Operator &op, raw_ostream &os, 1169 const StaticVerifierFunctionEmitter &staticVerifierEmitter) { 1170 OpEmitter(op, staticVerifierEmitter).emitDecl(os); 1171 } 1172 1173 void OpEmitter::emitDef( 1174 const Operator &op, raw_ostream &os, 1175 const StaticVerifierFunctionEmitter &staticVerifierEmitter) { 1176 OpEmitter(op, staticVerifierEmitter).emitDef(os); 1177 } 1178 1179 void OpEmitter::emitDecl(raw_ostream &os) { 1180 opClass.finalize(); 1181 opClass.writeDeclTo(os); 1182 } 1183 1184 void OpEmitter::emitDef(raw_ostream &os) { 1185 opClass.finalize(); 1186 opClass.writeDefTo(os); 1187 } 1188 1189 static void errorIfPruned(size_t line, Method *m, const Twine &methodName, 1190 const Operator &op) { 1191 if (m) 1192 return; 1193 PrintFatalError(op.getLoc(), "Unexpected overlap when generating `" + 1194 methodName + "` for " + 1195 op.getOperationName() + " (from line " + 1196 Twine(line) + ")"); 1197 } 1198 1199 #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O) 1200 1201 void OpEmitter::genAttrNameGetters() { 1202 const llvm::MapVector<StringRef, AttributeMetadata> &attributes = 1203 emitHelper.getAttrMetadata(); 1204 bool hasOperandSegmentsSize = 1205 op.getDialect().usePropertiesForAttributes() && 1206 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); 1207 // Emit the getAttributeNames method. 1208 { 1209 auto *method = opClass.addStaticInlineMethod( 1210 "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames"); 1211 ERROR_IF_PRUNED(method, "getAttributeNames", op); 1212 auto &body = method->body(); 1213 if (!hasOperandSegmentsSize && attributes.empty()) { 1214 body << " return {};"; 1215 // Nothing else to do if there are no registered attributes. Exit early. 1216 return; 1217 } 1218 body << " static ::llvm::StringRef attrNames[] = {"; 1219 llvm::interleaveComma(llvm::make_first_range(attributes), body, 1220 [&](StringRef attrName) { 1221 body << "::llvm::StringRef(\"" << attrName << "\")"; 1222 }); 1223 if (hasOperandSegmentsSize) { 1224 if (!attributes.empty()) 1225 body << ", "; 1226 body << "::llvm::StringRef(\"" << operandSegmentAttrName << "\")"; 1227 } 1228 body << "};\n return ::llvm::ArrayRef(attrNames);"; 1229 } 1230 1231 // Emit the getAttributeNameForIndex methods. 1232 { 1233 auto *method = opClass.addInlineMethod<Method::Private>( 1234 "::mlir::StringAttr", "getAttributeNameForIndex", 1235 MethodParameter("unsigned", "index")); 1236 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); 1237 method->body() 1238 << " return getAttributeNameForIndex((*this)->getName(), index);"; 1239 } 1240 { 1241 auto *method = opClass.addStaticInlineMethod<Method::Private>( 1242 "::mlir::StringAttr", "getAttributeNameForIndex", 1243 MethodParameter("::mlir::OperationName", "name"), 1244 MethodParameter("unsigned", "index")); 1245 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); 1246 1247 if (attributes.empty()) { 1248 method->body() << " return {};"; 1249 } else { 1250 const char *const getAttrName = R"( 1251 assert(index < {0} && "invalid attribute index"); 1252 assert(name.getStringRef() == getOperationName() && "invalid operation name"); 1253 assert(name.isRegistered() && "Operation isn't registered, missing a " 1254 "dependent dialect loading?"); 1255 return name.getAttributeNames()[index]; 1256 )"; 1257 method->body() << formatv(getAttrName, attributes.size()); 1258 } 1259 } 1260 1261 // Generate the <attr>AttrName methods, that expose the attribute names to 1262 // users. 1263 const char *attrNameMethodBody = " return getAttributeNameForIndex({0});"; 1264 for (auto [index, attr] : 1265 llvm::enumerate(llvm::make_first_range(attributes))) { 1266 std::string name = op.getGetterName(attr); 1267 std::string methodName = name + "AttrName"; 1268 1269 // Generate the non-static variant. 1270 { 1271 auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName); 1272 ERROR_IF_PRUNED(method, methodName, op); 1273 method->body() << llvm::formatv(attrNameMethodBody, index); 1274 } 1275 1276 // Generate the static variant. 1277 { 1278 auto *method = opClass.addStaticInlineMethod( 1279 "::mlir::StringAttr", methodName, 1280 MethodParameter("::mlir::OperationName", "name")); 1281 ERROR_IF_PRUNED(method, methodName, op); 1282 method->body() << llvm::formatv(attrNameMethodBody, 1283 "name, " + Twine(index)); 1284 } 1285 } 1286 if (hasOperandSegmentsSize) { 1287 std::string name = op.getGetterName(operandSegmentAttrName); 1288 std::string methodName = name + "AttrName"; 1289 // Generate the non-static variant. 1290 { 1291 auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName); 1292 ERROR_IF_PRUNED(method, methodName, op); 1293 method->body() 1294 << " return (*this)->getName().getAttributeNames().back();"; 1295 } 1296 1297 // Generate the static variant. 1298 { 1299 auto *method = opClass.addStaticInlineMethod( 1300 "::mlir::StringAttr", methodName, 1301 MethodParameter("::mlir::OperationName", "name")); 1302 ERROR_IF_PRUNED(method, methodName, op); 1303 method->body() << " return name.getAttributeNames().back();"; 1304 } 1305 } 1306 } 1307 1308 // Emit the getter for a named property. 1309 // It is templated to be shared between the Op and the adaptor class. 1310 template <typename OpClassOrAdaptor> 1311 static void emitPropGetter(OpClassOrAdaptor &opClass, const Operator &op, 1312 StringRef name, const Property &prop) { 1313 auto *method = opClass.addInlineMethod(prop.getInterfaceType(), name); 1314 ERROR_IF_PRUNED(method, name, op); 1315 method->body() << formatv(" return getProperties().{0}();", name); 1316 } 1317 1318 // Emit the getter for an attribute with the return type specified. 1319 // It is templated to be shared between the Op and the adaptor class. 1320 template <typename OpClassOrAdaptor> 1321 static void emitAttrGetterWithReturnType(FmtContext &fctx, 1322 OpClassOrAdaptor &opClass, 1323 const Operator &op, StringRef name, 1324 Attribute attr) { 1325 auto *method = opClass.addMethod(attr.getReturnType(), name); 1326 ERROR_IF_PRUNED(method, name, op); 1327 auto &body = method->body(); 1328 body << " auto attr = " << name << "Attr();\n"; 1329 if (attr.hasDefaultValue() && attr.isOptional()) { 1330 // Returns the default value if not set. 1331 // TODO: this is inefficient, we are recreating the attribute for every 1332 // call. This should be set instead. 1333 if (!attr.isConstBuildable()) { 1334 PrintFatalError("DefaultValuedAttr of type " + attr.getAttrDefName() + 1335 " must have a constBuilder"); 1336 } 1337 std::string defaultValue = std::string( 1338 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); 1339 body << " if (!attr)\n return " 1340 << tgfmt(attr.getConvertFromStorageCall(), 1341 &fctx.withSelf(defaultValue)) 1342 << ";\n"; 1343 } 1344 body << " return " 1345 << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr")) 1346 << ";\n"; 1347 } 1348 1349 void OpEmitter::genPropertiesSupport() { 1350 if (!emitHelper.hasProperties()) 1351 return; 1352 1353 SmallVector<ConstArgument> attrOrProperties; 1354 for (const std::pair<StringRef, AttributeMetadata> &it : 1355 emitHelper.getAttrMetadata()) { 1356 if (!it.second.constraint || !it.second.constraint->isDerivedAttr()) 1357 attrOrProperties.push_back(&it.second); 1358 } 1359 for (const NamedProperty &prop : op.getProperties()) 1360 attrOrProperties.push_back(&prop); 1361 if (emitHelper.getOperandSegmentsSize()) 1362 attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value()); 1363 if (emitHelper.getResultSegmentsSize()) 1364 attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value()); 1365 if (attrOrProperties.empty()) 1366 return; 1367 auto &setPropMethod = 1368 opClass 1369 .addStaticMethod( 1370 "::llvm::LogicalResult", "setPropertiesFromAttr", 1371 MethodParameter("Properties &", "prop"), 1372 MethodParameter("::mlir::Attribute", "attr"), 1373 MethodParameter( 1374 "::llvm::function_ref<::mlir::InFlightDiagnostic()>", 1375 "emitError")) 1376 ->body(); 1377 auto &getPropMethod = 1378 opClass 1379 .addStaticMethod("::mlir::Attribute", "getPropertiesAsAttr", 1380 MethodParameter("::mlir::MLIRContext *", "ctx"), 1381 MethodParameter("const Properties &", "prop")) 1382 ->body(); 1383 auto &hashMethod = 1384 opClass 1385 .addStaticMethod("llvm::hash_code", "computePropertiesHash", 1386 MethodParameter("const Properties &", "prop")) 1387 ->body(); 1388 auto &getInherentAttrMethod = 1389 opClass 1390 .addStaticMethod("std::optional<mlir::Attribute>", "getInherentAttr", 1391 MethodParameter("::mlir::MLIRContext *", "ctx"), 1392 MethodParameter("const Properties &", "prop"), 1393 MethodParameter("llvm::StringRef", "name")) 1394 ->body(); 1395 auto &setInherentAttrMethod = 1396 opClass 1397 .addStaticMethod("void", "setInherentAttr", 1398 MethodParameter("Properties &", "prop"), 1399 MethodParameter("llvm::StringRef", "name"), 1400 MethodParameter("mlir::Attribute", "value")) 1401 ->body(); 1402 auto &populateInherentAttrsMethod = 1403 opClass 1404 .addStaticMethod("void", "populateInherentAttrs", 1405 MethodParameter("::mlir::MLIRContext *", "ctx"), 1406 MethodParameter("const Properties &", "prop"), 1407 MethodParameter("::mlir::NamedAttrList &", "attrs")) 1408 ->body(); 1409 auto &verifyInherentAttrsMethod = 1410 opClass 1411 .addStaticMethod( 1412 "::llvm::LogicalResult", "verifyInherentAttrs", 1413 MethodParameter("::mlir::OperationName", "opName"), 1414 MethodParameter("::mlir::NamedAttrList &", "attrs"), 1415 MethodParameter( 1416 "llvm::function_ref<::mlir::InFlightDiagnostic()>", 1417 "emitError")) 1418 ->body(); 1419 1420 opClass.declare<UsingDeclaration>("Properties", "FoldAdaptor::Properties"); 1421 1422 // Convert the property to the attribute form. 1423 1424 setPropMethod << R"decl( 1425 ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); 1426 if (!dict) { 1427 emitError() << "expected DictionaryAttr to set properties"; 1428 return ::mlir::failure(); 1429 } 1430 )decl"; 1431 const char *propFromAttrFmt = R"decl( 1432 auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr, 1433 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) -> ::mlir::LogicalResult {{ 1434 {0} 1435 }; 1436 {1}; 1437 )decl"; 1438 const char *attrGetNoDefaultFmt = R"decl(; 1439 if (attr && ::mlir::failed(setFromAttr(prop.{0}, attr, emitError))) 1440 return ::mlir::failure(); 1441 )decl"; 1442 const char *attrGetDefaultFmt = R"decl(; 1443 if (attr) {{ 1444 if (::mlir::failed(setFromAttr(prop.{0}, attr, emitError))) 1445 return ::mlir::failure(); 1446 } else {{ 1447 prop.{0} = {1}; 1448 } 1449 )decl"; 1450 1451 for (const auto &attrOrProp : attrOrProperties) { 1452 if (const auto *namedProperty = 1453 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) { 1454 StringRef name = namedProperty->name; 1455 auto &prop = namedProperty->prop; 1456 FmtContext fctx; 1457 1458 std::string getAttr; 1459 llvm::raw_string_ostream os(getAttr); 1460 os << " auto attr = dict.get(\"" << name << "\");"; 1461 if (name == operandSegmentAttrName) { 1462 // Backward compat for now, TODO: Remove at some point. 1463 os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");"; 1464 } 1465 if (name == resultSegmentAttrName) { 1466 // Backward compat for now, TODO: Remove at some point. 1467 os << " if (!attr) attr = dict.get(\"result_segment_sizes\");"; 1468 } 1469 1470 setPropMethod << "{\n" 1471 << formatv(propFromAttrFmt, 1472 tgfmt(prop.getConvertFromAttributeCall(), 1473 &fctx.addSubst("_attr", propertyAttr) 1474 .addSubst("_storage", propertyStorage) 1475 .addSubst("_diag", propertyDiag)), 1476 getAttr); 1477 if (prop.hasStorageTypeValueOverride()) { 1478 setPropMethod << formatv(attrGetDefaultFmt, name, 1479 prop.getStorageTypeValueOverride()); 1480 } else if (prop.hasDefaultValue()) { 1481 setPropMethod << formatv(attrGetDefaultFmt, name, 1482 prop.getDefaultValue()); 1483 } else { 1484 setPropMethod << formatv(attrGetNoDefaultFmt, name); 1485 } 1486 setPropMethod << " }\n"; 1487 } else { 1488 const auto *namedAttr = 1489 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp); 1490 StringRef name = namedAttr->attrName; 1491 std::string getAttr; 1492 llvm::raw_string_ostream os(getAttr); 1493 os << " auto attr = dict.get(\"" << name << "\");"; 1494 if (name == operandSegmentAttrName) { 1495 // Backward compat for now 1496 os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");"; 1497 } 1498 if (name == resultSegmentAttrName) { 1499 // Backward compat for now 1500 os << " if (!attr) attr = dict.get(\"result_segment_sizes\");"; 1501 } 1502 1503 setPropMethod << formatv(R"decl( 1504 {{ 1505 auto &propStorage = prop.{0}; 1506 {1} 1507 if (attr) {{ 1508 auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr); 1509 if (convertedAttr) {{ 1510 propStorage = convertedAttr; 1511 } else {{ 1512 emitError() << "Invalid attribute `{0}` in property conversion: " << attr; 1513 return ::mlir::failure(); 1514 } 1515 } 1516 } 1517 )decl", 1518 name, getAttr); 1519 } 1520 } 1521 setPropMethod << " return ::mlir::success();\n"; 1522 1523 // Convert the attribute form to the property. 1524 1525 getPropMethod << " ::mlir::SmallVector<::mlir::NamedAttribute> attrs;\n" 1526 << " ::mlir::Builder odsBuilder{ctx};\n"; 1527 const char *propToAttrFmt = R"decl( 1528 { 1529 const auto &propStorage = prop.{0}; 1530 auto attr = [&]() -> ::mlir::Attribute {{ 1531 {1} 1532 }(); 1533 attrs.push_back(odsBuilder.getNamedAttr("{0}", attr)); 1534 } 1535 )decl"; 1536 for (const auto &attrOrProp : attrOrProperties) { 1537 if (const auto *namedProperty = 1538 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) { 1539 StringRef name = namedProperty->name; 1540 auto &prop = namedProperty->prop; 1541 FmtContext fctx; 1542 getPropMethod << formatv( 1543 propToAttrFmt, name, 1544 tgfmt(prop.getConvertToAttributeCall(), 1545 &fctx.addSubst("_ctxt", "ctx") 1546 .addSubst("_storage", propertyStorage))); 1547 continue; 1548 } 1549 const auto *namedAttr = 1550 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp); 1551 StringRef name = namedAttr->attrName; 1552 getPropMethod << formatv(R"decl( 1553 {{ 1554 const auto &propStorage = prop.{0}; 1555 if (propStorage) 1556 attrs.push_back(odsBuilder.getNamedAttr("{0}", 1557 propStorage)); 1558 } 1559 )decl", 1560 name); 1561 } 1562 getPropMethod << R"decl( 1563 if (!attrs.empty()) 1564 return odsBuilder.getDictionaryAttr(attrs); 1565 return {}; 1566 )decl"; 1567 1568 // Hashing for the property 1569 1570 const char *propHashFmt = R"decl( 1571 auto hash_{0} = [] (const auto &propStorage) -> llvm::hash_code { 1572 return {1}; 1573 }; 1574 )decl"; 1575 for (const auto &attrOrProp : attrOrProperties) { 1576 if (const auto *namedProperty = 1577 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) { 1578 StringRef name = namedProperty->name; 1579 auto &prop = namedProperty->prop; 1580 FmtContext fctx; 1581 if (!prop.getHashPropertyCall().empty()) { 1582 hashMethod << formatv( 1583 propHashFmt, name, 1584 tgfmt(prop.getHashPropertyCall(), 1585 &fctx.addSubst("_storage", propertyStorage))); 1586 } 1587 } 1588 } 1589 hashMethod << " return llvm::hash_combine("; 1590 llvm::interleaveComma( 1591 attrOrProperties, hashMethod, [&](const ConstArgument &attrOrProp) { 1592 if (const auto *namedProperty = 1593 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) { 1594 if (!namedProperty->prop.getHashPropertyCall().empty()) { 1595 hashMethod << "\n hash_" << namedProperty->name << "(prop." 1596 << namedProperty->name << ")"; 1597 } else { 1598 hashMethod << "\n ::llvm::hash_value(prop." 1599 << namedProperty->name << ")"; 1600 } 1601 return; 1602 } 1603 const auto *namedAttr = 1604 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp); 1605 StringRef name = namedAttr->attrName; 1606 hashMethod << "\n llvm::hash_value(prop." << name 1607 << ".getAsOpaquePointer())"; 1608 }); 1609 hashMethod << ");\n"; 1610 1611 const char *getInherentAttrMethodFmt = R"decl( 1612 if (name == "{0}") 1613 return prop.{0}; 1614 )decl"; 1615 const char *setInherentAttrMethodFmt = R"decl( 1616 if (name == "{0}") {{ 1617 prop.{0} = ::llvm::dyn_cast_or_null<std::remove_reference_t<decltype(prop.{0})>>(value); 1618 return; 1619 } 1620 )decl"; 1621 const char *populateInherentAttrsMethodFmt = R"decl( 1622 if (prop.{0}) attrs.append("{0}", prop.{0}); 1623 )decl"; 1624 for (const auto &attrOrProp : attrOrProperties) { 1625 if (const auto *namedAttr = 1626 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp)) { 1627 StringRef name = namedAttr->attrName; 1628 getInherentAttrMethod << formatv(getInherentAttrMethodFmt, name); 1629 setInherentAttrMethod << formatv(setInherentAttrMethodFmt, name); 1630 populateInherentAttrsMethod 1631 << formatv(populateInherentAttrsMethodFmt, name); 1632 continue; 1633 } 1634 // The ODS segment size property is "special": we expose it as an attribute 1635 // even though it is a native property. 1636 const auto *namedProperty = cast<const NamedProperty *>(attrOrProp); 1637 StringRef name = namedProperty->name; 1638 if (name != operandSegmentAttrName && name != resultSegmentAttrName) 1639 continue; 1640 auto &prop = namedProperty->prop; 1641 FmtContext fctx; 1642 fctx.addSubst("_ctxt", "ctx"); 1643 fctx.addSubst("_storage", Twine("prop.") + name); 1644 if (name == operandSegmentAttrName) { 1645 getInherentAttrMethod 1646 << formatv(" if (name == \"operand_segment_sizes\" || name == " 1647 "\"{0}\") return ", 1648 operandSegmentAttrName); 1649 } else { 1650 getInherentAttrMethod 1651 << formatv(" if (name == \"result_segment_sizes\" || name == " 1652 "\"{0}\") return ", 1653 resultSegmentAttrName); 1654 } 1655 getInherentAttrMethod << "[&]() -> ::mlir::Attribute { " 1656 << tgfmt(prop.getConvertToAttributeCall(), &fctx) 1657 << " }();\n"; 1658 1659 if (name == operandSegmentAttrName) { 1660 setInherentAttrMethod 1661 << formatv(" if (name == \"operand_segment_sizes\" || name == " 1662 "\"{0}\") {{", 1663 operandSegmentAttrName); 1664 } else { 1665 setInherentAttrMethod 1666 << formatv(" if (name == \"result_segment_sizes\" || name == " 1667 "\"{0}\") {{", 1668 resultSegmentAttrName); 1669 } 1670 setInherentAttrMethod << formatv(R"decl( 1671 auto arrAttr = ::llvm::dyn_cast_or_null<::mlir::DenseI32ArrayAttr>(value); 1672 if (!arrAttr) return; 1673 if (arrAttr.size() != sizeof(prop.{0}) / sizeof(int32_t)) 1674 return; 1675 llvm::copy(arrAttr.asArrayRef(), prop.{0}.begin()); 1676 return; 1677 } 1678 )decl", 1679 name); 1680 if (name == operandSegmentAttrName) { 1681 populateInherentAttrsMethod << formatv( 1682 " attrs.append(\"{0}\", [&]() -> ::mlir::Attribute { {1} }());\n", 1683 operandSegmentAttrName, 1684 tgfmt(prop.getConvertToAttributeCall(), &fctx)); 1685 } else { 1686 populateInherentAttrsMethod << formatv( 1687 " attrs.append(\"{0}\", [&]() -> ::mlir::Attribute { {1} }());\n", 1688 resultSegmentAttrName, 1689 tgfmt(prop.getConvertToAttributeCall(), &fctx)); 1690 } 1691 } 1692 getInherentAttrMethod << " return std::nullopt;\n"; 1693 1694 // Emit the verifiers method for backward compatibility with the generic 1695 // syntax. This method verifies the constraint on the properties attributes 1696 // before they are set, since dyn_cast<> will silently omit failures. 1697 for (const auto &attrOrProp : attrOrProperties) { 1698 const auto *namedAttr = 1699 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp); 1700 if (!namedAttr || !namedAttr->constraint) 1701 continue; 1702 Attribute attr = *namedAttr->constraint; 1703 std::optional<StringRef> constraintFn = 1704 staticVerifierEmitter.getAttrConstraintFn(attr); 1705 if (!constraintFn) 1706 continue; 1707 if (canEmitAttrVerifier(attr, 1708 /*isEmittingForOp=*/false)) { 1709 std::string name = op.getGetterName(namedAttr->attrName); 1710 verifyInherentAttrsMethod 1711 << formatv(R"( 1712 {{ 1713 ::mlir::Attribute attr = attrs.get({0}AttrName(opName)); 1714 if (attr && ::mlir::failed({1}(attr, "{2}", emitError))) 1715 return ::mlir::failure(); 1716 } 1717 )", 1718 name, constraintFn, namedAttr->attrName); 1719 } 1720 } 1721 verifyInherentAttrsMethod << " return ::mlir::success();"; 1722 1723 // Generate methods to interact with bytecode. 1724 genPropertiesSupportForBytecode(attrOrProperties); 1725 } 1726 1727 void OpEmitter::genPropertiesSupportForBytecode( 1728 ArrayRef<ConstArgument> attrOrProperties) { 1729 if (op.useCustomPropertiesEncoding()) { 1730 opClass.declareStaticMethod( 1731 "::llvm::LogicalResult", "readProperties", 1732 MethodParameter("::mlir::DialectBytecodeReader &", "reader"), 1733 MethodParameter("::mlir::OperationState &", "state")); 1734 opClass.declareMethod( 1735 "void", "writeProperties", 1736 MethodParameter("::mlir::DialectBytecodeWriter &", "writer")); 1737 return; 1738 } 1739 1740 auto &readPropertiesMethod = 1741 opClass 1742 .addStaticMethod( 1743 "::llvm::LogicalResult", "readProperties", 1744 MethodParameter("::mlir::DialectBytecodeReader &", "reader"), 1745 MethodParameter("::mlir::OperationState &", "state")) 1746 ->body(); 1747 1748 auto &writePropertiesMethod = 1749 opClass 1750 .addMethod( 1751 "void", "writeProperties", 1752 MethodParameter("::mlir::DialectBytecodeWriter &", "writer")) 1753 ->body(); 1754 1755 // Populate bytecode serialization logic. 1756 readPropertiesMethod 1757 << " auto &prop = state.getOrAddProperties<Properties>(); (void)prop;"; 1758 writePropertiesMethod << " auto &prop = getProperties(); (void)prop;\n"; 1759 for (const auto &item : llvm::enumerate(attrOrProperties)) { 1760 auto &attrOrProp = item.value(); 1761 FmtContext fctx; 1762 fctx.addSubst("_reader", "reader") 1763 .addSubst("_writer", "writer") 1764 .addSubst("_storage", propertyStorage) 1765 .addSubst("_ctxt", "this->getContext()"); 1766 // If the op emits operand/result segment sizes as a property, emit the 1767 // legacy reader/writer in the appropriate order to allow backward 1768 // compatibility and back deployment. 1769 if (emitHelper.getOperandSegmentsSize().has_value() && 1770 item.index() == emitHelper.getOperandSegmentSizesLegacyIndex()) { 1771 FmtContext fmtCtxt(fctx); 1772 fmtCtxt.addSubst("_propName", operandSegmentAttrName); 1773 readPropertiesMethod << tgfmt(readBytecodeSegmentSizeLegacy, &fmtCtxt); 1774 writePropertiesMethod << tgfmt(writeBytecodeSegmentSizeLegacy, &fmtCtxt); 1775 } 1776 if (emitHelper.getResultSegmentsSize().has_value() && 1777 item.index() == emitHelper.getResultSegmentSizesLegacyIndex()) { 1778 FmtContext fmtCtxt(fctx); 1779 fmtCtxt.addSubst("_propName", resultSegmentAttrName); 1780 readPropertiesMethod << tgfmt(readBytecodeSegmentSizeLegacy, &fmtCtxt); 1781 writePropertiesMethod << tgfmt(writeBytecodeSegmentSizeLegacy, &fmtCtxt); 1782 } 1783 if (const auto *namedProperty = 1784 attrOrProp.dyn_cast<const NamedProperty *>()) { 1785 StringRef name = namedProperty->name; 1786 readPropertiesMethod << formatv( 1787 R"( 1788 {{ 1789 auto &propStorage = prop.{0}; 1790 auto readProp = [&]() { 1791 {1}; 1792 return ::mlir::success(); 1793 }; 1794 if (::mlir::failed(readProp())) 1795 return ::mlir::failure(); 1796 } 1797 )", 1798 name, 1799 tgfmt(namedProperty->prop.getReadFromMlirBytecodeCall(), &fctx)); 1800 writePropertiesMethod << formatv( 1801 R"( 1802 {{ 1803 auto &propStorage = prop.{0}; 1804 {1}; 1805 } 1806 )", 1807 name, tgfmt(namedProperty->prop.getWriteToMlirBytecodeCall(), &fctx)); 1808 continue; 1809 } 1810 const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>(); 1811 StringRef name = namedAttr->attrName; 1812 if (namedAttr->isRequired) { 1813 readPropertiesMethod << formatv(R"( 1814 if (::mlir::failed(reader.readAttribute(prop.{0}))) 1815 return ::mlir::failure(); 1816 )", 1817 name); 1818 writePropertiesMethod 1819 << formatv(" writer.writeAttribute(prop.{0});\n", name); 1820 } else { 1821 readPropertiesMethod << formatv(R"( 1822 if (::mlir::failed(reader.readOptionalAttribute(prop.{0}))) 1823 return ::mlir::failure(); 1824 )", 1825 name); 1826 writePropertiesMethod << formatv(R"( 1827 writer.writeOptionalAttribute(prop.{0}); 1828 )", 1829 name); 1830 } 1831 } 1832 readPropertiesMethod << " return ::mlir::success();"; 1833 } 1834 1835 void OpEmitter::genPropGetters() { 1836 for (const NamedProperty &prop : op.getProperties()) { 1837 std::string name = op.getGetterName(prop.name); 1838 emitPropGetter(opClass, op, name, prop.prop); 1839 } 1840 } 1841 1842 void OpEmitter::genPropSetters() { 1843 for (const NamedProperty &prop : op.getProperties()) { 1844 std::string name = op.getSetterName(prop.name); 1845 std::string argName = "new" + convertToCamelFromSnakeCase( 1846 prop.name, /*capitalizeFirst=*/true); 1847 auto *method = opClass.addInlineMethod( 1848 "void", name, MethodParameter(prop.prop.getInterfaceType(), argName)); 1849 if (!method) 1850 return; 1851 method->body() << formatv(" getProperties().{0}({1});", name, argName); 1852 } 1853 } 1854 1855 void OpEmitter::genAttrGetters() { 1856 FmtContext fctx; 1857 fctx.withBuilder("::mlir::Builder((*this)->getContext())"); 1858 1859 // Emit the derived attribute body. 1860 auto emitDerivedAttr = [&](StringRef name, Attribute attr) { 1861 if (auto *method = opClass.addMethod(attr.getReturnType(), name)) 1862 method->body() << " " << attr.getDerivedCodeBody() << "\n"; 1863 }; 1864 1865 // Generate named accessor with Attribute return type. This is a wrapper 1866 // class that allows referring to the attributes via accessors instead of 1867 // having to use the string interface for better compile time verification. 1868 auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName, 1869 Attribute attr) { 1870 // The method body for this getter is trivial. Emit it inline. 1871 auto *method = 1872 opClass.addInlineMethod(attr.getStorageType(), name + "Attr"); 1873 if (!method) 1874 return; 1875 method->body() << formatv( 1876 " return ::llvm::{1}<{2}>({0});", emitHelper.getAttr(attrName), 1877 attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null" 1878 : "cast", 1879 attr.getStorageType()); 1880 }; 1881 1882 for (const NamedAttribute &namedAttr : op.getAttributes()) { 1883 std::string name = op.getGetterName(namedAttr.name); 1884 if (namedAttr.attr.isDerivedAttr()) { 1885 emitDerivedAttr(name, namedAttr.attr); 1886 } else { 1887 emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr); 1888 emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr); 1889 } 1890 } 1891 1892 auto derivedAttrs = make_filter_range(op.getAttributes(), 1893 [](const NamedAttribute &namedAttr) { 1894 return namedAttr.attr.isDerivedAttr(); 1895 }); 1896 if (derivedAttrs.empty()) 1897 return; 1898 1899 opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait"); 1900 // Generate helper method to query whether a named attribute is a derived 1901 // attribute. This enables, for example, avoiding adding an attribute that 1902 // overlaps with a derived attribute. 1903 { 1904 auto *method = 1905 opClass.addStaticMethod("bool", "isDerivedAttribute", 1906 MethodParameter("::llvm::StringRef", "name")); 1907 ERROR_IF_PRUNED(method, "isDerivedAttribute", op); 1908 auto &body = method->body(); 1909 for (auto namedAttr : derivedAttrs) 1910 body << " if (name == \"" << namedAttr.name << "\") return true;\n"; 1911 body << " return false;"; 1912 } 1913 // Generate method to materialize derived attributes as a DictionaryAttr. 1914 { 1915 auto *method = opClass.addMethod("::mlir::DictionaryAttr", 1916 "materializeDerivedAttributes"); 1917 ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op); 1918 auto &body = method->body(); 1919 1920 auto nonMaterializable = 1921 make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) { 1922 return namedAttr.attr.getConvertFromStorageCall().empty(); 1923 }); 1924 if (!nonMaterializable.empty()) { 1925 std::string attrs; 1926 llvm::raw_string_ostream os(attrs); 1927 interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) { 1928 os << op.getGetterName(attr.name); 1929 }); 1930 PrintWarning( 1931 op.getLoc(), 1932 formatv( 1933 "op has non-materializable derived attributes '{0}', skipping", 1934 os.str())); 1935 body << formatv(" emitOpError(\"op has non-materializable derived " 1936 "attributes '{0}'\");\n", 1937 attrs); 1938 body << " return nullptr;"; 1939 return; 1940 } 1941 1942 body << " ::mlir::MLIRContext* ctx = getContext();\n"; 1943 body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n"; 1944 body << " return ::mlir::DictionaryAttr::get("; 1945 body << " ctx, {\n"; 1946 interleave( 1947 derivedAttrs, body, 1948 [&](const NamedAttribute &namedAttr) { 1949 auto tmpl = namedAttr.attr.getConvertFromStorageCall(); 1950 std::string name = op.getGetterName(namedAttr.name); 1951 body << " {" << name << "AttrName(),\n" 1952 << tgfmt(tmpl, &fctx.withSelf(name + "()") 1953 .withBuilder("odsBuilder") 1954 .addSubst("_ctxt", "ctx") 1955 .addSubst("_storage", "ctx")) 1956 << "}"; 1957 }, 1958 ",\n"); 1959 body << "});"; 1960 } 1961 } 1962 1963 void OpEmitter::genAttrSetters() { 1964 bool useProperties = op.getDialect().usePropertiesForAttributes(); 1965 1966 // Generate the code to set an attribute. 1967 auto emitSetAttr = [&](Method *method, StringRef getterName, 1968 StringRef attrName, StringRef attrVar) { 1969 if (useProperties) { 1970 method->body() << formatv(" getProperties().{0} = {1};", attrName, 1971 attrVar); 1972 } else { 1973 method->body() << formatv(" (*this)->setAttr({0}AttrName(), {1});", 1974 getterName, attrVar); 1975 } 1976 }; 1977 1978 // Generate raw named setter type. This is a wrapper class that allows setting 1979 // to the attributes via setters instead of having to use the string interface 1980 // for better compile time verification. 1981 auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName, 1982 StringRef attrName, Attribute attr) { 1983 // This method body is trivial, so emit it inline. 1984 auto *method = 1985 opClass.addInlineMethod("void", setterName + "Attr", 1986 MethodParameter(attr.getStorageType(), "attr")); 1987 if (method) 1988 emitSetAttr(method, getterName, attrName, "attr"); 1989 }; 1990 1991 // Generate a setter that accepts the underlying C++ type as opposed to the 1992 // attribute type. 1993 auto emitAttrWithReturnType = [&](StringRef setterName, StringRef getterName, 1994 StringRef attrName, Attribute attr) { 1995 Attribute baseAttr = attr.getBaseAttr(); 1996 if (!canUseUnwrappedRawValue(baseAttr)) 1997 return; 1998 FmtContext fctx; 1999 fctx.withBuilder("::mlir::Builder((*this)->getContext())"); 2000 bool isUnitAttr = attr.getAttrDefName() == "UnitAttr"; 2001 bool isOptional = attr.isOptional(); 2002 2003 auto createMethod = [&](const Twine ¶mType) { 2004 return opClass.addMethod("void", setterName, 2005 MethodParameter(paramType.str(), "attrValue")); 2006 }; 2007 2008 // Build the method using the correct parameter type depending on 2009 // optionality. 2010 Method *method = nullptr; 2011 if (isUnitAttr) 2012 method = createMethod("bool"); 2013 else if (isOptional) 2014 method = 2015 createMethod("::std::optional<" + baseAttr.getReturnType() + ">"); 2016 else 2017 method = createMethod(attr.getReturnType()); 2018 if (!method) 2019 return; 2020 2021 // If the value isn't optional, just set it directly. 2022 if (!isOptional) { 2023 emitSetAttr(method, getterName, attrName, 2024 constBuildAttrFromParam(attr, fctx, "attrValue")); 2025 return; 2026 } 2027 2028 // Otherwise, we only set if the provided value is valid. If it isn't, we 2029 // remove the attribute. 2030 2031 // TODO: Handle unit attr parameters specially, given that it is treated as 2032 // optional but not in the same way as the others (i.e. it uses bool over 2033 // std::optional<>). 2034 StringRef paramStr = isUnitAttr ? "attrValue" : "*attrValue"; 2035 if (!useProperties) { 2036 const char *optionalCodeBody = R"( 2037 if (attrValue) 2038 return (*this)->setAttr({0}AttrName(), {1}); 2039 (*this)->removeAttr({0}AttrName());)"; 2040 method->body() << formatv( 2041 optionalCodeBody, getterName, 2042 constBuildAttrFromParam(baseAttr, fctx, paramStr)); 2043 } else { 2044 const char *optionalCodeBody = R"( 2045 auto &odsProp = getProperties().{0}; 2046 if (attrValue) 2047 odsProp = {1}; 2048 else 2049 odsProp = nullptr;)"; 2050 method->body() << formatv( 2051 optionalCodeBody, attrName, 2052 constBuildAttrFromParam(baseAttr, fctx, paramStr)); 2053 } 2054 }; 2055 2056 for (const NamedAttribute &namedAttr : op.getAttributes()) { 2057 if (namedAttr.attr.isDerivedAttr()) 2058 continue; 2059 std::string setterName = op.getSetterName(namedAttr.name); 2060 std::string getterName = op.getGetterName(namedAttr.name); 2061 emitAttrWithStorageType(setterName, getterName, namedAttr.name, 2062 namedAttr.attr); 2063 emitAttrWithReturnType(setterName, getterName, namedAttr.name, 2064 namedAttr.attr); 2065 } 2066 } 2067 2068 void OpEmitter::genOptionalAttrRemovers() { 2069 // Generate methods for removing optional attributes, instead of having to 2070 // use the string interface. Enables better compile time verification. 2071 auto emitRemoveAttr = [&](StringRef name, bool useProperties) { 2072 auto upperInitial = name.take_front().upper(); 2073 auto *method = opClass.addInlineMethod("::mlir::Attribute", 2074 op.getRemoverName(name) + "Attr"); 2075 if (!method) 2076 return; 2077 if (useProperties) { 2078 method->body() << formatv(R"( 2079 auto &attr = getProperties().{0}; 2080 attr = {{}; 2081 return attr; 2082 )", 2083 name); 2084 return; 2085 } 2086 method->body() << formatv("return (*this)->removeAttr({0}AttrName());", 2087 op.getGetterName(name)); 2088 }; 2089 2090 for (const NamedAttribute &namedAttr : op.getAttributes()) 2091 if (namedAttr.attr.isOptional()) 2092 emitRemoveAttr(namedAttr.name, 2093 op.getDialect().usePropertiesForAttributes()); 2094 } 2095 2096 // Generates the code to compute the start and end index of an operand or result 2097 // range. 2098 template <typename RangeT> 2099 static void generateValueRangeStartAndEnd( 2100 Class &opClass, bool isGenericAdaptorBase, StringRef methodName, 2101 int numVariadic, int numNonVariadic, StringRef rangeSizeCall, 2102 bool hasAttrSegmentSize, StringRef sizeAttrInit, RangeT &&odsValues) { 2103 2104 SmallVector<MethodParameter> parameters{MethodParameter("unsigned", "index")}; 2105 if (isGenericAdaptorBase) { 2106 parameters.emplace_back("unsigned", "odsOperandsSize"); 2107 // The range size is passed per parameter for generic adaptor bases as 2108 // using the rangeSizeCall would require the operands, which are not 2109 // accessible in the base class. 2110 rangeSizeCall = "odsOperandsSize"; 2111 } 2112 2113 // The method is trivial if the operation does not have any variadic operands. 2114 // In that case, make sure to generate it in-line. 2115 auto *method = opClass.addMethod("std::pair<unsigned, unsigned>", methodName, 2116 numVariadic == 0 ? Method::Properties::Inline 2117 : Method::Properties::None, 2118 parameters); 2119 if (!method) 2120 return; 2121 auto &body = method->body(); 2122 if (numVariadic == 0) { 2123 body << " return {index, 1};\n"; 2124 } else if (hasAttrSegmentSize) { 2125 body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode; 2126 } else { 2127 // Because the op can have arbitrarily interleaved variadic and non-variadic 2128 // operands, we need to embed a list in the "sink" getter method for 2129 // calculation at run-time. 2130 SmallVector<StringRef, 4> isVariadic; 2131 isVariadic.reserve(llvm::size(odsValues)); 2132 for (auto &it : odsValues) 2133 isVariadic.push_back(it.isVariableLength() ? "true" : "false"); 2134 std::string isVariadicList = llvm::join(isVariadic, ", "); 2135 body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, 2136 numNonVariadic, numVariadic, rangeSizeCall, "operand"); 2137 } 2138 } 2139 2140 static std::string generateTypeForGetter(const NamedTypeConstraint &value) { 2141 return llvm::formatv("::mlir::TypedValue<{0}>", value.constraint.getCppType()) 2142 .str(); 2143 } 2144 2145 // Generates the named operand getter methods for the given Operator `op` and 2146 // puts them in `opClass`. Uses `rangeType` as the return type of getters that 2147 // return a range of operands (individual operands are `Value ` and each 2148 // element in the range must also be `Value `); use `rangeBeginCall` to get 2149 // an iterator to the beginning of the operand range; use `rangeSizeCall` to 2150 // obtain the number of operands. `getOperandCallPattern` contains the code 2151 // necessary to obtain a single operand whose position will be substituted 2152 // instead of 2153 // "{0}" marker in the pattern. Note that the pattern should work for any kind 2154 // of ops, in particular for one-operand ops that may not have the 2155 // `getOperand(unsigned)` method. 2156 static void 2157 generateNamedOperandGetters(const Operator &op, Class &opClass, 2158 Class *genericAdaptorBase, StringRef sizeAttrInit, 2159 StringRef rangeType, StringRef rangeElementType, 2160 StringRef rangeBeginCall, StringRef rangeSizeCall, 2161 StringRef getOperandCallPattern) { 2162 const int numOperands = op.getNumOperands(); 2163 const int numVariadicOperands = op.getNumVariableLengthOperands(); 2164 const int numNormalOperands = numOperands - numVariadicOperands; 2165 2166 const auto *sameVariadicSize = 2167 op.getTrait("::mlir::OpTrait::SameVariadicOperandSize"); 2168 const auto *attrSizedOperands = 2169 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); 2170 2171 if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) { 2172 PrintFatalError(op.getLoc(), "op has multiple variadic operands but no " 2173 "specification over their sizes"); 2174 } 2175 2176 if (numVariadicOperands < 2 && attrSizedOperands) { 2177 PrintFatalError(op.getLoc(), "op must have at least two variadic operands " 2178 "to use 'AttrSizedOperandSegments' trait"); 2179 } 2180 2181 if (attrSizedOperands && sameVariadicSize) { 2182 PrintFatalError(op.getLoc(), 2183 "op cannot have both 'AttrSizedOperandSegments' and " 2184 "'SameVariadicOperandSize' traits"); 2185 } 2186 2187 // First emit a few "sink" getter methods upon which we layer all nicer named 2188 // getter methods. 2189 // If generating for an adaptor, the method is put into the non-templated 2190 // generic base class, to not require being defined in the header. 2191 // Since the operand size can't be determined from the base class however, 2192 // it has to be passed as an additional argument. The trampoline below 2193 // generates the function with the same signature as the Op in the generic 2194 // adaptor. 2195 bool isGenericAdaptorBase = genericAdaptorBase != nullptr; 2196 generateValueRangeStartAndEnd( 2197 /*opClass=*/isGenericAdaptorBase ? *genericAdaptorBase : opClass, 2198 isGenericAdaptorBase, 2199 /*methodName=*/"getODSOperandIndexAndLength", numVariadicOperands, 2200 numNormalOperands, rangeSizeCall, attrSizedOperands, sizeAttrInit, 2201 const_cast<Operator &>(op).getOperands()); 2202 if (isGenericAdaptorBase) { 2203 // Generate trampoline for calling 'getODSOperandIndexAndLength' with just 2204 // the index. This just calls the implementation in the base class but 2205 // passes the operand size as parameter. 2206 Method *method = opClass.addInlineMethod( 2207 "std::pair<unsigned, unsigned>", "getODSOperandIndexAndLength", 2208 MethodParameter("unsigned", "index")); 2209 ERROR_IF_PRUNED(method, "getODSOperandIndexAndLength", op); 2210 MethodBody &body = method->body(); 2211 body.indent() << formatv( 2212 "return Base::getODSOperandIndexAndLength(index, {0});", rangeSizeCall); 2213 } 2214 2215 // The implementation of this method is trivial and it is very load-bearing. 2216 // Generate it inline. 2217 auto *m = opClass.addInlineMethod(rangeType, "getODSOperands", 2218 MethodParameter("unsigned", "index")); 2219 ERROR_IF_PRUNED(m, "getODSOperands", op); 2220 auto &body = m->body(); 2221 body << formatv(valueRangeReturnCode, rangeBeginCall, 2222 "getODSOperandIndexAndLength(index)"); 2223 2224 // Then we emit nicer named getter methods by redirecting to the "sink" getter 2225 // method. 2226 for (int i = 0; i != numOperands; ++i) { 2227 const auto &operand = op.getOperand(i); 2228 if (operand.name.empty()) 2229 continue; 2230 std::string name = op.getGetterName(operand.name); 2231 if (operand.isOptional()) { 2232 m = opClass.addInlineMethod(isGenericAdaptorBase 2233 ? rangeElementType 2234 : generateTypeForGetter(operand), 2235 name); 2236 ERROR_IF_PRUNED(m, name, op); 2237 m->body().indent() << formatv("auto operands = getODSOperands({0});\n" 2238 "return operands.empty() ? {1}{{} : ", 2239 i, m->getReturnType()); 2240 if (!isGenericAdaptorBase) 2241 m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType()); 2242 m->body() << "(*operands.begin());"; 2243 } else if (operand.isVariadicOfVariadic()) { 2244 std::string segmentAttr = op.getGetterName( 2245 operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); 2246 if (genericAdaptorBase) { 2247 m = opClass.addMethod("::llvm::SmallVector<" + rangeType + ">", name); 2248 ERROR_IF_PRUNED(m, name, op); 2249 m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode, 2250 segmentAttr, i, rangeType); 2251 continue; 2252 } 2253 2254 m = opClass.addInlineMethod("::mlir::OperandRangeRange", name); 2255 ERROR_IF_PRUNED(m, name, op); 2256 m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr 2257 << "Attr());"; 2258 } else if (operand.isVariadic()) { 2259 m = opClass.addInlineMethod(rangeType, name); 2260 ERROR_IF_PRUNED(m, name, op); 2261 m->body() << " return getODSOperands(" << i << ");"; 2262 } else { 2263 m = opClass.addInlineMethod(isGenericAdaptorBase 2264 ? rangeElementType 2265 : generateTypeForGetter(operand), 2266 name); 2267 ERROR_IF_PRUNED(m, name, op); 2268 m->body().indent() << "return "; 2269 if (!isGenericAdaptorBase) 2270 m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType()); 2271 m->body() << llvm::formatv("(*getODSOperands({0}).begin());", i); 2272 } 2273 } 2274 } 2275 2276 void OpEmitter::genNamedOperandGetters() { 2277 // Build the code snippet used for initializing the operand_segment_size)s 2278 // array. 2279 std::string attrSizeInitCode; 2280 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 2281 if (op.getDialect().usePropertiesForAttributes()) 2282 attrSizeInitCode = formatv(adapterSegmentSizeAttrInitCodeProperties, 2283 "getProperties().operandSegmentSizes"); 2284 2285 else 2286 attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, 2287 emitHelper.getAttr(operandSegmentAttrName)); 2288 } 2289 2290 generateNamedOperandGetters( 2291 op, opClass, 2292 /*genericAdaptorBase=*/nullptr, 2293 /*sizeAttrInit=*/attrSizeInitCode, 2294 /*rangeType=*/"::mlir::Operation::operand_range", 2295 /*rangeElementType=*/"::mlir::Value", 2296 /*rangeBeginCall=*/"getOperation()->operand_begin()", 2297 /*rangeSizeCall=*/"getOperation()->getNumOperands()", 2298 /*getOperandCallPattern=*/"getOperation()->getOperand({0})"); 2299 } 2300 2301 void OpEmitter::genNamedOperandSetters() { 2302 auto *attrSizedOperands = 2303 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); 2304 for (int i = 0, e = op.getNumOperands(); i != e; ++i) { 2305 const auto &operand = op.getOperand(i); 2306 if (operand.name.empty()) 2307 continue; 2308 std::string name = op.getGetterName(operand.name); 2309 2310 StringRef returnType; 2311 if (operand.isVariadicOfVariadic()) { 2312 returnType = "::mlir::MutableOperandRangeRange"; 2313 } else if (operand.isVariableLength()) { 2314 returnType = "::mlir::MutableOperandRange"; 2315 } else { 2316 returnType = "::mlir::OpOperand &"; 2317 } 2318 bool isVariadicOperand = 2319 operand.isVariadicOfVariadic() || operand.isVariableLength(); 2320 auto *m = opClass.addMethod(returnType, name + "Mutable", 2321 isVariadicOperand ? Method::Properties::None 2322 : Method::Properties::Inline); 2323 ERROR_IF_PRUNED(m, name, op); 2324 auto &body = m->body(); 2325 body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"; 2326 2327 if (!isVariadicOperand) { 2328 // In case of a single operand, return a single OpOperand. 2329 body << " return getOperation()->getOpOperand(range.first);\n"; 2330 continue; 2331 } 2332 2333 body << " auto mutableRange = " 2334 "::mlir::MutableOperandRange(getOperation(), " 2335 "range.first, range.second"; 2336 if (attrSizedOperands) { 2337 if (emitHelper.hasProperties()) 2338 body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, " 2339 "{{getOperandSegmentSizesAttrName(), " 2340 "::mlir::DenseI32ArrayAttr::get(getContext(), " 2341 "getProperties().operandSegmentSizes)})", 2342 i); 2343 else 2344 body << formatv( 2345 ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i, 2346 emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true)); 2347 } 2348 body << ");\n"; 2349 2350 // If this operand is a nested variadic, we split the range into a 2351 // MutableOperandRangeRange that provides a range over all of the 2352 // sub-ranges. 2353 if (operand.isVariadicOfVariadic()) { 2354 body << " return " 2355 "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" 2356 << op.getGetterName( 2357 operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) 2358 << "AttrName()));\n"; 2359 } else { 2360 // Otherwise, we use the full range directly. 2361 body << " return mutableRange;\n"; 2362 } 2363 } 2364 } 2365 2366 void OpEmitter::genNamedResultGetters() { 2367 const int numResults = op.getNumResults(); 2368 const int numVariadicResults = op.getNumVariableLengthResults(); 2369 const int numNormalResults = numResults - numVariadicResults; 2370 2371 // If we have more than one variadic results, we need more complicated logic 2372 // to calculate the value range for each result. 2373 2374 const auto *sameVariadicSize = 2375 op.getTrait("::mlir::OpTrait::SameVariadicResultSize"); 2376 const auto *attrSizedResults = 2377 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"); 2378 2379 if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) { 2380 PrintFatalError(op.getLoc(), "op has multiple variadic results but no " 2381 "specification over their sizes"); 2382 } 2383 2384 if (numVariadicResults < 2 && attrSizedResults) { 2385 PrintFatalError(op.getLoc(), "op must have at least two variadic results " 2386 "to use 'AttrSizedResultSegments' trait"); 2387 } 2388 2389 if (attrSizedResults && sameVariadicSize) { 2390 PrintFatalError(op.getLoc(), 2391 "op cannot have both 'AttrSizedResultSegments' and " 2392 "'SameVariadicResultSize' traits"); 2393 } 2394 2395 // Build the initializer string for the result segment size attribute. 2396 std::string attrSizeInitCode; 2397 if (attrSizedResults) { 2398 if (op.getDialect().usePropertiesForAttributes()) 2399 attrSizeInitCode = formatv(adapterSegmentSizeAttrInitCodeProperties, 2400 "getProperties().resultSegmentSizes"); 2401 2402 else 2403 attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, 2404 emitHelper.getAttr(resultSegmentAttrName)); 2405 } 2406 2407 generateValueRangeStartAndEnd( 2408 opClass, /*isGenericAdaptorBase=*/false, "getODSResultIndexAndLength", 2409 numVariadicResults, numNormalResults, "getOperation()->getNumResults()", 2410 attrSizedResults, attrSizeInitCode, op.getResults()); 2411 2412 // The implementation of this method is trivial and it is very load-bearing. 2413 // Generate it inline. 2414 auto *m = opClass.addInlineMethod("::mlir::Operation::result_range", 2415 "getODSResults", 2416 MethodParameter("unsigned", "index")); 2417 ERROR_IF_PRUNED(m, "getODSResults", op); 2418 m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()", 2419 "getODSResultIndexAndLength(index)"); 2420 2421 for (int i = 0; i != numResults; ++i) { 2422 const auto &result = op.getResult(i); 2423 if (result.name.empty()) 2424 continue; 2425 std::string name = op.getGetterName(result.name); 2426 if (result.isOptional()) { 2427 m = opClass.addInlineMethod(generateTypeForGetter(result), name); 2428 ERROR_IF_PRUNED(m, name, op); 2429 m->body() << " auto results = getODSResults(" << i << ");\n" 2430 << llvm::formatv(" return results.empty()" 2431 " ? {0}()" 2432 " : ::llvm::cast<{0}>(*results.begin());", 2433 m->getReturnType()); 2434 } else if (result.isVariadic()) { 2435 m = opClass.addInlineMethod("::mlir::Operation::result_range", name); 2436 ERROR_IF_PRUNED(m, name, op); 2437 m->body() << " return getODSResults(" << i << ");"; 2438 } else { 2439 m = opClass.addInlineMethod(generateTypeForGetter(result), name); 2440 ERROR_IF_PRUNED(m, name, op); 2441 m->body() << llvm::formatv( 2442 " return ::llvm::cast<{0}>(*getODSResults({1}).begin());", 2443 m->getReturnType(), i); 2444 } 2445 } 2446 } 2447 2448 void OpEmitter::genNamedRegionGetters() { 2449 unsigned numRegions = op.getNumRegions(); 2450 for (unsigned i = 0; i < numRegions; ++i) { 2451 const auto ®ion = op.getRegion(i); 2452 if (region.name.empty()) 2453 continue; 2454 std::string name = op.getGetterName(region.name); 2455 2456 // Generate the accessors for a variadic region. 2457 if (region.isVariadic()) { 2458 auto *m = opClass.addInlineMethod( 2459 "::mlir::MutableArrayRef<::mlir::Region>", name); 2460 ERROR_IF_PRUNED(m, name, op); 2461 m->body() << formatv(" return (*this)->getRegions().drop_front({0});", 2462 i); 2463 continue; 2464 } 2465 2466 auto *m = opClass.addInlineMethod("::mlir::Region &", name); 2467 ERROR_IF_PRUNED(m, name, op); 2468 m->body() << formatv(" return (*this)->getRegion({0});", i); 2469 } 2470 } 2471 2472 void OpEmitter::genNamedSuccessorGetters() { 2473 unsigned numSuccessors = op.getNumSuccessors(); 2474 for (unsigned i = 0; i < numSuccessors; ++i) { 2475 const NamedSuccessor &successor = op.getSuccessor(i); 2476 if (successor.name.empty()) 2477 continue; 2478 std::string name = op.getGetterName(successor.name); 2479 // Generate the accessors for a variadic successor list. 2480 if (successor.isVariadic()) { 2481 auto *m = opClass.addInlineMethod("::mlir::SuccessorRange", name); 2482 ERROR_IF_PRUNED(m, name, op); 2483 m->body() << formatv( 2484 " return {std::next((*this)->successor_begin(), {0}), " 2485 "(*this)->successor_end()};", 2486 i); 2487 continue; 2488 } 2489 2490 auto *m = opClass.addInlineMethod("::mlir::Block *", name); 2491 ERROR_IF_PRUNED(m, name, op); 2492 m->body() << formatv(" return (*this)->getSuccessor({0});", i); 2493 } 2494 } 2495 2496 static bool canGenerateUnwrappedBuilder(const Operator &op) { 2497 // If this op does not have native attributes at all, return directly to avoid 2498 // redefining builders. 2499 if (op.getNumNativeAttributes() == 0) 2500 return false; 2501 2502 bool canGenerate = false; 2503 // We are generating builders that take raw values for attributes. We need to 2504 // make sure the native attributes have a meaningful "unwrapped" value type 2505 // different from the wrapped mlir::Attribute type to avoid redefining 2506 // builders. This checks for the op has at least one such native attribute. 2507 for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) { 2508 const NamedAttribute &namedAttr = op.getAttribute(i); 2509 if (canUseUnwrappedRawValue(namedAttr.attr)) { 2510 canGenerate = true; 2511 break; 2512 } 2513 } 2514 return canGenerate; 2515 } 2516 2517 static bool canInferType(const Operator &op) { 2518 return op.getTrait("::mlir::InferTypeOpInterface::Trait"); 2519 } 2520 2521 void OpEmitter::genSeparateArgParamBuilder() { 2522 SmallVector<AttrParamKind, 2> attrBuilderType; 2523 attrBuilderType.push_back(AttrParamKind::WrappedAttr); 2524 if (canGenerateUnwrappedBuilder(op)) 2525 attrBuilderType.push_back(AttrParamKind::UnwrappedValue); 2526 2527 // Emit with separate builders with or without unwrapped attributes and/or 2528 // inferring result type. 2529 auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, 2530 bool inferType) { 2531 SmallVector<MethodParameter> paramList; 2532 SmallVector<std::string, 4> resultNames; 2533 llvm::StringSet<> inferredAttributes; 2534 buildParamList(paramList, inferredAttributes, resultNames, paramKind, 2535 attrType); 2536 2537 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 2538 // If the builder is redundant, skip generating the method. 2539 if (!m) 2540 return; 2541 auto &body = m->body(); 2542 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes, 2543 /*isRawValueAttr=*/attrType == 2544 AttrParamKind::UnwrappedValue); 2545 2546 // Push all result types to the operation state 2547 2548 if (inferType) { 2549 // Generate builder that infers type too. 2550 // TODO: Subsume this with general checking if type can be 2551 // inferred automatically. 2552 body << formatv(R"( 2553 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; 2554 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(), 2555 {1}.location, {1}.operands, 2556 {1}.attributes.getDictionary({1}.getContext()), 2557 {1}.getRawProperties(), 2558 {1}.regions, inferredReturnTypes))) 2559 {1}.addTypes(inferredReturnTypes); 2560 else 2561 ::mlir::detail::reportFatalInferReturnTypesError({1}); 2562 )", 2563 opClass.getClassName(), builderOpState); 2564 return; 2565 } 2566 2567 switch (paramKind) { 2568 case TypeParamKind::None: 2569 return; 2570 case TypeParamKind::Separate: 2571 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 2572 if (op.getResult(i).isOptional()) 2573 body << " if (" << resultNames[i] << ")\n "; 2574 body << " " << builderOpState << ".addTypes(" << resultNames[i] 2575 << ");\n"; 2576 } 2577 2578 // Automatically create the 'resultSegmentSizes' attribute using 2579 // the length of the type ranges. 2580 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { 2581 if (op.getDialect().usePropertiesForAttributes()) { 2582 body << " ::llvm::copy(::llvm::ArrayRef<int32_t>({"; 2583 } else { 2584 std::string getterName = op.getGetterName(resultSegmentAttrName); 2585 body << " " << builderOpState << ".addAttribute(" << getterName 2586 << "AttrName(" << builderOpState << ".name), " 2587 << "odsBuilder.getDenseI32ArrayAttr({"; 2588 } 2589 interleaveComma( 2590 llvm::seq<int>(0, op.getNumResults()), body, [&](int i) { 2591 const NamedTypeConstraint &result = op.getResult(i); 2592 if (!result.isVariableLength()) { 2593 body << "1"; 2594 } else if (result.isOptional()) { 2595 body << "(" << resultNames[i] << " ? 1 : 0)"; 2596 } else { 2597 // VariadicOfVariadic of results are currently unsupported in 2598 // MLIR, hence it can only be a simple variadic. 2599 // TODO: Add implementation for VariadicOfVariadic results here 2600 // once supported. 2601 assert(result.isVariadic()); 2602 body << "static_cast<int32_t>(" << resultNames[i] << ".size())"; 2603 } 2604 }); 2605 if (op.getDialect().usePropertiesForAttributes()) { 2606 body << "}), " << builderOpState 2607 << ".getOrAddProperties<Properties>()." 2608 "resultSegmentSizes.begin());\n"; 2609 } else { 2610 body << "}));\n"; 2611 } 2612 } 2613 2614 return; 2615 case TypeParamKind::Collective: { 2616 int numResults = op.getNumResults(); 2617 int numVariadicResults = op.getNumVariableLengthResults(); 2618 int numNonVariadicResults = numResults - numVariadicResults; 2619 bool hasVariadicResult = numVariadicResults != 0; 2620 2621 // Avoid emitting "resultTypes.size() >= 0u" which is always true. 2622 if (!hasVariadicResult || numNonVariadicResults != 0) 2623 body << " " 2624 << "assert(resultTypes.size() " 2625 << (hasVariadicResult ? ">=" : "==") << " " 2626 << numNonVariadicResults 2627 << "u && \"mismatched number of results\");\n"; 2628 body << " " << builderOpState << ".addTypes(resultTypes);\n"; 2629 } 2630 return; 2631 } 2632 llvm_unreachable("unhandled TypeParamKind"); 2633 }; 2634 2635 // Some of the build methods generated here may be ambiguous, but TableGen's 2636 // ambiguous function detection will elide those ones. 2637 for (auto attrType : attrBuilderType) { 2638 emit(attrType, TypeParamKind::Separate, /*inferType=*/false); 2639 if (canInferType(op)) 2640 emit(attrType, TypeParamKind::None, /*inferType=*/true); 2641 emit(attrType, TypeParamKind::Collective, /*inferType=*/false); 2642 } 2643 } 2644 2645 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { 2646 int numResults = op.getNumResults(); 2647 2648 // Signature 2649 SmallVector<MethodParameter> paramList; 2650 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); 2651 paramList.emplace_back("::mlir::OperationState &", builderOpState); 2652 paramList.emplace_back("::mlir::ValueRange", "operands"); 2653 // Provide default value for `attributes` when its the last parameter 2654 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; 2655 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", 2656 "attributes", attributesDefaultValue); 2657 if (op.getNumVariadicRegions()) 2658 paramList.emplace_back("unsigned", "numRegions"); 2659 2660 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 2661 // If the builder is redundant, skip generating the method 2662 if (!m) 2663 return; 2664 auto &body = m->body(); 2665 2666 // Operands 2667 body << " " << builderOpState << ".addOperands(operands);\n"; 2668 2669 // Attributes 2670 body << " " << builderOpState << ".addAttributes(attributes);\n"; 2671 2672 // Create the correct number of regions 2673 if (int numRegions = op.getNumRegions()) { 2674 body << llvm::formatv( 2675 " for (unsigned i = 0; i != {0}; ++i)\n", 2676 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); 2677 body << " (void)" << builderOpState << ".addRegion();\n"; 2678 } 2679 2680 // Result types 2681 SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()"); 2682 body << " " << builderOpState << ".addTypes({" 2683 << llvm::join(resultTypes, ", ") << "});\n\n"; 2684 } 2685 2686 void OpEmitter::genPopulateDefaultAttributes() { 2687 // All done if no attributes, except optional ones, have default values. 2688 if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) { 2689 return !named.attr.hasDefaultValue() || named.attr.isOptional(); 2690 })) 2691 return; 2692 2693 if (emitHelper.hasProperties()) { 2694 SmallVector<MethodParameter> paramList; 2695 paramList.emplace_back("::mlir::OperationName", "opName"); 2696 paramList.emplace_back("Properties &", "properties"); 2697 auto *m = 2698 opClass.addStaticMethod("void", "populateDefaultProperties", paramList); 2699 ERROR_IF_PRUNED(m, "populateDefaultProperties", op); 2700 auto &body = m->body(); 2701 body.indent(); 2702 body << "::mlir::Builder " << odsBuilder << "(opName.getContext());\n"; 2703 for (const NamedAttribute &namedAttr : op.getAttributes()) { 2704 auto &attr = namedAttr.attr; 2705 if (!attr.hasDefaultValue() || attr.isOptional()) 2706 continue; 2707 StringRef name = namedAttr.name; 2708 FmtContext fctx; 2709 fctx.withBuilder(odsBuilder); 2710 body << "if (!properties." << name << ")\n" 2711 << " properties." << name << " = " 2712 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx, 2713 tgfmt(attr.getDefaultValue(), &fctx))) 2714 << ";\n"; 2715 } 2716 return; 2717 } 2718 2719 SmallVector<MethodParameter> paramList; 2720 paramList.emplace_back("const ::mlir::OperationName &", "opName"); 2721 paramList.emplace_back("::mlir::NamedAttrList &", "attributes"); 2722 auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList); 2723 ERROR_IF_PRUNED(m, "populateDefaultAttrs", op); 2724 auto &body = m->body(); 2725 body.indent(); 2726 2727 // Set default attributes that are unset. 2728 body << "auto attrNames = opName.getAttributeNames();\n"; 2729 body << "::mlir::Builder " << odsBuilder 2730 << "(attrNames.front().getContext());\n"; 2731 StringMap<int> attrIndex; 2732 for (const auto &it : llvm::enumerate(emitHelper.getAttrMetadata())) { 2733 attrIndex[it.value().first] = it.index(); 2734 } 2735 for (const NamedAttribute &namedAttr : op.getAttributes()) { 2736 auto &attr = namedAttr.attr; 2737 if (!attr.hasDefaultValue() || attr.isOptional()) 2738 continue; 2739 auto index = attrIndex[namedAttr.name]; 2740 body << "if (!attributes.get(attrNames[" << index << "])) {\n"; 2741 FmtContext fctx; 2742 fctx.withBuilder(odsBuilder); 2743 2744 std::string defaultValue = 2745 std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx, 2746 tgfmt(attr.getDefaultValue(), &fctx))); 2747 body.indent() << formatv("attributes.append(attrNames[{0}], {1});\n", index, 2748 defaultValue); 2749 body.unindent() << "}\n"; 2750 } 2751 } 2752 2753 void OpEmitter::genInferredTypeCollectiveParamBuilder() { 2754 SmallVector<MethodParameter> paramList; 2755 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); 2756 paramList.emplace_back("::mlir::OperationState &", builderOpState); 2757 paramList.emplace_back("::mlir::ValueRange", "operands"); 2758 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; 2759 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", 2760 "attributes", attributesDefaultValue); 2761 if (op.getNumVariadicRegions()) 2762 paramList.emplace_back("unsigned", "numRegions"); 2763 2764 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 2765 // If the builder is redundant, skip generating the method 2766 if (!m) 2767 return; 2768 auto &body = m->body(); 2769 2770 int numResults = op.getNumResults(); 2771 int numVariadicResults = op.getNumVariableLengthResults(); 2772 int numNonVariadicResults = numResults - numVariadicResults; 2773 2774 int numOperands = op.getNumOperands(); 2775 int numVariadicOperands = op.getNumVariableLengthOperands(); 2776 int numNonVariadicOperands = numOperands - numVariadicOperands; 2777 2778 // Operands 2779 if (numVariadicOperands == 0 || numNonVariadicOperands != 0) 2780 body << " assert(operands.size()" 2781 << (numVariadicOperands != 0 ? " >= " : " == ") 2782 << numNonVariadicOperands 2783 << "u && \"mismatched number of parameters\");\n"; 2784 body << " " << builderOpState << ".addOperands(operands);\n"; 2785 body << " " << builderOpState << ".addAttributes(attributes);\n"; 2786 2787 // Create the correct number of regions 2788 if (int numRegions = op.getNumRegions()) { 2789 body << llvm::formatv( 2790 " for (unsigned i = 0; i != {0}; ++i)\n", 2791 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); 2792 body << " (void)" << builderOpState << ".addRegion();\n"; 2793 } 2794 2795 // Result types 2796 if (emitHelper.hasProperties()) { 2797 // Initialize the properties from Attributes before invoking the infer 2798 // function. 2799 body << formatv(R"( 2800 if (!attributes.empty()) { 2801 ::mlir::OpaqueProperties properties = 2802 &{1}.getOrAddProperties<{0}::Properties>(); 2803 std::optional<::mlir::RegisteredOperationName> info = 2804 {1}.name.getRegisteredInfo(); 2805 if (failed(info->setOpPropertiesFromAttribute({1}.name, properties, 2806 {1}.attributes.getDictionary({1}.getContext()), nullptr))) 2807 ::llvm::report_fatal_error("Property conversion failed."); 2808 })", 2809 opClass.getClassName(), builderOpState); 2810 } 2811 body << formatv(R"( 2812 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; 2813 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(), 2814 {1}.location, operands, 2815 {1}.attributes.getDictionary({1}.getContext()), 2816 {1}.getRawProperties(), 2817 {1}.regions, inferredReturnTypes))) {{)", 2818 opClass.getClassName(), builderOpState); 2819 if (numVariadicResults == 0 || numNonVariadicResults != 0) 2820 body << "\n assert(inferredReturnTypes.size()" 2821 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults 2822 << "u && \"mismatched number of return types\");"; 2823 body << "\n " << builderOpState << ".addTypes(inferredReturnTypes);"; 2824 2825 body << R"( 2826 } else { 2827 ::llvm::report_fatal_error("Failed to infer result type(s)."); 2828 })"; 2829 } 2830 2831 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { 2832 auto emit = [&](AttrParamKind attrType) { 2833 SmallVector<MethodParameter> paramList; 2834 SmallVector<std::string, 4> resultNames; 2835 llvm::StringSet<> inferredAttributes; 2836 buildParamList(paramList, inferredAttributes, resultNames, 2837 TypeParamKind::None, attrType); 2838 2839 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 2840 // If the builder is redundant, skip generating the method 2841 if (!m) 2842 return; 2843 auto &body = m->body(); 2844 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes, 2845 /*isRawValueAttr=*/attrType == 2846 AttrParamKind::UnwrappedValue); 2847 2848 auto numResults = op.getNumResults(); 2849 if (numResults == 0) 2850 return; 2851 2852 // Push all result types to the operation state 2853 const char *index = op.getOperand(0).isVariadic() ? ".front()" : ""; 2854 std::string resultType = 2855 formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str(); 2856 body << " " << builderOpState << ".addTypes({" << resultType; 2857 for (int i = 1; i != numResults; ++i) 2858 body << ", " << resultType; 2859 body << "});\n\n"; 2860 }; 2861 2862 emit(AttrParamKind::WrappedAttr); 2863 // Generate additional builder(s) if any attributes can be "unwrapped" 2864 if (canGenerateUnwrappedBuilder(op)) 2865 emit(AttrParamKind::UnwrappedValue); 2866 } 2867 2868 void OpEmitter::genUseAttrAsResultTypeBuilder() { 2869 SmallVector<MethodParameter> paramList; 2870 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); 2871 paramList.emplace_back("::mlir::OperationState &", builderOpState); 2872 paramList.emplace_back("::mlir::ValueRange", "operands"); 2873 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", 2874 "attributes", "{}"); 2875 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 2876 // If the builder is redundant, skip generating the method 2877 if (!m) 2878 return; 2879 2880 auto &body = m->body(); 2881 2882 // Push all result types to the operation state 2883 std::string resultType; 2884 const auto &namedAttr = op.getAttribute(0); 2885 2886 body << " auto attrName = " << op.getGetterName(namedAttr.name) 2887 << "AttrName(" << builderOpState 2888 << ".name);\n" 2889 " for (auto attr : attributes) {\n" 2890 " if (attr.getName() != attrName) continue;\n"; 2891 if (namedAttr.attr.isTypeAttr()) { 2892 resultType = "::llvm::cast<::mlir::TypeAttr>(attr.getValue()).getValue()"; 2893 } else { 2894 resultType = "::llvm::cast<::mlir::TypedAttr>(attr.getValue()).getType()"; 2895 } 2896 2897 // Operands 2898 body << " " << builderOpState << ".addOperands(operands);\n"; 2899 2900 // Attributes 2901 body << " " << builderOpState << ".addAttributes(attributes);\n"; 2902 2903 // Result types 2904 SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType); 2905 body << " " << builderOpState << ".addTypes({" 2906 << llvm::join(resultTypes, ", ") << "});\n"; 2907 body << " }\n"; 2908 } 2909 2910 /// Returns a signature of the builder. Updates the context `fctx` to enable 2911 /// replacement of $_builder and $_state in the body. 2912 static SmallVector<MethodParameter> 2913 getBuilderSignature(const Builder &builder) { 2914 ArrayRef<Builder::Parameter> params(builder.getParameters()); 2915 2916 // Inject builder and state arguments. 2917 SmallVector<MethodParameter> arguments; 2918 arguments.reserve(params.size() + 2); 2919 arguments.emplace_back("::mlir::OpBuilder &", odsBuilder); 2920 arguments.emplace_back("::mlir::OperationState &", builderOpState); 2921 2922 for (unsigned i = 0, e = params.size(); i < e; ++i) { 2923 // If no name is provided, generate one. 2924 std::optional<StringRef> paramName = params[i].getName(); 2925 std::string name = 2926 paramName ? paramName->str() : "odsArg" + std::to_string(i); 2927 2928 StringRef defaultValue; 2929 if (std::optional<StringRef> defaultParamValue = 2930 params[i].getDefaultValue()) 2931 defaultValue = *defaultParamValue; 2932 2933 arguments.emplace_back(params[i].getCppType(), std::move(name), 2934 defaultValue); 2935 } 2936 2937 return arguments; 2938 } 2939 2940 void OpEmitter::genBuilder() { 2941 // Handle custom builders if provided. 2942 for (const Builder &builder : op.getBuilders()) { 2943 SmallVector<MethodParameter> arguments = getBuilderSignature(builder); 2944 2945 std::optional<StringRef> body = builder.getBody(); 2946 auto properties = body ? Method::Static : Method::StaticDeclaration; 2947 auto *method = 2948 opClass.addMethod("void", "build", properties, std::move(arguments)); 2949 if (body) 2950 ERROR_IF_PRUNED(method, "build", op); 2951 2952 if (method) 2953 method->setDeprecated(builder.getDeprecatedMessage()); 2954 2955 FmtContext fctx; 2956 fctx.withBuilder(odsBuilder); 2957 fctx.addSubst("_state", builderOpState); 2958 if (body) 2959 method->body() << tgfmt(*body, &fctx); 2960 } 2961 2962 // Generate default builders that requires all result type, operands, and 2963 // attributes as parameters. 2964 if (op.skipDefaultBuilders()) 2965 return; 2966 2967 // We generate three classes of builders here: 2968 // 1. one having a stand-alone parameter for each operand / attribute, and 2969 genSeparateArgParamBuilder(); 2970 // 2. one having an aggregated parameter for all result types / operands / 2971 // attributes, and 2972 genCollectiveParamBuilder(); 2973 // 3. one having a stand-alone parameter for each operand and attribute, 2974 // use the first operand or attribute's type as all result types 2975 // to facilitate different call patterns. 2976 if (op.getNumVariableLengthResults() == 0) { 2977 if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { 2978 genUseOperandAsResultTypeSeparateParamBuilder(); 2979 genUseOperandAsResultTypeCollectiveParamBuilder(); 2980 } 2981 if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType")) 2982 genUseAttrAsResultTypeBuilder(); 2983 } 2984 } 2985 2986 void OpEmitter::genCollectiveParamBuilder() { 2987 int numResults = op.getNumResults(); 2988 int numVariadicResults = op.getNumVariableLengthResults(); 2989 int numNonVariadicResults = numResults - numVariadicResults; 2990 2991 int numOperands = op.getNumOperands(); 2992 int numVariadicOperands = op.getNumVariableLengthOperands(); 2993 int numNonVariadicOperands = numOperands - numVariadicOperands; 2994 2995 SmallVector<MethodParameter> paramList; 2996 paramList.emplace_back("::mlir::OpBuilder &", ""); 2997 paramList.emplace_back("::mlir::OperationState &", builderOpState); 2998 paramList.emplace_back("::mlir::TypeRange", "resultTypes"); 2999 paramList.emplace_back("::mlir::ValueRange", "operands"); 3000 // Provide default value for `attributes` when its the last parameter 3001 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; 3002 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", 3003 "attributes", attributesDefaultValue); 3004 if (op.getNumVariadicRegions()) 3005 paramList.emplace_back("unsigned", "numRegions"); 3006 3007 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 3008 // If the builder is redundant, skip generating the method 3009 if (!m) 3010 return; 3011 auto &body = m->body(); 3012 3013 // Operands 3014 if (numVariadicOperands == 0 || numNonVariadicOperands != 0) 3015 body << " assert(operands.size()" 3016 << (numVariadicOperands != 0 ? " >= " : " == ") 3017 << numNonVariadicOperands 3018 << "u && \"mismatched number of parameters\");\n"; 3019 body << " " << builderOpState << ".addOperands(operands);\n"; 3020 3021 // Attributes 3022 body << " " << builderOpState << ".addAttributes(attributes);\n"; 3023 3024 // Create the correct number of regions 3025 if (int numRegions = op.getNumRegions()) { 3026 body << llvm::formatv( 3027 " for (unsigned i = 0; i != {0}; ++i)\n", 3028 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); 3029 body << " (void)" << builderOpState << ".addRegion();\n"; 3030 } 3031 3032 // Result types 3033 if (numVariadicResults == 0 || numNonVariadicResults != 0) 3034 body << " assert(resultTypes.size()" 3035 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults 3036 << "u && \"mismatched number of return types\");\n"; 3037 body << " " << builderOpState << ".addTypes(resultTypes);\n"; 3038 3039 if (emitHelper.hasProperties()) { 3040 // Initialize the properties from Attributes before invoking the infer 3041 // function. 3042 body << formatv(R"( 3043 if (!attributes.empty()) { 3044 ::mlir::OpaqueProperties properties = 3045 &{1}.getOrAddProperties<{0}::Properties>(); 3046 std::optional<::mlir::RegisteredOperationName> info = 3047 {1}.name.getRegisteredInfo(); 3048 if (failed(info->setOpPropertiesFromAttribute({1}.name, properties, 3049 {1}.attributes.getDictionary({1}.getContext()), nullptr))) 3050 ::llvm::report_fatal_error("Property conversion failed."); 3051 })", 3052 opClass.getClassName(), builderOpState); 3053 } 3054 3055 // Generate builder that infers type too. 3056 // TODO: Expand to handle successors. 3057 if (canInferType(op) && op.getNumSuccessors() == 0) 3058 genInferredTypeCollectiveParamBuilder(); 3059 } 3060 3061 void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList, 3062 llvm::StringSet<> &inferredAttributes, 3063 SmallVectorImpl<std::string> &resultTypeNames, 3064 TypeParamKind typeParamKind, 3065 AttrParamKind attrParamKind) { 3066 resultTypeNames.clear(); 3067 auto numResults = op.getNumResults(); 3068 resultTypeNames.reserve(numResults); 3069 3070 paramList.emplace_back("::mlir::OpBuilder &", odsBuilder); 3071 paramList.emplace_back("::mlir::OperationState &", builderOpState); 3072 3073 switch (typeParamKind) { 3074 case TypeParamKind::None: 3075 break; 3076 case TypeParamKind::Separate: { 3077 // Add parameters for all return types 3078 for (int i = 0; i < numResults; ++i) { 3079 const auto &result = op.getResult(i); 3080 std::string resultName = std::string(result.name); 3081 if (resultName.empty()) 3082 resultName = std::string(formatv("resultType{0}", i)); 3083 3084 StringRef type = 3085 result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type"; 3086 3087 paramList.emplace_back(type, resultName, result.isOptional()); 3088 resultTypeNames.emplace_back(std::move(resultName)); 3089 } 3090 } break; 3091 case TypeParamKind::Collective: { 3092 paramList.emplace_back("::mlir::TypeRange", "resultTypes"); 3093 resultTypeNames.push_back("resultTypes"); 3094 } break; 3095 } 3096 3097 // Add parameters for all arguments (operands and attributes). 3098 // Track "attr-like" (property and attribute) optional values separate from 3099 // attributes themselves so that the disambiguation code can look at the first 3100 // attribute specifically when determining where to trim the optional-value 3101 // list to avoid ambiguity while preserving the ability of all-property ops to 3102 // use default parameters. 3103 int defaultValuedAttrLikeStartIndex = op.getNumArgs(); 3104 int defaultValuedAttrStartIndex = op.getNumArgs(); 3105 // Successors and variadic regions go at the end of the parameter list, so no 3106 // default arguments are possible. 3107 bool hasTrailingParams = op.getNumSuccessors() || op.getNumVariadicRegions(); 3108 if (!hasTrailingParams) { 3109 // Calculate the start index from which we can attach default values in the 3110 // builder declaration. 3111 for (int i = op.getNumArgs() - 1; i >= 0; --i) { 3112 auto *namedAttr = 3113 llvm::dyn_cast_if_present<tblgen::NamedAttribute *>(op.getArg(i)); 3114 auto *namedProperty = 3115 llvm::dyn_cast_if_present<tblgen::NamedProperty *>(op.getArg(i)); 3116 if (namedProperty) { 3117 Property prop = namedProperty->prop; 3118 if (!prop.hasDefaultValue()) 3119 break; 3120 defaultValuedAttrLikeStartIndex = i; 3121 continue; 3122 } 3123 if (!namedAttr) 3124 break; 3125 3126 Attribute attr = namedAttr->attr; 3127 // TODO: Currently we can't differentiate between optional meaning do not 3128 // verify/not always error if missing or optional meaning need not be 3129 // specified in builder. Expand isOptional once we can differentiate. 3130 if (!attr.hasDefaultValue() && !attr.isDerivedAttr()) 3131 break; 3132 3133 // Creating an APInt requires us to provide bitwidth, value, and 3134 // signedness, which is complicated compared to others. Similarly 3135 // for APFloat. 3136 // TODO: Adjust the 'returnType' field of such attributes 3137 // to support them. 3138 StringRef retType = namedAttr->attr.getReturnType(); 3139 if (retType == "::llvm::APInt" || retType == "::llvm::APFloat") 3140 break; 3141 3142 defaultValuedAttrLikeStartIndex = i; 3143 defaultValuedAttrStartIndex = i; 3144 } 3145 } 3146 3147 // Check if parameters besides default valued one are enough to distinguish 3148 // between builders with wrapped and unwrapped arguments. 3149 bool hasBuilderAmbiguity = true; 3150 for (const auto &arg : op.getArgs()) { 3151 auto *namedAttr = dyn_cast<NamedAttribute *>(arg); 3152 if (!namedAttr) 3153 continue; 3154 Attribute attr = namedAttr->attr; 3155 if (attr.hasDefaultValue() || attr.isDerivedAttr()) 3156 continue; 3157 3158 if (attrParamKind != AttrParamKind::WrappedAttr || 3159 !canUseUnwrappedRawValue(attr)) 3160 continue; 3161 3162 hasBuilderAmbiguity = false; 3163 break; 3164 } 3165 3166 // Avoid generating build methods that are ambiguous due to default values by 3167 // requiring at least one attribute. 3168 if (defaultValuedAttrStartIndex < op.getNumArgs()) { 3169 // TODO: This should have been possible as a cast<NamedAttribute> but 3170 // required template instantiations is not yet defined for the tblgen helper 3171 // classes. 3172 auto *namedAttr = 3173 cast<NamedAttribute *>(op.getArg(defaultValuedAttrStartIndex)); 3174 Attribute attr = namedAttr->attr; 3175 if ((attrParamKind == AttrParamKind::WrappedAttr && 3176 canUseUnwrappedRawValue(attr) && hasBuilderAmbiguity) || 3177 (attrParamKind == AttrParamKind::UnwrappedValue && 3178 !canUseUnwrappedRawValue(attr) && hasBuilderAmbiguity)) { 3179 ++defaultValuedAttrStartIndex; 3180 defaultValuedAttrLikeStartIndex = defaultValuedAttrStartIndex; 3181 } 3182 } 3183 3184 /// Collect any inferred attributes. 3185 for (const NamedTypeConstraint &operand : op.getOperands()) { 3186 if (operand.isVariadicOfVariadic()) { 3187 inferredAttributes.insert( 3188 operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); 3189 } 3190 } 3191 3192 for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) { 3193 Argument arg = op.getArg(i); 3194 if (const auto *operand = 3195 llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg)) { 3196 StringRef type; 3197 if (operand->isVariadicOfVariadic()) 3198 type = "::llvm::ArrayRef<::mlir::ValueRange>"; 3199 else if (operand->isVariadic()) 3200 type = "::mlir::ValueRange"; 3201 else 3202 type = "::mlir::Value"; 3203 3204 paramList.emplace_back(type, getArgumentName(op, numOperands++), 3205 operand->isOptional()); 3206 continue; 3207 } 3208 if (auto *propArg = llvm::dyn_cast_if_present<NamedProperty *>(arg)) { 3209 const Property &prop = propArg->prop; 3210 StringRef type = prop.getInterfaceType(); 3211 std::string defaultValue; 3212 if (prop.hasDefaultValue() && i >= defaultValuedAttrLikeStartIndex) { 3213 defaultValue = prop.getDefaultValue(); 3214 } 3215 bool isOptional = prop.hasDefaultValue(); 3216 paramList.emplace_back(type, propArg->name, StringRef(defaultValue), 3217 isOptional); 3218 continue; 3219 } 3220 const NamedAttribute &namedAttr = *cast<NamedAttribute *>(arg); 3221 const Attribute &attr = namedAttr.attr; 3222 3223 // Inferred attributes don't need to be added to the param list. 3224 if (inferredAttributes.contains(namedAttr.name)) 3225 continue; 3226 3227 StringRef type; 3228 switch (attrParamKind) { 3229 case AttrParamKind::WrappedAttr: 3230 type = attr.getStorageType(); 3231 break; 3232 case AttrParamKind::UnwrappedValue: 3233 if (canUseUnwrappedRawValue(attr)) 3234 type = attr.getReturnType(); 3235 else 3236 type = attr.getStorageType(); 3237 break; 3238 } 3239 3240 // Attach default value if requested and possible. 3241 std::string defaultValue; 3242 if (i >= defaultValuedAttrStartIndex) { 3243 if (attrParamKind == AttrParamKind::UnwrappedValue && 3244 canUseUnwrappedRawValue(attr)) 3245 defaultValue += attr.getDefaultValue(); 3246 else 3247 defaultValue += "nullptr"; 3248 } 3249 paramList.emplace_back(type, namedAttr.name, StringRef(defaultValue), 3250 attr.isOptional()); 3251 } 3252 3253 /// Insert parameters for each successor. 3254 for (const NamedSuccessor &succ : op.getSuccessors()) { 3255 StringRef type = 3256 succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *"; 3257 paramList.emplace_back(type, succ.name); 3258 } 3259 3260 /// Insert parameters for variadic regions. 3261 for (const NamedRegion ®ion : op.getRegions()) 3262 if (region.isVariadic()) 3263 paramList.emplace_back("unsigned", 3264 llvm::formatv("{0}Count", region.name).str()); 3265 } 3266 3267 void OpEmitter::genCodeForAddingArgAndRegionForBuilder( 3268 MethodBody &body, llvm::StringSet<> &inferredAttributes, 3269 bool isRawValueAttr) { 3270 // Push all operands to the result. 3271 for (int i = 0, e = op.getNumOperands(); i < e; ++i) { 3272 std::string argName = getArgumentName(op, i); 3273 const NamedTypeConstraint &operand = op.getOperand(i); 3274 if (operand.constraint.isVariadicOfVariadic()) { 3275 body << " for (::mlir::ValueRange range : " << argName << ")\n " 3276 << builderOpState << ".addOperands(range);\n"; 3277 3278 // Add the segment attribute. 3279 body << " {\n" 3280 << " ::llvm::SmallVector<int32_t> rangeSegments;\n" 3281 << " for (::mlir::ValueRange range : " << argName << ")\n" 3282 << " rangeSegments.push_back(range.size());\n" 3283 << " auto rangeAttr = " << odsBuilder 3284 << ".getDenseI32ArrayAttr(rangeSegments);\n"; 3285 if (op.getDialect().usePropertiesForAttributes()) { 3286 body << " " << builderOpState << ".getOrAddProperties<Properties>()." 3287 << operand.constraint.getVariadicOfVariadicSegmentSizeAttr() 3288 << " = rangeAttr;"; 3289 } else { 3290 body << " " << builderOpState << ".addAttribute(" 3291 << op.getGetterName( 3292 operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) 3293 << "AttrName(" << builderOpState << ".name), rangeAttr);"; 3294 } 3295 body << " }\n"; 3296 continue; 3297 } 3298 3299 if (operand.isOptional()) 3300 body << " if (" << argName << ")\n "; 3301 body << " " << builderOpState << ".addOperands(" << argName << ");\n"; 3302 } 3303 3304 // If the operation has the operand segment size attribute, add it here. 3305 auto emitSegment = [&]() { 3306 interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) { 3307 const NamedTypeConstraint &operand = op.getOperand(i); 3308 if (!operand.isVariableLength()) { 3309 body << "1"; 3310 return; 3311 } 3312 3313 std::string operandName = getArgumentName(op, i); 3314 if (operand.isOptional()) { 3315 body << "(" << operandName << " ? 1 : 0)"; 3316 } else if (operand.isVariadicOfVariadic()) { 3317 body << llvm::formatv( 3318 "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, " 3319 "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + " 3320 "static_cast<int32_t>(range.size()); }))", 3321 operandName); 3322 } else { 3323 body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())"; 3324 } 3325 }); 3326 }; 3327 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 3328 std::string sizes = op.getGetterName(operandSegmentAttrName); 3329 if (op.getDialect().usePropertiesForAttributes()) { 3330 body << " ::llvm::copy(::llvm::ArrayRef<int32_t>({"; 3331 emitSegment(); 3332 body << "}), " << builderOpState 3333 << ".getOrAddProperties<Properties>()." 3334 "operandSegmentSizes.begin());\n"; 3335 } else { 3336 body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName(" 3337 << builderOpState << ".name), " 3338 << "odsBuilder.getDenseI32ArrayAttr({"; 3339 emitSegment(); 3340 body << "}));\n"; 3341 } 3342 } 3343 3344 // Push all properties to the result. 3345 for (const auto &namedProp : op.getProperties()) { 3346 // Use the setter from the Properties struct since the conversion from the 3347 // interface type (used in the builder argument) to the storage type (used 3348 // in the state) is not necessarily trivial. 3349 std::string setterName = op.getSetterName(namedProp.name); 3350 body << formatv(" {0}.getOrAddProperties<Properties>().{1}({2});\n", 3351 builderOpState, setterName, namedProp.name); 3352 } 3353 // Push all attributes to the result. 3354 for (const auto &namedAttr : op.getAttributes()) { 3355 auto &attr = namedAttr.attr; 3356 if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name)) 3357 continue; 3358 3359 // TODO: The wrapping of optional is different for default or not, so don't 3360 // unwrap for default ones that would fail below. 3361 bool emitNotNullCheck = 3362 (attr.isOptional() && !attr.hasDefaultValue()) || 3363 (attr.hasDefaultValue() && !isRawValueAttr) || 3364 // TODO: UnitAttr is optional, not wrapped, but needs to be guarded as 3365 // the constant materialization is only for true case. 3366 (isRawValueAttr && attr.getAttrDefName() == "UnitAttr"); 3367 if (emitNotNullCheck) 3368 body.indent() << formatv("if ({0}) ", namedAttr.name) << "{\n"; 3369 3370 if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { 3371 // If this is a raw value, then we need to wrap it in an Attribute 3372 // instance. 3373 FmtContext fctx; 3374 fctx.withBuilder("odsBuilder"); 3375 if (op.getDialect().usePropertiesForAttributes()) { 3376 body << formatv(" {0}.getOrAddProperties<Properties>().{1} = {2};\n", 3377 builderOpState, namedAttr.name, 3378 constBuildAttrFromParam(attr, fctx, namedAttr.name)); 3379 } else { 3380 body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", 3381 builderOpState, op.getGetterName(namedAttr.name), 3382 constBuildAttrFromParam(attr, fctx, namedAttr.name)); 3383 } 3384 } else { 3385 if (op.getDialect().usePropertiesForAttributes()) { 3386 body << formatv(" {0}.getOrAddProperties<Properties>().{1} = {1};\n", 3387 builderOpState, namedAttr.name); 3388 } else { 3389 body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", 3390 builderOpState, op.getGetterName(namedAttr.name), 3391 namedAttr.name); 3392 } 3393 } 3394 if (emitNotNullCheck) 3395 body.unindent() << " }\n"; 3396 } 3397 3398 // Create the correct number of regions. 3399 for (const NamedRegion ®ion : op.getRegions()) { 3400 if (region.isVariadic()) 3401 body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ", 3402 region.name); 3403 3404 body << " (void)" << builderOpState << ".addRegion();\n"; 3405 } 3406 3407 // Push all successors to the result. 3408 for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) { 3409 body << formatv(" {0}.addSuccessors({1});\n", builderOpState, 3410 namedSuccessor.name); 3411 } 3412 } 3413 3414 void OpEmitter::genCanonicalizerDecls() { 3415 bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod"); 3416 if (hasCanonicalizeMethod) { 3417 // static LogicResult FooOp:: 3418 // canonicalize(FooOp op, PatternRewriter &rewriter); 3419 SmallVector<MethodParameter> paramList; 3420 paramList.emplace_back(op.getCppClassName(), "op"); 3421 paramList.emplace_back("::mlir::PatternRewriter &", "rewriter"); 3422 auto *m = opClass.declareStaticMethod("::llvm::LogicalResult", 3423 "canonicalize", std::move(paramList)); 3424 ERROR_IF_PRUNED(m, "canonicalize", op); 3425 } 3426 3427 // We get a prototype for 'getCanonicalizationPatterns' if requested directly 3428 // or if using a 'canonicalize' method. 3429 bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer"); 3430 if (!hasCanonicalizeMethod && !hasCanonicalizer) 3431 return; 3432 3433 // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize' 3434 // method, but not implementing 'getCanonicalizationPatterns' manually. 3435 bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer; 3436 3437 // Add a signature for getCanonicalizationPatterns if implemented by the 3438 // dialect or if synthesized to call 'canonicalize'. 3439 SmallVector<MethodParameter> paramList; 3440 paramList.emplace_back("::mlir::RewritePatternSet &", "results"); 3441 paramList.emplace_back("::mlir::MLIRContext *", "context"); 3442 auto kind = hasBody ? Method::Static : Method::StaticDeclaration; 3443 auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind, 3444 std::move(paramList)); 3445 3446 // If synthesizing the method, fill it. 3447 if (hasBody) { 3448 ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op); 3449 method->body() << " results.add(canonicalize);\n"; 3450 } 3451 } 3452 3453 void OpEmitter::genFolderDecls() { 3454 if (!op.hasFolder()) 3455 return; 3456 3457 SmallVector<MethodParameter> paramList; 3458 paramList.emplace_back("FoldAdaptor", "adaptor"); 3459 3460 StringRef retType; 3461 bool hasSingleResult = 3462 op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0; 3463 if (hasSingleResult) { 3464 retType = "::mlir::OpFoldResult"; 3465 } else { 3466 paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &", 3467 "results"); 3468 retType = "::llvm::LogicalResult"; 3469 } 3470 3471 auto *m = opClass.declareMethod(retType, "fold", std::move(paramList)); 3472 ERROR_IF_PRUNED(m, "fold", op); 3473 } 3474 3475 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) { 3476 Interface interface = opTrait->getInterface(); 3477 3478 // Get the set of methods that should always be declared. 3479 auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods(); 3480 llvm::StringSet<> alwaysDeclaredMethods; 3481 alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(), 3482 alwaysDeclaredMethodsVec.end()); 3483 3484 for (const InterfaceMethod &method : interface.getMethods()) { 3485 // Don't declare if the method has a body. 3486 if (method.getBody()) 3487 continue; 3488 // Don't declare if the method has a default implementation and the op 3489 // didn't request that it always be declared. 3490 if (method.getDefaultImplementation() && 3491 !alwaysDeclaredMethods.count(method.getName())) 3492 continue; 3493 // Interface methods are allowed to overlap with existing methods, so don't 3494 // check if pruned. 3495 (void)genOpInterfaceMethod(method); 3496 } 3497 } 3498 3499 Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method, 3500 bool declaration) { 3501 SmallVector<MethodParameter> paramList; 3502 for (const InterfaceMethod::Argument &arg : method.getArguments()) 3503 paramList.emplace_back(arg.type, arg.name); 3504 3505 auto props = (method.isStatic() ? Method::Static : Method::None) | 3506 (declaration ? Method::Declaration : Method::None); 3507 return opClass.addMethod(method.getReturnType(), method.getName(), props, 3508 std::move(paramList)); 3509 } 3510 3511 void OpEmitter::genOpInterfaceMethods() { 3512 for (const auto &trait : op.getTraits()) { 3513 if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait)) 3514 if (opTrait->shouldDeclareMethods()) 3515 genOpInterfaceMethods(opTrait); 3516 } 3517 } 3518 3519 void OpEmitter::genSideEffectInterfaceMethods() { 3520 enum EffectKind { Operand, Result, Symbol, Static }; 3521 struct EffectLocation { 3522 /// The effect applied. 3523 SideEffect effect; 3524 3525 /// The index if the kind is not static. 3526 unsigned index; 3527 3528 /// The kind of the location. 3529 unsigned kind; 3530 }; 3531 3532 StringMap<SmallVector<EffectLocation, 1>> interfaceEffects; 3533 auto resolveDecorators = [&](Operator::var_decorator_range decorators, 3534 unsigned index, unsigned kind) { 3535 for (auto decorator : decorators) 3536 if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) { 3537 opClass.addTrait(effect->getInterfaceTrait()); 3538 interfaceEffects[effect->getBaseEffectName()].push_back( 3539 EffectLocation{*effect, index, kind}); 3540 } 3541 }; 3542 3543 // Collect effects that were specified via: 3544 /// Traits. 3545 for (const auto &trait : op.getTraits()) { 3546 const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait); 3547 if (!opTrait) 3548 continue; 3549 auto &effects = interfaceEffects[opTrait->getBaseEffectName()]; 3550 for (auto decorator : opTrait->getEffects()) 3551 effects.push_back(EffectLocation{cast<SideEffect>(decorator), 3552 /*index=*/0, EffectKind::Static}); 3553 } 3554 /// Attributes and Operands. 3555 for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) { 3556 Argument arg = op.getArg(i); 3557 if (isa<NamedTypeConstraint *>(arg)) { 3558 resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand); 3559 ++operandIt; 3560 continue; 3561 } 3562 if (isa<NamedProperty *>(arg)) 3563 continue; 3564 const NamedAttribute *attr = cast<NamedAttribute *>(arg); 3565 if (attr->attr.getBaseAttr().isSymbolRefAttr()) 3566 resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol); 3567 } 3568 /// Results. 3569 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) 3570 resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result); 3571 3572 // The code used to add an effect instance. 3573 // {0}: The effect class. 3574 // {1}: Optional value or symbol reference. 3575 // {2}: The side effect stage. 3576 // {3}: Does this side effect act on every single value of resource. 3577 // {4}: The resource class. 3578 const char *addEffectCode = 3579 " effects.emplace_back({0}::get(), {1}{2}, {3}, {4}::get());\n"; 3580 3581 for (auto &it : interfaceEffects) { 3582 // Generate the 'getEffects' method. 3583 std::string type = llvm::formatv("::llvm::SmallVectorImpl<::mlir::" 3584 "SideEffects::EffectInstance<{0}>> &", 3585 it.first()) 3586 .str(); 3587 auto *getEffects = opClass.addMethod("void", "getEffects", 3588 MethodParameter(type, "effects")); 3589 ERROR_IF_PRUNED(getEffects, "getEffects", op); 3590 auto &body = getEffects->body(); 3591 3592 // Add effect instances for each of the locations marked on the operation. 3593 for (auto &location : it.second) { 3594 StringRef effect = location.effect.getName(); 3595 StringRef resource = location.effect.getResource(); 3596 int stage = (int)location.effect.getStage(); 3597 bool effectOnFullRegion = (int)location.effect.getEffectOnfullRegion(); 3598 if (location.kind == EffectKind::Static) { 3599 // A static instance has no attached value. 3600 body << llvm::formatv(addEffectCode, effect, "", stage, 3601 effectOnFullRegion, resource) 3602 .str(); 3603 } else if (location.kind == EffectKind::Symbol) { 3604 // A symbol reference requires adding the proper attribute. 3605 const auto *attr = cast<NamedAttribute *>(op.getArg(location.index)); 3606 std::string argName = op.getGetterName(attr->name); 3607 if (attr->attr.isOptional()) { 3608 body << " if (auto symbolRef = " << argName << "Attr())\n " 3609 << llvm::formatv(addEffectCode, effect, "symbolRef, ", stage, 3610 effectOnFullRegion, resource) 3611 .str(); 3612 } else { 3613 body << llvm::formatv(addEffectCode, effect, argName + "Attr(), ", 3614 stage, effectOnFullRegion, resource) 3615 .str(); 3616 } 3617 } else { 3618 // Otherwise this is an operand/result, so we need to attach the Value. 3619 body << " {\n auto valueRange = getODS" 3620 << (location.kind == EffectKind::Operand ? "Operand" : "Result") 3621 << "IndexAndLength(" << location.index << ");\n" 3622 << " for (unsigned idx = valueRange.first; idx < " 3623 "valueRange.first" 3624 << " + valueRange.second; idx++) {\n " 3625 << llvm::formatv(addEffectCode, effect, 3626 (location.kind == EffectKind::Operand 3627 ? "&getOperation()->getOpOperand(idx), " 3628 : "getOperation()->getOpResult(idx), "), 3629 stage, effectOnFullRegion, resource) 3630 << " }\n }\n"; 3631 } 3632 } 3633 } 3634 } 3635 3636 void OpEmitter::genTypeInterfaceMethods() { 3637 if (!op.allResultTypesKnown()) 3638 return; 3639 // Generate 'inferReturnTypes' method declaration using the interface method 3640 // declared in 'InferTypeOpInterface' op interface. 3641 const auto *trait = 3642 cast<InterfaceTrait>(op.getTrait("::mlir::InferTypeOpInterface::Trait")); 3643 Interface interface = trait->getInterface(); 3644 Method *method = [&]() -> Method * { 3645 for (const InterfaceMethod &interfaceMethod : interface.getMethods()) { 3646 if (interfaceMethod.getName() == "inferReturnTypes") { 3647 return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false); 3648 } 3649 } 3650 assert(0 && "unable to find inferReturnTypes interface method"); 3651 return nullptr; 3652 }(); 3653 ERROR_IF_PRUNED(method, "inferReturnTypes", op); 3654 auto &body = method->body(); 3655 body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n"; 3656 3657 FmtContext fctx; 3658 fctx.withBuilder("odsBuilder"); 3659 fctx.addSubst("_ctxt", "context"); 3660 body << " ::mlir::Builder odsBuilder(context);\n"; 3661 3662 // Preprocessing stage to verify all accesses to operands are valid. 3663 int maxAccessedIndex = -1; 3664 for (int i = 0, e = op.getNumResults(); i != e; ++i) { 3665 const InferredResultType &infer = op.getInferredResultType(i); 3666 if (!infer.isArg()) 3667 continue; 3668 Operator::OperandOrAttribute arg = 3669 op.getArgToOperandOrAttribute(infer.getIndex()); 3670 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { 3671 maxAccessedIndex = 3672 std::max(maxAccessedIndex, arg.operandOrAttributeIndex()); 3673 } 3674 } 3675 if (maxAccessedIndex != -1) { 3676 body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n"; 3677 body << " return ::mlir::failure();\n"; 3678 } 3679 3680 // Process the type inference graph in topological order, starting from types 3681 // that are always fully-inferred: operands and results with constructible 3682 // types. The type inference graph here will always be a DAG, so this gives 3683 // us the correct order for generating the types. -1 is a placeholder to 3684 // indicate the type for a result has not been generated. 3685 SmallVector<int> constructedIndices(op.getNumResults(), -1); 3686 int inferredTypeIdx = 0; 3687 for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) { 3688 for (int i = 0, e = op.getNumResults(); i != e; ++i) { 3689 if (constructedIndices[i] >= 0) 3690 continue; 3691 const InferredResultType &infer = op.getInferredResultType(i); 3692 std::string typeStr; 3693 if (infer.isArg()) { 3694 // If this is an operand, just index into operand list to access the 3695 // type. 3696 Operator::OperandOrAttribute arg = 3697 op.getArgToOperandOrAttribute(infer.getIndex()); 3698 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { 3699 typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) + 3700 "].getType()") 3701 .str(); 3702 3703 // If this is an attribute, index into the attribute dictionary. 3704 } else { 3705 auto *attr = 3706 cast<NamedAttribute *>(op.getArg(arg.operandOrAttributeIndex())); 3707 body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx 3708 << " = "; 3709 if (op.getDialect().usePropertiesForAttributes()) { 3710 body << "(properties ? properties.as<Properties *>()->" 3711 << attr->name 3712 << " : " 3713 "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes." 3714 "get(\"" + 3715 attr->name + "\")));\n"; 3716 } else { 3717 body << "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes." 3718 "get(\"" + 3719 attr->name + "\"));\n"; 3720 } 3721 body << " if (!odsInferredTypeAttr" << inferredTypeIdx 3722 << ") return ::mlir::failure();\n"; 3723 typeStr = 3724 ("odsInferredTypeAttr" + Twine(inferredTypeIdx) + ".getType()") 3725 .str(); 3726 } 3727 } else if (std::optional<StringRef> builder = 3728 op.getResult(infer.getResultIndex()) 3729 .constraint.getBuilderCall()) { 3730 typeStr = tgfmt(*builder, &fctx).str(); 3731 } else if (int index = constructedIndices[infer.getResultIndex()]; 3732 index >= 0) { 3733 typeStr = ("odsInferredType" + Twine(index)).str(); 3734 } else { 3735 continue; 3736 } 3737 body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = " 3738 << tgfmt(infer.getTransformer(), &fctx.withSelf(typeStr)) << ";\n"; 3739 constructedIndices[i] = inferredTypeIdx - 1; 3740 } 3741 } 3742 for (auto [i, index] : llvm::enumerate(constructedIndices)) 3743 body << " inferredReturnTypes[" << i << "] = odsInferredType" << index 3744 << ";\n"; 3745 body << " return ::mlir::success();"; 3746 } 3747 3748 void OpEmitter::genParser() { 3749 if (hasStringAttribute(def, "assemblyFormat")) 3750 return; 3751 3752 if (!def.getValueAsBit("hasCustomAssemblyFormat")) 3753 return; 3754 3755 SmallVector<MethodParameter> paramList; 3756 paramList.emplace_back("::mlir::OpAsmParser &", "parser"); 3757 paramList.emplace_back("::mlir::OperationState &", "result"); 3758 3759 auto *method = opClass.declareStaticMethod("::mlir::ParseResult", "parse", 3760 std::move(paramList)); 3761 ERROR_IF_PRUNED(method, "parse", op); 3762 } 3763 3764 void OpEmitter::genPrinter() { 3765 if (hasStringAttribute(def, "assemblyFormat")) 3766 return; 3767 3768 // Check to see if this op uses a c++ format. 3769 if (!def.getValueAsBit("hasCustomAssemblyFormat")) 3770 return; 3771 auto *method = opClass.declareMethod( 3772 "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p")); 3773 ERROR_IF_PRUNED(method, "print", op); 3774 } 3775 3776 void OpEmitter::genVerifier() { 3777 auto *implMethod = 3778 opClass.addMethod("::llvm::LogicalResult", "verifyInvariantsImpl"); 3779 ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op); 3780 auto &implBody = implMethod->body(); 3781 bool useProperties = emitHelper.hasProperties(); 3782 3783 populateSubstitutions(emitHelper, verifyCtx); 3784 genPropertyVerifier(emitHelper, verifyCtx, implBody); 3785 genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter, 3786 useProperties); 3787 genOperandResultVerifier(implBody, op.getOperands(), "operand"); 3788 genOperandResultVerifier(implBody, op.getResults(), "result"); 3789 3790 for (auto &trait : op.getTraits()) { 3791 if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) { 3792 implBody << tgfmt(" if (!($0))\n " 3793 "return emitOpError(\"failed to verify that $1\");\n", 3794 &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx), 3795 t->getSummary()); 3796 } 3797 } 3798 3799 genRegionVerifier(implBody); 3800 genSuccessorVerifier(implBody); 3801 3802 implBody << " return ::mlir::success();\n"; 3803 3804 // TODO: Some places use the `verifyInvariants` to do operation verification. 3805 // This may not act as their expectation because this doesn't call any 3806 // verifiers of native/interface traits. Needs to review those use cases and 3807 // see if we should use the mlir::verify() instead. 3808 auto *method = opClass.addMethod("::llvm::LogicalResult", "verifyInvariants"); 3809 ERROR_IF_PRUNED(method, "verifyInvariants", op); 3810 auto &body = method->body(); 3811 if (def.getValueAsBit("hasVerifier")) { 3812 body << " if(::mlir::succeeded(verifyInvariantsImpl()) && " 3813 "::mlir::succeeded(verify()))\n"; 3814 body << " return ::mlir::success();\n"; 3815 body << " return ::mlir::failure();"; 3816 } else { 3817 body << " return verifyInvariantsImpl();"; 3818 } 3819 } 3820 3821 void OpEmitter::genCustomVerifier() { 3822 if (def.getValueAsBit("hasVerifier")) { 3823 auto *method = opClass.declareMethod("::llvm::LogicalResult", "verify"); 3824 ERROR_IF_PRUNED(method, "verify", op); 3825 } 3826 3827 if (def.getValueAsBit("hasRegionVerifier")) { 3828 auto *method = 3829 opClass.declareMethod("::llvm::LogicalResult", "verifyRegions"); 3830 ERROR_IF_PRUNED(method, "verifyRegions", op); 3831 } 3832 } 3833 3834 void OpEmitter::genOperandResultVerifier(MethodBody &body, 3835 Operator::const_value_range values, 3836 StringRef valueKind) { 3837 // Check that an optional value is at most 1 element. 3838 // 3839 // {0}: Value index. 3840 // {1}: "operand" or "result" 3841 const char *const verifyOptional = R"( 3842 if (valueGroup{0}.size() > 1) { 3843 return emitOpError("{1} group starting at #") << index 3844 << " requires 0 or 1 element, but found " << valueGroup{0}.size(); 3845 } 3846 )"; 3847 // Check the types of a range of values. 3848 // 3849 // {0}: Value index. 3850 // {1}: Type constraint function. 3851 // {2}: "operand" or "result" 3852 const char *const verifyValues = R"( 3853 for (auto v : valueGroup{0}) { 3854 if (::mlir::failed({1}(*this, v.getType(), "{2}", index++))) 3855 return ::mlir::failure(); 3856 } 3857 )"; 3858 3859 const auto canSkip = [](const NamedTypeConstraint &value) { 3860 return !value.hasPredicate() && !value.isOptional() && 3861 !value.isVariadicOfVariadic(); 3862 }; 3863 if (values.empty() || llvm::all_of(values, canSkip)) 3864 return; 3865 3866 FmtContext fctx; 3867 3868 body << " {\n unsigned index = 0; (void)index;\n"; 3869 3870 for (const auto &staticValue : llvm::enumerate(values)) { 3871 const NamedTypeConstraint &value = staticValue.value(); 3872 3873 bool hasPredicate = value.hasPredicate(); 3874 bool isOptional = value.isOptional(); 3875 bool isVariadicOfVariadic = value.isVariadicOfVariadic(); 3876 if (!hasPredicate && !isOptional && !isVariadicOfVariadic) 3877 continue; 3878 body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n", 3879 // Capitalize the first letter to match the function name 3880 valueKind.substr(0, 1).upper(), valueKind.substr(1), 3881 staticValue.index()); 3882 3883 // If the constraint is optional check that the value group has at most 1 3884 // value. 3885 if (isOptional) { 3886 body << formatv(verifyOptional, staticValue.index(), valueKind); 3887 } else if (isVariadicOfVariadic) { 3888 body << formatv( 3889 " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr(" 3890 "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n" 3891 " return ::mlir::failure();\n", 3892 value.constraint.getVariadicOfVariadicSegmentSizeAttr(), value.name, 3893 staticValue.index()); 3894 } 3895 3896 // Otherwise, if there is no predicate there is nothing left to do. 3897 if (!hasPredicate) 3898 continue; 3899 // Emit a loop to check all the dynamic values in the pack. 3900 StringRef constraintFn = 3901 staticVerifierEmitter.getTypeConstraintFn(value.constraint); 3902 body << formatv(verifyValues, staticValue.index(), constraintFn, valueKind); 3903 } 3904 3905 body << " }\n"; 3906 } 3907 3908 void OpEmitter::genRegionVerifier(MethodBody &body) { 3909 /// Code to verify a region. 3910 /// 3911 /// {0}: Getter for the regions. 3912 /// {1}: The region constraint. 3913 /// {2}: The region's name. 3914 /// {3}: The region description. 3915 const char *const verifyRegion = R"( 3916 for (auto ®ion : {0}) 3917 if (::mlir::failed({1}(*this, region, "{2}", index++))) 3918 return ::mlir::failure(); 3919 )"; 3920 /// Get a single region. 3921 /// 3922 /// {0}: The region's index. 3923 const char *const getSingleRegion = 3924 "::llvm::MutableArrayRef((*this)->getRegion({0}))"; 3925 3926 // If we have no regions, there is nothing more to do. 3927 const auto canSkip = [](const NamedRegion ®ion) { 3928 return region.constraint.getPredicate().isNull(); 3929 }; 3930 auto regions = op.getRegions(); 3931 if (regions.empty() && llvm::all_of(regions, canSkip)) 3932 return; 3933 3934 body << " {\n unsigned index = 0; (void)index;\n"; 3935 for (const auto &it : llvm::enumerate(regions)) { 3936 const auto ®ion = it.value(); 3937 if (canSkip(region)) 3938 continue; 3939 3940 auto getRegion = region.isVariadic() 3941 ? formatv("{0}()", op.getGetterName(region.name)).str() 3942 : formatv(getSingleRegion, it.index()).str(); 3943 auto constraintFn = 3944 staticVerifierEmitter.getRegionConstraintFn(region.constraint); 3945 body << formatv(verifyRegion, getRegion, constraintFn, region.name); 3946 } 3947 body << " }\n"; 3948 } 3949 3950 void OpEmitter::genSuccessorVerifier(MethodBody &body) { 3951 const char *const verifySuccessor = R"( 3952 for (auto *successor : {0}) 3953 if (::mlir::failed({1}(*this, successor, "{2}", index++))) 3954 return ::mlir::failure(); 3955 )"; 3956 /// Get a single successor. 3957 /// 3958 /// {0}: The successor's name. 3959 const char *const getSingleSuccessor = "::llvm::MutableArrayRef({0}())"; 3960 3961 // If we have no successors, there is nothing more to do. 3962 const auto canSkip = [](const NamedSuccessor &successor) { 3963 return successor.constraint.getPredicate().isNull(); 3964 }; 3965 auto successors = op.getSuccessors(); 3966 if (successors.empty() && llvm::all_of(successors, canSkip)) 3967 return; 3968 3969 body << " {\n unsigned index = 0; (void)index;\n"; 3970 3971 for (auto it : llvm::enumerate(successors)) { 3972 const auto &successor = it.value(); 3973 if (canSkip(successor)) 3974 continue; 3975 3976 auto getSuccessor = 3977 formatv(successor.isVariadic() ? "{0}()" : getSingleSuccessor, 3978 successor.name) 3979 .str(); 3980 auto constraintFn = 3981 staticVerifierEmitter.getSuccessorConstraintFn(successor.constraint); 3982 body << formatv(verifySuccessor, getSuccessor, constraintFn, 3983 successor.name); 3984 } 3985 body << " }\n"; 3986 } 3987 3988 /// Add a size count trait to the given operation class. 3989 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind, 3990 int numTotal, int numVariadic) { 3991 if (numVariadic != 0) { 3992 if (numTotal == numVariadic) 3993 opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s"); 3994 else 3995 opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" + 3996 Twine(numTotal - numVariadic) + ">::Impl"); 3997 return; 3998 } 3999 switch (numTotal) { 4000 case 0: 4001 opClass.addTrait("::mlir::OpTrait::Zero" + traitKind + "s"); 4002 break; 4003 case 1: 4004 opClass.addTrait("::mlir::OpTrait::One" + traitKind); 4005 break; 4006 default: 4007 opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) + 4008 ">::Impl"); 4009 break; 4010 } 4011 } 4012 4013 void OpEmitter::genTraits() { 4014 // Add region size trait. 4015 unsigned numRegions = op.getNumRegions(); 4016 unsigned numVariadicRegions = op.getNumVariadicRegions(); 4017 addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions); 4018 4019 // Add result size traits. 4020 int numResults = op.getNumResults(); 4021 int numVariadicResults = op.getNumVariableLengthResults(); 4022 addSizeCountTrait(opClass, "Result", numResults, numVariadicResults); 4023 4024 // For single result ops with a known specific type, generate a OneTypedResult 4025 // trait. 4026 if (numResults == 1 && numVariadicResults == 0) { 4027 auto cppName = op.getResults().begin()->constraint.getCppType(); 4028 opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl"); 4029 } 4030 4031 // Add successor size trait. 4032 unsigned numSuccessors = op.getNumSuccessors(); 4033 unsigned numVariadicSuccessors = op.getNumVariadicSuccessors(); 4034 addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors); 4035 4036 // Add variadic size trait and normal op traits. 4037 int numOperands = op.getNumOperands(); 4038 int numVariadicOperands = op.getNumVariableLengthOperands(); 4039 4040 // Add operand size trait. 4041 addSizeCountTrait(opClass, "Operand", numOperands, numVariadicOperands); 4042 4043 // The op traits defined internal are ensured that they can be verified 4044 // earlier. 4045 for (const auto &trait : op.getTraits()) { 4046 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) { 4047 if (opTrait->isStructuralOpTrait()) 4048 opClass.addTrait(opTrait->getFullyQualifiedTraitName()); 4049 } 4050 } 4051 4052 // OpInvariants wrapps the verifyInvariants which needs to be run before 4053 // native/interface traits and after all the traits with `StructuralOpTrait`. 4054 opClass.addTrait("::mlir::OpTrait::OpInvariants"); 4055 4056 if (emitHelper.hasProperties()) 4057 opClass.addTrait("::mlir::BytecodeOpInterface::Trait"); 4058 4059 // Add the native and interface traits. 4060 for (const auto &trait : op.getTraits()) { 4061 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) { 4062 if (!opTrait->isStructuralOpTrait()) 4063 opClass.addTrait(opTrait->getFullyQualifiedTraitName()); 4064 } else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait)) { 4065 opClass.addTrait(opTrait->getFullyQualifiedTraitName()); 4066 } 4067 } 4068 } 4069 4070 void OpEmitter::genOpNameGetter() { 4071 auto *method = opClass.addStaticMethod<Method::Constexpr>( 4072 "::llvm::StringLiteral", "getOperationName"); 4073 ERROR_IF_PRUNED(method, "getOperationName", op); 4074 method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName() 4075 << "\");"; 4076 } 4077 4078 void OpEmitter::genOpAsmInterface() { 4079 // If the user only has one results or specifically added the Asm trait, 4080 // then don't generate it for them. We specifically only handle multi result 4081 // operations, because the name of a single result in the common case is not 4082 // interesting(generally 'result'/'output'/etc.). 4083 // TODO: We could also add a flag to allow operations to opt in to this 4084 // generation, even if they only have a single operation. 4085 int numResults = op.getNumResults(); 4086 if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait")) 4087 return; 4088 4089 SmallVector<StringRef, 4> resultNames(numResults); 4090 for (int i = 0; i != numResults; ++i) 4091 resultNames[i] = op.getResultName(i); 4092 4093 // Don't add the trait if none of the results have a valid name. 4094 if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); })) 4095 return; 4096 opClass.addTrait("::mlir::OpAsmOpInterface::Trait"); 4097 4098 // Generate the right accessor for the number of results. 4099 auto *method = opClass.addMethod( 4100 "void", "getAsmResultNames", 4101 MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn")); 4102 ERROR_IF_PRUNED(method, "getAsmResultNames", op); 4103 auto &body = method->body(); 4104 for (int i = 0; i != numResults; ++i) { 4105 body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n" 4106 << " if (!resultGroup" << i << ".empty())\n" 4107 << " setNameFn(*resultGroup" << i << ".begin(), \"" 4108 << resultNames[i] << "\");\n"; 4109 } 4110 } 4111 4112 //===----------------------------------------------------------------------===// 4113 // OpOperandAdaptor emitter 4114 //===----------------------------------------------------------------------===// 4115 4116 namespace { 4117 // Helper class to emit Op operand adaptors to an output stream. Operand 4118 // adaptors are wrappers around random access ranges that provide named operand 4119 // getters identical to those defined in the Op. 4120 // This currently generates 3 classes per Op: 4121 // * A Base class within the 'detail' namespace, which contains all logic and 4122 // members independent of the random access range that is indexed into. 4123 // In other words, it contains all the attribute and region getters. 4124 // * A templated class named '{OpName}GenericAdaptor' with a template parameter 4125 // 'RangeT' that is indexed into by the getters to access the operands. 4126 // It contains all getters to access operands and inherits from the previous 4127 // class. 4128 // * A class named '{OpName}Adaptor', which inherits from the 'GenericAdaptor' 4129 // with 'mlir::ValueRange' as template parameter. It adds a constructor from 4130 // an instance of the op type and a verify function. 4131 class OpOperandAdaptorEmitter { 4132 public: 4133 static void 4134 emitDecl(const Operator &op, 4135 const StaticVerifierFunctionEmitter &staticVerifierEmitter, 4136 raw_ostream &os); 4137 static void 4138 emitDef(const Operator &op, 4139 const StaticVerifierFunctionEmitter &staticVerifierEmitter, 4140 raw_ostream &os); 4141 4142 private: 4143 explicit OpOperandAdaptorEmitter( 4144 const Operator &op, 4145 const StaticVerifierFunctionEmitter &staticVerifierEmitter); 4146 4147 // Add verification function. This generates a verify method for the adaptor 4148 // which verifies all the op-independent attribute constraints. 4149 void addVerification(); 4150 4151 // The operation for which to emit an adaptor. 4152 const Operator &op; 4153 4154 // The generated adaptor classes. 4155 Class genericAdaptorBase; 4156 Class genericAdaptor; 4157 Class adaptor; 4158 4159 // The emitter containing all of the locally emitted verification functions. 4160 const StaticVerifierFunctionEmitter &staticVerifierEmitter; 4161 4162 // Helper for emitting adaptor code. 4163 OpOrAdaptorHelper emitHelper; 4164 }; 4165 } // namespace 4166 4167 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( 4168 const Operator &op, 4169 const StaticVerifierFunctionEmitter &staticVerifierEmitter) 4170 : op(op), genericAdaptorBase(op.getGenericAdaptorName() + "Base"), 4171 genericAdaptor(op.getGenericAdaptorName()), adaptor(op.getAdaptorName()), 4172 staticVerifierEmitter(staticVerifierEmitter), 4173 emitHelper(op, /*emitForOp=*/false) { 4174 4175 genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Public); 4176 bool useProperties = emitHelper.hasProperties(); 4177 if (useProperties) { 4178 // Define the properties struct with multiple members. 4179 using ConstArgument = 4180 llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>; 4181 SmallVector<ConstArgument> attrOrProperties; 4182 for (const std::pair<StringRef, AttributeMetadata> &it : 4183 emitHelper.getAttrMetadata()) { 4184 if (!it.second.constraint || !it.second.constraint->isDerivedAttr()) 4185 attrOrProperties.push_back(&it.second); 4186 } 4187 for (const NamedProperty &prop : op.getProperties()) 4188 attrOrProperties.push_back(&prop); 4189 if (emitHelper.getOperandSegmentsSize()) 4190 attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value()); 4191 if (emitHelper.getResultSegmentsSize()) 4192 attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value()); 4193 assert(!attrOrProperties.empty()); 4194 std::string declarations = " struct Properties {\n"; 4195 llvm::raw_string_ostream os(declarations); 4196 std::string comparator = 4197 " bool operator==(const Properties &rhs) const {\n" 4198 " return \n"; 4199 llvm::raw_string_ostream comparatorOs(comparator); 4200 for (const auto &attrOrProp : attrOrProperties) { 4201 if (const auto *namedProperty = 4202 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) { 4203 StringRef name = namedProperty->name; 4204 if (name.empty()) 4205 report_fatal_error("missing name for property"); 4206 std::string camelName = 4207 convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true); 4208 auto &prop = namedProperty->prop; 4209 // Generate the data member using the storage type. 4210 os << " using " << name << "Ty = " << prop.getStorageType() << ";\n" 4211 << " " << name << "Ty " << name; 4212 if (prop.hasStorageTypeValueOverride()) 4213 os << " = " << prop.getStorageTypeValueOverride(); 4214 else if (prop.hasDefaultValue()) 4215 os << " = " << prop.getDefaultValue(); 4216 comparatorOs << " rhs." << name << " == this->" << name 4217 << " &&\n"; 4218 // Emit accessors using the interface type. 4219 const char *accessorFmt = R"decl(; 4220 {0} get{1}() const { 4221 auto &propStorage = this->{2}; 4222 return {3}; 4223 } 4224 void set{1}({0} propValue) { 4225 auto &propStorage = this->{2}; 4226 {4}; 4227 } 4228 )decl"; 4229 FmtContext fctx; 4230 os << formatv(accessorFmt, prop.getInterfaceType(), camelName, name, 4231 tgfmt(prop.getConvertFromStorageCall(), 4232 &fctx.addSubst("_storage", propertyStorage)), 4233 tgfmt(prop.getAssignToStorageCall(), 4234 &fctx.addSubst("_value", propertyValue) 4235 .addSubst("_storage", propertyStorage))); 4236 continue; 4237 } 4238 const auto *namedAttr = 4239 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp); 4240 const Attribute *attr = nullptr; 4241 if (namedAttr->constraint) 4242 attr = &*namedAttr->constraint; 4243 StringRef name = namedAttr->attrName; 4244 if (name.empty()) 4245 report_fatal_error("missing name for property attr"); 4246 std::string camelName = 4247 convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true); 4248 // Generate the data member using the storage type. 4249 StringRef storageType; 4250 if (attr) { 4251 storageType = attr->getStorageType(); 4252 } else { 4253 if (name != operandSegmentAttrName && name != resultSegmentAttrName) { 4254 report_fatal_error("unexpected AttributeMetadata"); 4255 } 4256 // TODO: update to use native integers. 4257 storageType = "::mlir::DenseI32ArrayAttr"; 4258 } 4259 os << " using " << name << "Ty = " << storageType << ";\n" 4260 << " " << name << "Ty " << name << ";\n"; 4261 comparatorOs << " rhs." << name << " == this->" << name << " &&\n"; 4262 4263 // Emit accessors using the interface type. 4264 if (attr) { 4265 const char *accessorFmt = R"decl( 4266 auto get{0}() { 4267 auto &propStorage = this->{1}; 4268 return ::llvm::{2}<{3}>(propStorage); 4269 } 4270 void set{0}(const {3} &propValue) { 4271 this->{1} = propValue; 4272 } 4273 )decl"; 4274 os << formatv(accessorFmt, camelName, name, 4275 attr->isOptional() || attr->hasDefaultValue() 4276 ? "dyn_cast_or_null" 4277 : "cast", 4278 storageType); 4279 } 4280 } 4281 comparatorOs << " true;\n }\n" 4282 " bool operator!=(const Properties &rhs) const {\n" 4283 " return !(*this == rhs);\n" 4284 " }\n"; 4285 os << comparator; 4286 os << " };\n"; 4287 4288 genericAdaptorBase.declare<ExtraClassDeclaration>(std::move(declarations)); 4289 } 4290 genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Protected); 4291 genericAdaptorBase.declare<Field>("::mlir::DictionaryAttr", "odsAttrs"); 4292 genericAdaptorBase.declare<Field>("::std::optional<::mlir::OperationName>", 4293 "odsOpName"); 4294 if (useProperties) 4295 genericAdaptorBase.declare<Field>("Properties", "properties"); 4296 genericAdaptorBase.declare<Field>("::mlir::RegionRange", "odsRegions"); 4297 4298 genericAdaptor.addTemplateParam("RangeT"); 4299 genericAdaptor.addField("RangeT", "odsOperands"); 4300 genericAdaptor.addParent( 4301 ParentClass("detail::" + genericAdaptorBase.getClassName())); 4302 genericAdaptor.declare<UsingDeclaration>( 4303 "ValueT", "::llvm::detail::ValueOfRange<RangeT>"); 4304 genericAdaptor.declare<UsingDeclaration>( 4305 "Base", "detail::" + genericAdaptorBase.getClassName()); 4306 4307 const auto *attrSizedOperands = 4308 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); 4309 { 4310 SmallVector<MethodParameter> paramList; 4311 if (useProperties) { 4312 // Properties can't be given a default constructor here due to Properties 4313 // struct being defined in the enclosing class which isn't complete by 4314 // here. 4315 paramList.emplace_back("::mlir::DictionaryAttr", "attrs"); 4316 paramList.emplace_back("const Properties &", "properties"); 4317 } else { 4318 paramList.emplace_back("::mlir::DictionaryAttr", "attrs", "{}"); 4319 paramList.emplace_back("const ::mlir::EmptyProperties &", "properties", 4320 "{}"); 4321 } 4322 paramList.emplace_back("::mlir::RegionRange", "regions", "{}"); 4323 auto *baseConstructor = 4324 genericAdaptorBase.addConstructor<Method::Inline>(paramList); 4325 baseConstructor->addMemberInitializer("odsAttrs", "attrs"); 4326 if (useProperties) 4327 baseConstructor->addMemberInitializer("properties", "properties"); 4328 baseConstructor->addMemberInitializer("odsRegions", "regions"); 4329 4330 MethodBody &body = baseConstructor->body(); 4331 body.indent() << "if (odsAttrs)\n"; 4332 body.indent() << formatv( 4333 "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n", 4334 op.getOperationName()); 4335 4336 paramList.insert(paramList.begin(), MethodParameter("RangeT", "values")); 4337 auto *constructor = genericAdaptor.addConstructor(paramList); 4338 constructor->addMemberInitializer("Base", "attrs, properties, regions"); 4339 constructor->addMemberInitializer("odsOperands", "values"); 4340 4341 // Add a forwarding constructor to the previous one that accepts 4342 // OpaqueProperties instead and check for null and perform the cast to the 4343 // actual properties type. 4344 paramList[1] = MethodParameter("::mlir::DictionaryAttr", "attrs"); 4345 paramList[2] = MethodParameter("::mlir::OpaqueProperties", "properties"); 4346 auto *opaquePropertiesConstructor = 4347 genericAdaptor.addConstructor(std::move(paramList)); 4348 if (useProperties) { 4349 opaquePropertiesConstructor->addMemberInitializer( 4350 genericAdaptor.getClassName(), 4351 "values, " 4352 "attrs, " 4353 "(properties ? *properties.as<Properties *>() : Properties{}), " 4354 "regions"); 4355 } else { 4356 opaquePropertiesConstructor->addMemberInitializer( 4357 genericAdaptor.getClassName(), 4358 "values, " 4359 "attrs, " 4360 "(properties ? *properties.as<::mlir::EmptyProperties *>() : " 4361 "::mlir::EmptyProperties{}), " 4362 "regions"); 4363 } 4364 4365 // Add forwarding constructor that constructs Properties. 4366 if (useProperties) { 4367 SmallVector<MethodParameter> paramList; 4368 paramList.emplace_back("RangeT", "values"); 4369 paramList.emplace_back("::mlir::DictionaryAttr", "attrs", 4370 attrSizedOperands ? "" : "nullptr"); 4371 auto *noPropertiesConstructor = 4372 genericAdaptor.addConstructor(std::move(paramList)); 4373 noPropertiesConstructor->addMemberInitializer( 4374 genericAdaptor.getClassName(), "values, " 4375 "attrs, " 4376 "Properties{}, " 4377 "{}"); 4378 } 4379 } 4380 4381 // Create a constructor that creates a new generic adaptor by copying 4382 // everything from another adaptor, except for the values. 4383 { 4384 SmallVector<MethodParameter> paramList; 4385 paramList.emplace_back("RangeT", "values"); 4386 paramList.emplace_back("const " + op.getGenericAdaptorName() + "Base &", 4387 "base"); 4388 auto *constructor = 4389 genericAdaptor.addConstructor<Method::Inline>(paramList); 4390 constructor->addMemberInitializer("Base", "base"); 4391 constructor->addMemberInitializer("odsOperands", "values"); 4392 } 4393 4394 // Create constructors constructing the adaptor from an instance of the op. 4395 // This takes the attributes, properties and regions from the op instance 4396 // and the value range from the parameter. 4397 { 4398 // Base class is in the cpp file and can simply access the members of the op 4399 // class to initialize the template independent fields. If the op doesn't 4400 // have properties, we can emit a generic constructor inline. Otherwise, 4401 // emit it out-of-line because we need the op to be defined. 4402 Constructor *constructor; 4403 if (useProperties) { 4404 constructor = genericAdaptorBase.addConstructor( 4405 MethodParameter(op.getCppClassName(), "op")); 4406 } else { 4407 constructor = genericAdaptorBase.addConstructor<Method::Inline>( 4408 MethodParameter("::mlir::Operation *", "op")); 4409 } 4410 constructor->addMemberInitializer("odsAttrs", 4411 "op->getRawDictionaryAttrs()"); 4412 // Retrieve the operation name from the op directly. 4413 constructor->addMemberInitializer("odsOpName", "op->getName()"); 4414 if (useProperties) 4415 constructor->addMemberInitializer("properties", "op.getProperties()"); 4416 constructor->addMemberInitializer("odsRegions", "op->getRegions()"); 4417 4418 // Generic adaptor is templated and therefore defined inline in the header. 4419 // We cannot use the Op class here as it is an incomplete type (we have a 4420 // circular reference between the two). 4421 // Use a template trick to make the constructor be instantiated at call site 4422 // when the op class is complete. 4423 constructor = genericAdaptor.addConstructor( 4424 MethodParameter("RangeT", "values"), MethodParameter("LateInst", "op")); 4425 constructor->addTemplateParam("LateInst = " + op.getCppClassName()); 4426 constructor->addTemplateParam( 4427 "= std::enable_if_t<std::is_same_v<LateInst, " + op.getCppClassName() + 4428 ">>"); 4429 constructor->addMemberInitializer("Base", "op"); 4430 constructor->addMemberInitializer("odsOperands", "values"); 4431 } 4432 4433 std::string sizeAttrInit; 4434 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 4435 if (op.getDialect().usePropertiesForAttributes()) 4436 sizeAttrInit = 4437 formatv(adapterSegmentSizeAttrInitCodeProperties, 4438 llvm::formatv("getProperties().operandSegmentSizes")); 4439 else 4440 sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, 4441 emitHelper.getAttr(operandSegmentAttrName)); 4442 } 4443 generateNamedOperandGetters(op, genericAdaptor, 4444 /*genericAdaptorBase=*/&genericAdaptorBase, 4445 /*sizeAttrInit=*/sizeAttrInit, 4446 /*rangeType=*/"RangeT", 4447 /*rangeElementType=*/"ValueT", 4448 /*rangeBeginCall=*/"odsOperands.begin()", 4449 /*rangeSizeCall=*/"odsOperands.size()", 4450 /*getOperandCallPattern=*/"odsOperands[{0}]"); 4451 4452 // Any invalid overlap for `getOperands` will have been diagnosed before 4453 // here already. 4454 if (auto *m = genericAdaptor.addMethod("RangeT", "getOperands")) 4455 m->body() << " return odsOperands;"; 4456 4457 FmtContext fctx; 4458 fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); 4459 4460 // Generate named accessor with Attribute return type. 4461 auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName, 4462 Attribute attr) { 4463 // The method body is trivial if the attribute does not have a default 4464 // value, in which case the default value may be arbitrary code. 4465 auto *method = genericAdaptorBase.addMethod( 4466 attr.getStorageType(), emitName + "Attr", 4467 attr.hasDefaultValue() || !useProperties ? Method::Properties::None 4468 : Method::Properties::Inline); 4469 ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op); 4470 auto &body = method->body().indent(); 4471 if (!useProperties) 4472 body << "assert(odsAttrs && \"no attributes when constructing " 4473 "adapter\");\n"; 4474 body << formatv( 4475 "auto attr = ::llvm::{1}<{2}>({0});\n", emitHelper.getAttr(name), 4476 attr.hasDefaultValue() || attr.isOptional() ? "dyn_cast_or_null" 4477 : "cast", 4478 attr.getStorageType()); 4479 4480 if (attr.hasDefaultValue() && attr.isOptional()) { 4481 // Use the default value if attribute is not set. 4482 // TODO: this is inefficient, we are recreating the attribute for every 4483 // call. This should be set instead. 4484 std::string defaultValue = std::string( 4485 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); 4486 body << "if (!attr)\n attr = " << defaultValue << ";\n"; 4487 } 4488 body << "return attr;\n"; 4489 }; 4490 4491 if (useProperties) { 4492 auto *m = genericAdaptorBase.addInlineMethod("const Properties &", 4493 "getProperties"); 4494 ERROR_IF_PRUNED(m, "Adaptor::getProperties", op); 4495 m->body() << " return properties;"; 4496 } 4497 { 4498 auto *m = genericAdaptorBase.addInlineMethod("::mlir::DictionaryAttr", 4499 "getAttributes"); 4500 ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op); 4501 m->body() << " return odsAttrs;"; 4502 } 4503 for (auto &namedProp : op.getProperties()) { 4504 std::string name = op.getGetterName(namedProp.name); 4505 emitPropGetter(genericAdaptorBase, op, name, namedProp.prop); 4506 } 4507 4508 for (auto &namedAttr : op.getAttributes()) { 4509 const auto &name = namedAttr.name; 4510 const auto &attr = namedAttr.attr; 4511 if (attr.isDerivedAttr()) 4512 continue; 4513 std::string emitName = op.getGetterName(name); 4514 emitAttrWithStorageType(name, emitName, attr); 4515 emitAttrGetterWithReturnType(fctx, genericAdaptorBase, op, emitName, attr); 4516 } 4517 4518 unsigned numRegions = op.getNumRegions(); 4519 for (unsigned i = 0; i < numRegions; ++i) { 4520 const auto ®ion = op.getRegion(i); 4521 if (region.name.empty()) 4522 continue; 4523 4524 // Generate the accessors for a variadic region. 4525 std::string name = op.getGetterName(region.name); 4526 if (region.isVariadic()) { 4527 auto *m = genericAdaptorBase.addInlineMethod("::mlir::RegionRange", name); 4528 ERROR_IF_PRUNED(m, "Adaptor::" + name, op); 4529 m->body() << formatv(" return odsRegions.drop_front({0});", i); 4530 continue; 4531 } 4532 4533 auto *m = genericAdaptorBase.addInlineMethod("::mlir::Region &", name); 4534 ERROR_IF_PRUNED(m, "Adaptor::" + name, op); 4535 m->body() << formatv(" return *odsRegions[{0}];", i); 4536 } 4537 if (numRegions > 0) { 4538 // Any invalid overlap for `getRegions` will have been diagnosed before 4539 // here already. 4540 if (auto *m = genericAdaptorBase.addInlineMethod("::mlir::RegionRange", 4541 "getRegions")) 4542 m->body() << " return odsRegions;"; 4543 } 4544 4545 StringRef genericAdaptorClassName = genericAdaptor.getClassName(); 4546 adaptor.addParent(ParentClass(genericAdaptorClassName)) 4547 .addTemplateParam("::mlir::ValueRange"); 4548 adaptor.declare<VisibilityDeclaration>(Visibility::Public); 4549 adaptor.declare<UsingDeclaration>(genericAdaptorClassName + 4550 "::" + genericAdaptorClassName); 4551 { 4552 // Constructor taking the Op as single parameter. 4553 auto *constructor = 4554 adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op")); 4555 constructor->addMemberInitializer(genericAdaptorClassName, 4556 "op->getOperands(), op"); 4557 } 4558 4559 // Add verification function. 4560 addVerification(); 4561 4562 genericAdaptorBase.finalize(); 4563 genericAdaptor.finalize(); 4564 adaptor.finalize(); 4565 } 4566 4567 void OpOperandAdaptorEmitter::addVerification() { 4568 auto *method = adaptor.addMethod("::llvm::LogicalResult", "verify", 4569 MethodParameter("::mlir::Location", "loc")); 4570 ERROR_IF_PRUNED(method, "verify", op); 4571 auto &body = method->body(); 4572 bool useProperties = emitHelper.hasProperties(); 4573 4574 FmtContext verifyCtx; 4575 populateSubstitutions(emitHelper, verifyCtx); 4576 genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter, 4577 useProperties); 4578 4579 body << " return ::mlir::success();"; 4580 } 4581 4582 void OpOperandAdaptorEmitter::emitDecl( 4583 const Operator &op, 4584 const StaticVerifierFunctionEmitter &staticVerifierEmitter, 4585 raw_ostream &os) { 4586 OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter); 4587 { 4588 NamespaceEmitter ns(os, "detail"); 4589 emitter.genericAdaptorBase.writeDeclTo(os); 4590 } 4591 emitter.genericAdaptor.writeDeclTo(os); 4592 emitter.adaptor.writeDeclTo(os); 4593 } 4594 4595 void OpOperandAdaptorEmitter::emitDef( 4596 const Operator &op, 4597 const StaticVerifierFunctionEmitter &staticVerifierEmitter, 4598 raw_ostream &os) { 4599 OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter); 4600 { 4601 NamespaceEmitter ns(os, "detail"); 4602 emitter.genericAdaptorBase.writeDefTo(os); 4603 } 4604 emitter.genericAdaptor.writeDefTo(os); 4605 emitter.adaptor.writeDefTo(os); 4606 } 4607 4608 /// Emit the class declarations or definitions for the given op defs. 4609 static void 4610 emitOpClasses(const RecordKeeper &records, 4611 const std::vector<const Record *> &defs, raw_ostream &os, 4612 const StaticVerifierFunctionEmitter &staticVerifierEmitter, 4613 bool emitDecl) { 4614 if (defs.empty()) 4615 return; 4616 4617 for (auto *def : defs) { 4618 Operator op(*def); 4619 if (emitDecl) { 4620 { 4621 NamespaceEmitter emitter(os, op.getCppNamespace()); 4622 os << formatv(opCommentHeader, op.getQualCppClassName(), 4623 "declarations"); 4624 OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os); 4625 OpEmitter::emitDecl(op, os, staticVerifierEmitter); 4626 } 4627 // Emit the TypeID explicit specialization to have a single definition. 4628 if (!op.getCppNamespace().empty()) 4629 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace() 4630 << "::" << op.getCppClassName() << ")\n\n"; 4631 } else { 4632 { 4633 NamespaceEmitter emitter(os, op.getCppNamespace()); 4634 os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); 4635 OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os); 4636 OpEmitter::emitDef(op, os, staticVerifierEmitter); 4637 } 4638 // Emit the TypeID explicit specialization to have a single definition. 4639 if (!op.getCppNamespace().empty()) 4640 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace() 4641 << "::" << op.getCppClassName() << ")\n\n"; 4642 } 4643 } 4644 } 4645 4646 /// Emit the declarations for the provided op classes. 4647 static void emitOpClassDecls(const RecordKeeper &records, 4648 const std::vector<const Record *> &defs, 4649 raw_ostream &os) { 4650 // First emit forward declaration for each class, this allows them to refer 4651 // to each others in traits for example. 4652 for (auto *def : defs) { 4653 Operator op(*def); 4654 NamespaceEmitter emitter(os, op.getCppNamespace()); 4655 os << "class " << op.getCppClassName() << ";\n"; 4656 } 4657 4658 // Emit the op class declarations. 4659 IfDefScope scope("GET_OP_CLASSES", os); 4660 if (defs.empty()) 4661 return; 4662 StaticVerifierFunctionEmitter staticVerifierEmitter(os, records); 4663 staticVerifierEmitter.collectOpConstraints(defs); 4664 emitOpClasses(records, defs, os, staticVerifierEmitter, 4665 /*emitDecl=*/true); 4666 } 4667 4668 /// Emit the definitions for the provided op classes. 4669 static void emitOpClassDefs(const RecordKeeper &records, 4670 ArrayRef<const Record *> defs, raw_ostream &os, 4671 StringRef constraintPrefix = "") { 4672 if (defs.empty()) 4673 return; 4674 4675 // Generate all of the locally instantiated methods first. 4676 StaticVerifierFunctionEmitter staticVerifierEmitter(os, records, 4677 constraintPrefix); 4678 os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); 4679 staticVerifierEmitter.collectOpConstraints(defs); 4680 staticVerifierEmitter.emitOpConstraints(defs); 4681 4682 // Emit the classes. 4683 emitOpClasses(records, defs, os, staticVerifierEmitter, 4684 /*emitDecl=*/false); 4685 } 4686 4687 /// Emit op declarations for all op records. 4688 static bool emitOpDecls(const RecordKeeper &records, raw_ostream &os) { 4689 emitSourceFileHeader("Op Declarations", os, records); 4690 4691 std::vector<const Record *> defs = getRequestedOpDefinitions(records); 4692 emitOpClassDecls(records, defs, os); 4693 4694 // If we are generating sharded op definitions, emit the sharded op 4695 // registration hooks. 4696 SmallVector<ArrayRef<const Record *>, 4> shardedDefs; 4697 shardOpDefinitions(defs, shardedDefs); 4698 if (defs.empty() || shardedDefs.size() <= 1) 4699 return false; 4700 4701 Dialect dialect = Operator(defs.front()).getDialect(); 4702 NamespaceEmitter ns(os, dialect); 4703 4704 const char *const opRegistrationHook = 4705 "void register{0}Operations{1}({2}::{0} *dialect);\n"; 4706 os << formatv(opRegistrationHook, dialect.getCppClassName(), "", 4707 dialect.getCppNamespace()); 4708 for (unsigned i = 0; i < shardedDefs.size(); ++i) { 4709 os << formatv(opRegistrationHook, dialect.getCppClassName(), i, 4710 dialect.getCppNamespace()); 4711 } 4712 4713 return false; 4714 } 4715 4716 /// Generate the dialect op registration hook and the op class definitions for a 4717 /// shard of ops. 4718 static void emitOpDefShard(const RecordKeeper &records, 4719 ArrayRef<const Record *> defs, 4720 const Dialect &dialect, unsigned shardIndex, 4721 unsigned shardCount, raw_ostream &os) { 4722 std::string shardGuard = "GET_OP_DEFS_"; 4723 std::string indexStr = std::to_string(shardIndex); 4724 shardGuard += indexStr; 4725 IfDefScope scope(shardGuard, os); 4726 4727 // Emit the op registration hook in the first shard. 4728 const char *const opRegistrationHook = 4729 "void {0}::register{1}Operations{2}({0}::{1} *dialect) {{\n"; 4730 if (shardIndex == 0) { 4731 os << formatv(opRegistrationHook, dialect.getCppNamespace(), 4732 dialect.getCppClassName(), ""); 4733 for (unsigned i = 0; i < shardCount; ++i) { 4734 os << formatv(" {0}::register{1}Operations{2}(dialect);\n", 4735 dialect.getCppNamespace(), dialect.getCppClassName(), i); 4736 } 4737 os << "}\n"; 4738 } 4739 4740 // Generate the per-shard op registration hook. 4741 os << formatv(opCommentHeader, dialect.getCppClassName(), 4742 "Op Registration Hook") 4743 << formatv(opRegistrationHook, dialect.getCppNamespace(), 4744 dialect.getCppClassName(), shardIndex); 4745 for (const Record *def : defs) { 4746 os << formatv(" ::mlir::RegisteredOperationName::insert<{0}>(*dialect);\n", 4747 Operator(def).getQualCppClassName()); 4748 } 4749 os << "}\n"; 4750 4751 // Generate the per-shard op definitions. 4752 emitOpClassDefs(records, defs, os, indexStr); 4753 } 4754 4755 /// Emit op definitions for all op records. 4756 static bool emitOpDefs(const RecordKeeper &records, raw_ostream &os) { 4757 emitSourceFileHeader("Op Definitions", os, records); 4758 4759 std::vector<const Record *> defs = getRequestedOpDefinitions(records); 4760 SmallVector<ArrayRef<const Record *>, 4> shardedDefs; 4761 shardOpDefinitions(defs, shardedDefs); 4762 4763 // If no shard was requested, emit the regular op list and class definitions. 4764 if (shardedDefs.size() == 1) { 4765 { 4766 IfDefScope scope("GET_OP_LIST", os); 4767 interleave( 4768 defs, os, 4769 [&](const Record *def) { os << Operator(def).getQualCppClassName(); }, 4770 ",\n"); 4771 } 4772 { 4773 IfDefScope scope("GET_OP_CLASSES", os); 4774 emitOpClassDefs(records, defs, os); 4775 } 4776 return false; 4777 } 4778 4779 if (defs.empty()) 4780 return false; 4781 Dialect dialect = Operator(defs.front()).getDialect(); 4782 for (auto [idx, value] : llvm::enumerate(shardedDefs)) { 4783 emitOpDefShard(records, value, dialect, idx, shardedDefs.size(), os); 4784 } 4785 return false; 4786 } 4787 4788 static mlir::GenRegistration 4789 genOpDecls("gen-op-decls", "Generate op declarations", 4790 [](const RecordKeeper &records, raw_ostream &os) { 4791 return emitOpDecls(records, os); 4792 }); 4793 4794 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions", 4795 [](const RecordKeeper &records, 4796 raw_ostream &os) { 4797 return emitOpDefs(records, os); 4798 }); 4799