1 //===- Operator.cpp - Operator class --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Operator wrapper to simplify using TableGen Record defining a MLIR Op. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/TableGen/Operator.h" 14 #include "mlir/TableGen/Argument.h" 15 #include "mlir/TableGen/Predicate.h" 16 #include "mlir/TableGen/Trait.h" 17 #include "mlir/TableGen/Type.h" 18 #include "llvm/ADT/EquivalenceClasses.h" 19 #include "llvm/ADT/STLExtras.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/ADT/SmallPtrSet.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/ErrorHandling.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include "llvm/TableGen/Error.h" 28 #include "llvm/TableGen/Record.h" 29 #include <list> 30 31 #define DEBUG_TYPE "mlir-tblgen-operator" 32 33 using namespace mlir; 34 using namespace mlir::tblgen; 35 36 using llvm::DagInit; 37 using llvm::DefInit; 38 using llvm::Init; 39 using llvm::ListInit; 40 using llvm::Record; 41 using llvm::StringInit; 42 43 Operator::Operator(const Record &def) 44 : dialect(def.getValueAsDef("opDialect")), def(def) { 45 // The first `_` in the op's TableGen def name is treated as separating the 46 // dialect prefix and the op class name. The dialect prefix will be ignored if 47 // not empty. Otherwise, if def name starts with a `_`, the `_` is considered 48 // as part of the class name. 49 StringRef prefix; 50 std::tie(prefix, cppClassName) = def.getName().split('_'); 51 if (prefix.empty()) { 52 // Class name with a leading underscore and without dialect prefix 53 cppClassName = def.getName(); 54 } else if (cppClassName.empty()) { 55 // Class name without dialect prefix 56 cppClassName = prefix; 57 } 58 59 cppNamespace = def.getValueAsString("cppNamespace"); 60 61 populateOpStructure(); 62 assertInvariants(); 63 } 64 65 std::string Operator::getOperationName() const { 66 auto prefix = dialect.getName(); 67 auto opName = def.getValueAsString("opName"); 68 if (prefix.empty()) 69 return std::string(opName); 70 return std::string(llvm::formatv("{0}.{1}", prefix, opName)); 71 } 72 73 std::string Operator::getAdaptorName() const { 74 return std::string(llvm::formatv("{0}Adaptor", getCppClassName())); 75 } 76 77 std::string Operator::getGenericAdaptorName() const { 78 return std::string(llvm::formatv("{0}GenericAdaptor", getCppClassName())); 79 } 80 81 /// Assert the invariants of accessors generated for the given name. 82 static void assertAccessorInvariants(const Operator &op, StringRef name) { 83 std::string accessorName = 84 convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true); 85 86 // Functor used to detect when an accessor will cause an overlap with an 87 // operation API. 88 // 89 // There are a little bit more invasive checks possible for cases where not 90 // all ops have the trait that would cause overlap. For many cases here, 91 // renaming would be better (e.g., we can only guard in limited manner 92 // against methods from traits and interfaces here, so avoiding these in op 93 // definition is safer). 94 auto nameOverlapsWithOpAPI = [&](StringRef newName) { 95 if (newName == "AttributeNames" || newName == "Attributes" || 96 newName == "Operation") 97 return true; 98 if (newName == "Operands") 99 return op.getNumOperands() != 1 || op.getNumVariableLengthOperands() != 1; 100 if (newName == "Regions") 101 return op.getNumRegions() != 1 || op.getNumVariadicRegions() != 1; 102 if (newName == "Type") 103 return op.getNumResults() != 1; 104 return false; 105 }; 106 if (nameOverlapsWithOpAPI(accessorName)) { 107 // This error could be avoided in situations where the final function is 108 // identical, but preferably the op definition should avoid using generic 109 // names. 110 PrintFatalError(op.getLoc(), "generated accessor for `" + name + 111 "` overlaps with a default one; please " 112 "rename to avoid overlap"); 113 } 114 } 115 116 void Operator::assertInvariants() const { 117 // Check that the name of arguments/results/regions/successors don't overlap. 118 DenseMap<StringRef, StringRef> existingNames; 119 auto checkName = [&](StringRef name, StringRef entity) { 120 if (name.empty()) 121 return; 122 auto insertion = existingNames.insert({name, entity}); 123 if (insertion.second) { 124 // Assert invariants for accessors generated for this name. 125 assertAccessorInvariants(*this, name); 126 return; 127 } 128 if (entity == insertion.first->second) 129 PrintFatalError(getLoc(), "op has a conflict with two " + entity + 130 " having the same name '" + name + "'"); 131 PrintFatalError(getLoc(), "op has a conflict with " + 132 insertion.first->second + " and " + entity + 133 " both having an entry with the name '" + 134 name + "'"); 135 }; 136 // Check operands amongst themselves. 137 for (int i : llvm::seq<int>(0, getNumOperands())) 138 checkName(getOperand(i).name, "operands"); 139 140 // Check results amongst themselves and against operands. 141 for (int i : llvm::seq<int>(0, getNumResults())) 142 checkName(getResult(i).name, "results"); 143 144 // Check regions amongst themselves and against operands and results. 145 for (int i : llvm::seq<int>(0, getNumRegions())) 146 checkName(getRegion(i).name, "regions"); 147 148 // Check successors amongst themselves and against operands, results, and 149 // regions. 150 for (int i : llvm::seq<int>(0, getNumSuccessors())) 151 checkName(getSuccessor(i).name, "successors"); 152 } 153 154 StringRef Operator::getDialectName() const { return dialect.getName(); } 155 156 StringRef Operator::getCppClassName() const { return cppClassName; } 157 158 std::string Operator::getQualCppClassName() const { 159 if (cppNamespace.empty()) 160 return std::string(cppClassName); 161 return std::string(llvm::formatv("{0}::{1}", cppNamespace, cppClassName)); 162 } 163 164 StringRef Operator::getCppNamespace() const { return cppNamespace; } 165 166 int Operator::getNumResults() const { 167 const DagInit *results = def.getValueAsDag("results"); 168 return results->getNumArgs(); 169 } 170 171 StringRef Operator::getExtraClassDeclaration() const { 172 constexpr auto attr = "extraClassDeclaration"; 173 if (def.isValueUnset(attr)) 174 return {}; 175 return def.getValueAsString(attr); 176 } 177 178 StringRef Operator::getExtraClassDefinition() const { 179 constexpr auto attr = "extraClassDefinition"; 180 if (def.isValueUnset(attr)) 181 return {}; 182 return def.getValueAsString(attr); 183 } 184 185 const Record &Operator::getDef() const { return def; } 186 187 bool Operator::skipDefaultBuilders() const { 188 return def.getValueAsBit("skipDefaultBuilders"); 189 } 190 191 auto Operator::result_begin() const -> const_value_iterator { 192 return results.begin(); 193 } 194 195 auto Operator::result_end() const -> const_value_iterator { 196 return results.end(); 197 } 198 199 auto Operator::getResults() const -> const_value_range { 200 return {result_begin(), result_end()}; 201 } 202 203 TypeConstraint Operator::getResultTypeConstraint(int index) const { 204 const DagInit *results = def.getValueAsDag("results"); 205 return TypeConstraint(cast<DefInit>(results->getArg(index))); 206 } 207 208 StringRef Operator::getResultName(int index) const { 209 const DagInit *results = def.getValueAsDag("results"); 210 return results->getArgNameStr(index); 211 } 212 213 auto Operator::getResultDecorators(int index) const -> var_decorator_range { 214 const Record *result = 215 cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef(); 216 if (!result->isSubClassOf("OpVariable")) 217 return var_decorator_range(nullptr, nullptr); 218 return *result->getValueAsListInit("decorators"); 219 } 220 221 unsigned Operator::getNumVariableLengthResults() const { 222 return llvm::count_if(results, [](const NamedTypeConstraint &c) { 223 return c.constraint.isVariableLength(); 224 }); 225 } 226 227 unsigned Operator::getNumVariableLengthOperands() const { 228 return llvm::count_if(operands, [](const NamedTypeConstraint &c) { 229 return c.constraint.isVariableLength(); 230 }); 231 } 232 233 bool Operator::hasSingleVariadicArg() const { 234 return getNumArgs() == 1 && isa<NamedTypeConstraint *>(getArg(0)) && 235 getOperand(0).isVariadic(); 236 } 237 238 Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); } 239 240 Operator::arg_iterator Operator::arg_end() const { return arguments.end(); } 241 242 Operator::arg_range Operator::getArgs() const { 243 return {arg_begin(), arg_end()}; 244 } 245 246 StringRef Operator::getArgName(int index) const { 247 const DagInit *argumentValues = def.getValueAsDag("arguments"); 248 return argumentValues->getArgNameStr(index); 249 } 250 251 auto Operator::getArgDecorators(int index) const -> var_decorator_range { 252 const Record *arg = 253 cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef(); 254 if (!arg->isSubClassOf("OpVariable")) 255 return var_decorator_range(nullptr, nullptr); 256 return *arg->getValueAsListInit("decorators"); 257 } 258 259 const Trait *Operator::getTrait(StringRef trait) const { 260 for (const auto &t : traits) { 261 if (const auto *traitDef = dyn_cast<NativeTrait>(&t)) { 262 if (traitDef->getFullyQualifiedTraitName() == trait) 263 return traitDef; 264 } else if (const auto *traitDef = dyn_cast<InternalTrait>(&t)) { 265 if (traitDef->getFullyQualifiedTraitName() == trait) 266 return traitDef; 267 } else if (const auto *traitDef = dyn_cast<InterfaceTrait>(&t)) { 268 if (traitDef->getFullyQualifiedTraitName() == trait) 269 return traitDef; 270 } 271 } 272 return nullptr; 273 } 274 275 auto Operator::region_begin() const -> const_region_iterator { 276 return regions.begin(); 277 } 278 auto Operator::region_end() const -> const_region_iterator { 279 return regions.end(); 280 } 281 auto Operator::getRegions() const 282 -> llvm::iterator_range<const_region_iterator> { 283 return {region_begin(), region_end()}; 284 } 285 286 unsigned Operator::getNumRegions() const { return regions.size(); } 287 288 const NamedRegion &Operator::getRegion(unsigned index) const { 289 return regions[index]; 290 } 291 292 unsigned Operator::getNumVariadicRegions() const { 293 return llvm::count_if(regions, 294 [](const NamedRegion &c) { return c.isVariadic(); }); 295 } 296 297 auto Operator::successor_begin() const -> const_successor_iterator { 298 return successors.begin(); 299 } 300 auto Operator::successor_end() const -> const_successor_iterator { 301 return successors.end(); 302 } 303 auto Operator::getSuccessors() const 304 -> llvm::iterator_range<const_successor_iterator> { 305 return {successor_begin(), successor_end()}; 306 } 307 308 unsigned Operator::getNumSuccessors() const { return successors.size(); } 309 310 const NamedSuccessor &Operator::getSuccessor(unsigned index) const { 311 return successors[index]; 312 } 313 314 unsigned Operator::getNumVariadicSuccessors() const { 315 return llvm::count_if(successors, 316 [](const NamedSuccessor &c) { return c.isVariadic(); }); 317 } 318 319 auto Operator::trait_begin() const -> const_trait_iterator { 320 return traits.begin(); 321 } 322 auto Operator::trait_end() const -> const_trait_iterator { 323 return traits.end(); 324 } 325 auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> { 326 return {trait_begin(), trait_end()}; 327 } 328 329 auto Operator::attribute_begin() const -> const_attribute_iterator { 330 return attributes.begin(); 331 } 332 auto Operator::attribute_end() const -> const_attribute_iterator { 333 return attributes.end(); 334 } 335 auto Operator::getAttributes() const 336 -> llvm::iterator_range<const_attribute_iterator> { 337 return {attribute_begin(), attribute_end()}; 338 } 339 auto Operator::attribute_begin() -> attribute_iterator { 340 return attributes.begin(); 341 } 342 auto Operator::attribute_end() -> attribute_iterator { 343 return attributes.end(); 344 } 345 auto Operator::getAttributes() -> llvm::iterator_range<attribute_iterator> { 346 return {attribute_begin(), attribute_end()}; 347 } 348 349 auto Operator::operand_begin() const -> const_value_iterator { 350 return operands.begin(); 351 } 352 auto Operator::operand_end() const -> const_value_iterator { 353 return operands.end(); 354 } 355 auto Operator::getOperands() const -> const_value_range { 356 return {operand_begin(), operand_end()}; 357 } 358 359 auto Operator::getArg(int index) const -> Argument { return arguments[index]; } 360 361 bool Operator::isVariadic() const { 362 return any_of(llvm::concat<const NamedTypeConstraint>(operands, results), 363 [](const NamedTypeConstraint &op) { return op.isVariadic(); }); 364 } 365 366 void Operator::populateTypeInferenceInfo( 367 const llvm::StringMap<int> &argumentsAndResultsIndex) { 368 // If the type inference op interface is not registered, then do not attempt 369 // to determine if the result types an be inferred. 370 auto &recordKeeper = def.getRecords(); 371 auto *inferTrait = recordKeeper.getDef(inferTypeOpInterface); 372 allResultsHaveKnownTypes = false; 373 if (!inferTrait) 374 return; 375 376 // If there are no results, the skip this else the build method generated 377 // overlaps with another autogenerated builder. 378 if (getNumResults() == 0) 379 return; 380 381 // Skip ops with variadic or optional results. 382 if (getNumVariableLengthResults() > 0) 383 return; 384 385 // Skip cases currently being custom generated. 386 // TODO: Remove special cases. 387 if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { 388 // Check for a non-variable length operand to use as the type anchor. 389 auto *operandI = llvm::find_if(arguments, [](const Argument &arg) { 390 NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg); 391 return operand && !operand->isVariableLength(); 392 }); 393 if (operandI == arguments.end()) 394 return; 395 396 // All result types are inferred from the operand type. 397 int operandIdx = operandI - arguments.begin(); 398 for (int i = 0; i < getNumResults(); ++i) 399 resultTypeMapping.emplace_back(operandIdx, "$_self"); 400 401 allResultsHaveKnownTypes = true; 402 traits.push_back(Trait::create(inferTrait->getDefInit())); 403 return; 404 } 405 406 /// This struct represents a node in this operation's result type inferenece 407 /// graph. Each node has a list of incoming type inference edges `sources`. 408 /// Each edge represents a "source" from which the result type can be 409 /// inferred, either an operand (leaf) or another result (node). When a node 410 /// is known to have a fully-inferred type, `inferred` is set to true. 411 struct ResultTypeInference { 412 /// The list of incoming type inference edges. 413 SmallVector<InferredResultType> sources; 414 /// This flag is set to true when the result type is known to be inferrable. 415 bool inferred = false; 416 }; 417 418 // This vector represents the type inference graph, with one node for each 419 // operation result. The nth element is the node for the nth result. 420 SmallVector<ResultTypeInference> inference(getNumResults(), {}); 421 422 // For all results whose types are buildable, initialize their type inference 423 // nodes with an edge to themselves. Mark those nodes are fully-inferred. 424 for (auto [idx, infer] : llvm::enumerate(inference)) { 425 if (getResult(idx).constraint.getBuilderCall()) { 426 infer.sources.emplace_back(InferredResultType::mapResultIndex(idx), 427 "$_self"); 428 infer.inferred = true; 429 } 430 } 431 432 // Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the 433 // result type inference graph. 434 for (const Trait &trait : traits) { 435 const Record &def = trait.getDef(); 436 437 // If the infer type op interface was manually added, then treat it as 438 // intention that the op needs special handling. 439 // TODO: Reconsider whether to always generate, this is more conservative 440 // and keeps existing behavior so starting that way for now. 441 if (def.isSubClassOf( 442 llvm::formatv("{0}::Trait", inferTypeOpInterface).str())) 443 return; 444 if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait)) 445 if (&traitDef->getDef() == inferTrait) 446 return; 447 448 // The `TypesMatchWith` trait represents a 1 -> 1 type inference edge with a 449 // type transformer. 450 if (def.isSubClassOf("TypesMatchWith")) { 451 int target = argumentsAndResultsIndex.lookup(def.getValueAsString("rhs")); 452 // Ignore operand type inference. 453 if (InferredResultType::isArgIndex(target)) 454 continue; 455 int resultIndex = InferredResultType::unmapResultIndex(target); 456 ResultTypeInference &infer = inference[resultIndex]; 457 // If the type of the result has already been inferred, do nothing. 458 if (infer.inferred) 459 continue; 460 int sourceIndex = 461 argumentsAndResultsIndex.lookup(def.getValueAsString("lhs")); 462 infer.sources.emplace_back(sourceIndex, 463 def.getValueAsString("transformer").str()); 464 // Locally propagate inferredness. 465 infer.inferred = 466 InferredResultType::isArgIndex(sourceIndex) || 467 inference[InferredResultType::unmapResultIndex(sourceIndex)].inferred; 468 continue; 469 } 470 471 if (!def.isSubClassOf("AllTypesMatch")) 472 continue; 473 474 auto values = def.getValueAsListOfStrings("values"); 475 // The `AllTypesMatch` trait represents an N <-> N fanin and fanout. That 476 // is, every result type has an edge from every other type. However, if any 477 // one of the values refers to an operand or a result with a fully-inferred 478 // type, we can infer all other types from that value. Try to find a 479 // fully-inferred type in the list. 480 std::optional<int> fullyInferredIndex; 481 SmallVector<int> resultIndices; 482 for (StringRef name : values) { 483 int index = argumentsAndResultsIndex.lookup(name); 484 if (InferredResultType::isResultIndex(index)) 485 resultIndices.push_back(InferredResultType::unmapResultIndex(index)); 486 if (InferredResultType::isArgIndex(index) || 487 inference[InferredResultType::unmapResultIndex(index)].inferred) 488 fullyInferredIndex = index; 489 } 490 if (fullyInferredIndex) { 491 // Make the fully-inferred type the only source for all results that 492 // aren't already inferred -- a 1 -> N fanout. 493 for (int resultIndex : resultIndices) { 494 ResultTypeInference &infer = inference[resultIndex]; 495 if (!infer.inferred) { 496 infer.sources.assign(1, {*fullyInferredIndex, "$_self"}); 497 infer.inferred = true; 498 } 499 } 500 } else { 501 // Add an edge between every result and every other type; N <-> N. 502 for (int resultIndex : resultIndices) { 503 for (int otherResultIndex : resultIndices) { 504 if (resultIndex == otherResultIndex) 505 continue; 506 inference[resultIndex].sources.emplace_back( 507 InferredResultType::unmapResultIndex(otherResultIndex), "$_self"); 508 } 509 } 510 } 511 } 512 513 // Propagate inferredness until a fixed point. 514 std::vector<ResultTypeInference *> worklist; 515 for (ResultTypeInference &infer : inference) 516 if (!infer.inferred) 517 worklist.push_back(&infer); 518 bool changed; 519 do { 520 changed = false; 521 for (auto cur = worklist.begin(); cur != worklist.end();) { 522 ResultTypeInference &infer = **cur; 523 524 InferredResultType *iter = 525 llvm::find_if(infer.sources, [&](const InferredResultType &source) { 526 assert(InferredResultType::isResultIndex(source.getIndex())); 527 return inference[InferredResultType::unmapResultIndex( 528 source.getIndex())] 529 .inferred; 530 }); 531 if (iter == infer.sources.end()) { 532 ++cur; 533 continue; 534 } 535 536 changed = true; 537 infer.inferred = true; 538 // Make this the only source for the result. This breaks any cycles. 539 infer.sources.assign(1, *iter); 540 cur = worklist.erase(cur); 541 } 542 } while (changed); 543 544 allResultsHaveKnownTypes = worklist.empty(); 545 546 // If the types could be computed, then add type inference trait. 547 if (allResultsHaveKnownTypes) { 548 traits.push_back(Trait::create(inferTrait->getDefInit())); 549 for (const ResultTypeInference &infer : inference) 550 resultTypeMapping.push_back(infer.sources.front()); 551 } 552 } 553 554 void Operator::populateOpStructure() { 555 auto &recordKeeper = def.getRecords(); 556 auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint"); 557 auto *attrClass = recordKeeper.getClass("Attr"); 558 auto *propertyClass = recordKeeper.getClass("Property"); 559 auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr"); 560 auto *opVarClass = recordKeeper.getClass("OpVariable"); 561 numNativeAttributes = 0; 562 563 const DagInit *argumentValues = def.getValueAsDag("arguments"); 564 unsigned numArgs = argumentValues->getNumArgs(); 565 566 // Mapping from name of to argument or result index. Arguments are indexed 567 // to match getArg index, while the results are negatively indexed. 568 llvm::StringMap<int> argumentsAndResultsIndex; 569 570 // Handle operands and native attributes. 571 for (unsigned i = 0; i != numArgs; ++i) { 572 auto *arg = argumentValues->getArg(i); 573 auto givenName = argumentValues->getArgNameStr(i); 574 auto *argDefInit = dyn_cast<DefInit>(arg); 575 if (!argDefInit) 576 PrintFatalError(def.getLoc(), 577 Twine("undefined type for argument #") + Twine(i)); 578 const Record *argDef = argDefInit->getDef(); 579 if (argDef->isSubClassOf(opVarClass)) 580 argDef = argDef->getValueAsDef("constraint"); 581 582 if (argDef->isSubClassOf(typeConstraintClass)) { 583 operands.push_back( 584 NamedTypeConstraint{givenName, TypeConstraint(argDef)}); 585 } else if (argDef->isSubClassOf(attrClass)) { 586 if (givenName.empty()) 587 PrintFatalError(argDef->getLoc(), "attributes must be named"); 588 if (argDef->isSubClassOf(derivedAttrClass)) 589 PrintFatalError(argDef->getLoc(), 590 "derived attributes not allowed in argument list"); 591 attributes.push_back({givenName, Attribute(argDef)}); 592 ++numNativeAttributes; 593 } else if (argDef->isSubClassOf(propertyClass)) { 594 if (givenName.empty()) 595 PrintFatalError(argDef->getLoc(), "properties must be named"); 596 properties.push_back({givenName, Property(argDef)}); 597 } else { 598 PrintFatalError(def.getLoc(), 599 "unexpected def type; only defs deriving " 600 "from TypeConstraint or Attr or Property are allowed"); 601 } 602 if (!givenName.empty()) 603 argumentsAndResultsIndex[givenName] = i; 604 } 605 606 // Handle derived attributes. 607 for (const auto &val : def.getValues()) { 608 if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) { 609 if (!record->isSubClassOf(attrClass)) 610 continue; 611 if (!record->isSubClassOf(derivedAttrClass)) 612 PrintFatalError(def.getLoc(), 613 "unexpected Attr where only DerivedAttr is allowed"); 614 615 if (record->getClasses().size() != 1) { 616 PrintFatalError( 617 def.getLoc(), 618 "unsupported attribute modelling, only single class expected"); 619 } 620 attributes.push_back({cast<StringInit>(val.getNameInit())->getValue(), 621 Attribute(cast<DefInit>(val.getValue()))}); 622 } 623 } 624 625 // Populate `arguments`. This must happen after we've finalized `operands` and 626 // `attributes` because we will put their elements' pointers in `arguments`. 627 // SmallVector may perform re-allocation under the hood when adding new 628 // elements. 629 int operandIndex = 0, attrIndex = 0, propIndex = 0; 630 for (unsigned i = 0; i != numArgs; ++i) { 631 const Record *argDef = 632 dyn_cast<DefInit>(argumentValues->getArg(i))->getDef(); 633 if (argDef->isSubClassOf(opVarClass)) 634 argDef = argDef->getValueAsDef("constraint"); 635 636 if (argDef->isSubClassOf(typeConstraintClass)) { 637 attrOrOperandMapping.push_back( 638 {OperandOrAttribute::Kind::Operand, operandIndex}); 639 arguments.emplace_back(&operands[operandIndex++]); 640 } else if (argDef->isSubClassOf(attrClass)) { 641 attrOrOperandMapping.push_back( 642 {OperandOrAttribute::Kind::Attribute, attrIndex}); 643 arguments.emplace_back(&attributes[attrIndex++]); 644 } else { 645 assert(argDef->isSubClassOf(propertyClass)); 646 arguments.emplace_back(&properties[propIndex++]); 647 } 648 } 649 650 auto *resultsDag = def.getValueAsDag("results"); 651 auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator()); 652 if (!outsOp || outsOp->getDef()->getName() != "outs") { 653 PrintFatalError(def.getLoc(), "'results' must have 'outs' directive"); 654 } 655 656 // Handle results. 657 for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) { 658 auto name = resultsDag->getArgNameStr(i); 659 auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i)); 660 if (!resultInit) { 661 PrintFatalError(def.getLoc(), 662 Twine("undefined type for result #") + Twine(i)); 663 } 664 auto *resultDef = resultInit->getDef(); 665 if (resultDef->isSubClassOf(opVarClass)) 666 resultDef = resultDef->getValueAsDef("constraint"); 667 results.push_back({name, TypeConstraint(resultDef)}); 668 if (!name.empty()) 669 argumentsAndResultsIndex[name] = InferredResultType::mapResultIndex(i); 670 671 // We currently only support VariadicOfVariadic operands. 672 if (results.back().constraint.isVariadicOfVariadic()) { 673 PrintFatalError( 674 def.getLoc(), 675 "'VariadicOfVariadic' results are currently not supported"); 676 } 677 } 678 679 // Handle successors 680 auto *successorsDag = def.getValueAsDag("successors"); 681 auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator()); 682 if (!successorsOp || successorsOp->getDef()->getName() != "successor") { 683 PrintFatalError(def.getLoc(), 684 "'successors' must have 'successor' directive"); 685 } 686 687 for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) { 688 auto name = successorsDag->getArgNameStr(i); 689 auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i)); 690 if (!successorInit) { 691 PrintFatalError(def.getLoc(), 692 Twine("undefined kind for successor #") + Twine(i)); 693 } 694 Successor successor(successorInit->getDef()); 695 696 // Only support variadic successors if it is the last one for now. 697 if (i != e - 1 && successor.isVariadic()) 698 PrintFatalError(def.getLoc(), "only the last successor can be variadic"); 699 successors.push_back({name, successor}); 700 } 701 702 // Create list of traits, skipping over duplicates: appending to lists in 703 // tablegen is easy, making them unique less so, so dedupe here. 704 if (auto *traitList = def.getValueAsListInit("traits")) { 705 // This is uniquing based on pointers of the trait. 706 SmallPtrSet<const Init *, 32> traitSet; 707 traits.reserve(traitSet.size()); 708 709 // The declaration order of traits imply the verification order of traits. 710 // Some traits may require other traits to be verified first then they can 711 // do further verification based on those verified facts. If you see this 712 // error, fix the traits declaration order by checking the `dependentTraits` 713 // field. 714 auto verifyTraitValidity = [&](const Record *trait) { 715 auto *dependentTraits = trait->getValueAsListInit("dependentTraits"); 716 for (auto *traitInit : *dependentTraits) 717 if (!traitSet.contains(traitInit)) 718 PrintFatalError( 719 def.getLoc(), 720 trait->getValueAsString("trait") + " requires " + 721 cast<DefInit>(traitInit)->getDef()->getValueAsString( 722 "trait") + 723 " to precede it in traits list"); 724 }; 725 726 std::function<void(const ListInit *)> insert; 727 insert = [&](const ListInit *traitList) { 728 for (auto *traitInit : *traitList) { 729 auto *def = cast<DefInit>(traitInit)->getDef(); 730 if (def->isSubClassOf("TraitList")) { 731 insert(def->getValueAsListInit("traits")); 732 continue; 733 } 734 735 // Ignore duplicates. 736 if (!traitSet.insert(traitInit).second) 737 continue; 738 739 // If this is an interface with base classes, add the bases to the 740 // trait list. 741 if (def->isSubClassOf("Interface")) 742 insert(def->getValueAsListInit("baseInterfaces")); 743 744 // Verify if the trait has all the dependent traits declared before 745 // itself. 746 verifyTraitValidity(def); 747 traits.push_back(Trait::create(traitInit)); 748 } 749 }; 750 insert(traitList); 751 } 752 753 populateTypeInferenceInfo(argumentsAndResultsIndex); 754 755 // Handle regions 756 auto *regionsDag = def.getValueAsDag("regions"); 757 auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator()); 758 if (!regionsOp || regionsOp->getDef()->getName() != "region") { 759 PrintFatalError(def.getLoc(), "'regions' must have 'region' directive"); 760 } 761 762 for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) { 763 auto name = regionsDag->getArgNameStr(i); 764 auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i)); 765 if (!regionInit) { 766 PrintFatalError(def.getLoc(), 767 Twine("undefined kind for region #") + Twine(i)); 768 } 769 Region region(regionInit->getDef()); 770 if (region.isVariadic()) { 771 // Only support variadic regions if it is the last one for now. 772 if (i != e - 1) 773 PrintFatalError(def.getLoc(), "only the last region can be variadic"); 774 if (name.empty()) 775 PrintFatalError(def.getLoc(), "variadic regions must be named"); 776 } 777 778 regions.push_back({name, region}); 779 } 780 781 // Populate the builders. 782 auto *builderList = dyn_cast_or_null<ListInit>(def.getValueInit("builders")); 783 if (builderList && !builderList->empty()) { 784 for (const Init *init : builderList->getValues()) 785 builders.emplace_back(cast<DefInit>(init)->getDef(), def.getLoc()); 786 } else if (skipDefaultBuilders()) { 787 PrintFatalError( 788 def.getLoc(), 789 "default builders are skipped and no custom builders provided"); 790 } 791 792 LLVM_DEBUG(print(llvm::dbgs())); 793 } 794 795 const InferredResultType &Operator::getInferredResultType(int index) const { 796 assert(allResultTypesKnown()); 797 return resultTypeMapping[index]; 798 } 799 800 ArrayRef<SMLoc> Operator::getLoc() const { return def.getLoc(); } 801 802 bool Operator::hasDescription() const { 803 return !getDescription().trim().empty(); 804 } 805 806 StringRef Operator::getDescription() const { 807 return def.getValueAsString("description"); 808 } 809 810 bool Operator::hasSummary() const { return !getSummary().trim().empty(); } 811 812 StringRef Operator::getSummary() const { 813 return def.getValueAsString("summary"); 814 } 815 816 bool Operator::hasAssemblyFormat() const { 817 auto *valueInit = def.getValueInit("assemblyFormat"); 818 return isa<StringInit>(valueInit); 819 } 820 821 StringRef Operator::getAssemblyFormat() const { 822 return TypeSwitch<const Init *, StringRef>(def.getValueInit("assemblyFormat")) 823 .Case<StringInit>([&](auto *init) { return init->getValue(); }); 824 } 825 826 void Operator::print(llvm::raw_ostream &os) const { 827 os << "op '" << getOperationName() << "'\n"; 828 for (Argument arg : arguments) { 829 if (auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(arg)) 830 os << "[attribute] " << attr->name << '\n'; 831 else 832 os << "[operand] " << cast<NamedTypeConstraint *>(arg)->name << '\n'; 833 } 834 } 835 836 auto Operator::VariableDecoratorIterator::unwrap(const Init *init) 837 -> VariableDecorator { 838 return VariableDecorator(cast<DefInit>(init)->getDef()); 839 } 840 841 auto Operator::getArgToOperandOrAttribute(int index) const 842 -> OperandOrAttribute { 843 return attrOrOperandMapping[index]; 844 } 845 846 std::string Operator::getGetterName(StringRef name) const { 847 return "get" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true); 848 } 849 850 std::string Operator::getSetterName(StringRef name) const { 851 return "set" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true); 852 } 853 854 std::string Operator::getRemoverName(StringRef name) const { 855 return "remove" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true); 856 } 857 858 bool Operator::hasFolder() const { return def.getValueAsBit("hasFolder"); } 859 860 bool Operator::useCustomPropertiesEncoding() const { 861 return def.getValueAsBit("useCustomPropertiesEncoding"); 862 } 863