1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 10 // Pattern. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <utility> 15 16 #include "mlir/TableGen/Pattern.h" 17 #include "llvm/ADT/StringExtras.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/Debug.h" 20 #include "llvm/Support/FormatVariadic.h" 21 #include "llvm/TableGen/Error.h" 22 #include "llvm/TableGen/Record.h" 23 24 #define DEBUG_TYPE "mlir-tblgen-pattern" 25 26 using namespace mlir; 27 using namespace tblgen; 28 29 using llvm::DagInit; 30 using llvm::dbgs; 31 using llvm::DefInit; 32 using llvm::formatv; 33 using llvm::IntInit; 34 using llvm::Record; 35 36 //===----------------------------------------------------------------------===// 37 // DagLeaf 38 //===----------------------------------------------------------------------===// 39 40 bool DagLeaf::isUnspecified() const { 41 return isa_and_nonnull<llvm::UnsetInit>(def); 42 } 43 44 bool DagLeaf::isOperandMatcher() const { 45 // Operand matchers specify a type constraint. 46 return isSubClassOf("TypeConstraint"); 47 } 48 49 bool DagLeaf::isAttrMatcher() const { 50 // Attribute matchers specify an attribute constraint. 51 return isSubClassOf("AttrConstraint"); 52 } 53 54 bool DagLeaf::isNativeCodeCall() const { 55 return isSubClassOf("NativeCodeCall"); 56 } 57 58 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); } 59 60 bool DagLeaf::isEnumAttrCase() const { 61 return isSubClassOf("EnumAttrCaseInfo"); 62 } 63 64 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); } 65 66 Constraint DagLeaf::getAsConstraint() const { 67 assert((isOperandMatcher() || isAttrMatcher()) && 68 "the DAG leaf must be operand or attribute"); 69 return Constraint(cast<DefInit>(def)->getDef()); 70 } 71 72 ConstantAttr DagLeaf::getAsConstantAttr() const { 73 assert(isConstantAttr() && "the DAG leaf must be constant attribute"); 74 return ConstantAttr(cast<DefInit>(def)); 75 } 76 77 EnumAttrCase DagLeaf::getAsEnumAttrCase() const { 78 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); 79 return EnumAttrCase(cast<DefInit>(def)); 80 } 81 82 std::string DagLeaf::getConditionTemplate() const { 83 return getAsConstraint().getConditionTemplate(); 84 } 85 86 StringRef DagLeaf::getNativeCodeTemplate() const { 87 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 88 return cast<DefInit>(def)->getDef()->getValueAsString("expression"); 89 } 90 91 int DagLeaf::getNumReturnsOfNativeCode() const { 92 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 93 return cast<DefInit>(def)->getDef()->getValueAsInt("numReturns"); 94 } 95 96 std::string DagLeaf::getStringAttr() const { 97 assert(isStringAttr() && "the DAG leaf must be string attribute"); 98 return def->getAsUnquotedString(); 99 } 100 bool DagLeaf::isSubClassOf(StringRef superclass) const { 101 if (auto *defInit = dyn_cast_or_null<DefInit>(def)) 102 return defInit->getDef()->isSubClassOf(superclass); 103 return false; 104 } 105 106 void DagLeaf::print(raw_ostream &os) const { 107 if (def) 108 def->print(os); 109 } 110 111 //===----------------------------------------------------------------------===// 112 // DagNode 113 //===----------------------------------------------------------------------===// 114 115 bool DagNode::isNativeCodeCall() const { 116 if (auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator())) 117 return defInit->getDef()->isSubClassOf("NativeCodeCall"); 118 return false; 119 } 120 121 bool DagNode::isOperation() const { 122 return !isNativeCodeCall() && !isReplaceWithValue() && 123 !isLocationDirective() && !isReturnTypeDirective() && !isEither() && 124 !isVariadic(); 125 } 126 127 StringRef DagNode::getNativeCodeTemplate() const { 128 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 129 return cast<DefInit>(node->getOperator()) 130 ->getDef() 131 ->getValueAsString("expression"); 132 } 133 134 int DagNode::getNumReturnsOfNativeCode() const { 135 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 136 return cast<DefInit>(node->getOperator()) 137 ->getDef() 138 ->getValueAsInt("numReturns"); 139 } 140 141 StringRef DagNode::getSymbol() const { return node->getNameStr(); } 142 143 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { 144 const Record *opDef = cast<DefInit>(node->getOperator())->getDef(); 145 auto [it, inserted] = mapper->try_emplace(opDef); 146 if (inserted) 147 it->second = std::make_unique<Operator>(opDef); 148 return *it->second; 149 } 150 151 int DagNode::getNumOps() const { 152 // We want to get number of operations recursively involved in the DAG tree. 153 // All other directives should be excluded. 154 int count = isOperation() ? 1 : 0; 155 for (int i = 0, e = getNumArgs(); i != e; ++i) { 156 if (auto child = getArgAsNestedDag(i)) 157 count += child.getNumOps(); 158 } 159 return count; 160 } 161 162 int DagNode::getNumArgs() const { return node->getNumArgs(); } 163 164 bool DagNode::isNestedDagArg(unsigned index) const { 165 return isa<DagInit>(node->getArg(index)); 166 } 167 168 DagNode DagNode::getArgAsNestedDag(unsigned index) const { 169 return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index))); 170 } 171 172 DagLeaf DagNode::getArgAsLeaf(unsigned index) const { 173 assert(!isNestedDagArg(index)); 174 return DagLeaf(node->getArg(index)); 175 } 176 177 StringRef DagNode::getArgName(unsigned index) const { 178 return node->getArgNameStr(index); 179 } 180 181 bool DagNode::isReplaceWithValue() const { 182 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef(); 183 return dagOpDef->getName() == "replaceWithValue"; 184 } 185 186 bool DagNode::isLocationDirective() const { 187 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef(); 188 return dagOpDef->getName() == "location"; 189 } 190 191 bool DagNode::isReturnTypeDirective() const { 192 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef(); 193 return dagOpDef->getName() == "returnType"; 194 } 195 196 bool DagNode::isEither() const { 197 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef(); 198 return dagOpDef->getName() == "either"; 199 } 200 201 bool DagNode::isVariadic() const { 202 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef(); 203 return dagOpDef->getName() == "variadic"; 204 } 205 206 void DagNode::print(raw_ostream &os) const { 207 if (node) 208 node->print(os); 209 } 210 211 //===----------------------------------------------------------------------===// 212 // SymbolInfoMap 213 //===----------------------------------------------------------------------===// 214 215 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) { 216 int idx = -1; 217 auto [name, indexStr] = symbol.rsplit("__"); 218 219 if (indexStr.consumeInteger(10, idx)) { 220 // The second part is not an index; we return the whole symbol as-is. 221 return symbol; 222 } 223 if (index) { 224 *index = idx; 225 } 226 return name; 227 } 228 229 SymbolInfoMap::SymbolInfo::SymbolInfo( 230 const Operator *op, SymbolInfo::Kind kind, 231 std::optional<DagAndConstant> dagAndConstant) 232 : op(op), kind(kind), dagAndConstant(dagAndConstant) {} 233 234 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { 235 switch (kind) { 236 case Kind::Attr: 237 case Kind::Operand: 238 case Kind::Value: 239 return 1; 240 case Kind::Result: 241 return op->getNumResults(); 242 case Kind::MultipleValues: 243 return getSize(); 244 } 245 llvm_unreachable("unknown kind"); 246 } 247 248 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const { 249 return alternativeName ? *alternativeName : name.str(); 250 } 251 252 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const { 253 LLVM_DEBUG(dbgs() << "getVarTypeStr for '" << name << "': "); 254 switch (kind) { 255 case Kind::Attr: { 256 if (op) 257 return cast<NamedAttribute *>(op->getArg(getArgIndex())) 258 ->attr.getStorageType() 259 .str(); 260 // TODO(suderman): Use a more exact type when available. 261 return "::mlir::Attribute"; 262 } 263 case Kind::Operand: { 264 // Use operand range for captured operands (to support potential variadic 265 // operands). 266 return "::mlir::Operation::operand_range"; 267 } 268 case Kind::Value: { 269 return "::mlir::Value"; 270 } 271 case Kind::MultipleValues: { 272 return "::mlir::ValueRange"; 273 } 274 case Kind::Result: { 275 // Use the op itself for captured results. 276 return op->getQualCppClassName(); 277 } 278 } 279 llvm_unreachable("unknown kind"); 280 } 281 282 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { 283 LLVM_DEBUG(dbgs() << "getVarDecl for '" << name << "': "); 284 std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : ""; 285 return std::string( 286 formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit)); 287 } 288 289 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const { 290 LLVM_DEBUG(dbgs() << "getArgDecl for '" << name << "': "); 291 return std::string( 292 formatv("{0} &{1}", getVarTypeStr(name), getVarName(name))); 293 } 294 295 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( 296 StringRef name, int index, const char *fmt, const char *separator) const { 297 LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '" << name << "': "); 298 switch (kind) { 299 case Kind::Attr: { 300 assert(index < 0); 301 auto repl = formatv(fmt, name); 302 LLVM_DEBUG(dbgs() << repl << " (Attr)\n"); 303 return std::string(repl); 304 } 305 case Kind::Operand: { 306 assert(index < 0); 307 auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex())); 308 // If this operand is variadic and this SymbolInfo doesn't have a range 309 // index, then return the full variadic operand_range. Otherwise, return 310 // the value itself. 311 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) { 312 auto repl = formatv(fmt, name); 313 LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n"); 314 return std::string(repl); 315 } 316 auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); 317 LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n"); 318 return std::string(repl); 319 } 320 case Kind::Result: { 321 // If `index` is greater than zero, then we are referencing a specific 322 // result of a multi-result op. The result can still be variadic. 323 if (index >= 0) { 324 std::string v = 325 std::string(formatv("{0}.getODSResults({1})", name, index)); 326 if (!op->getResult(index).isVariadic()) 327 v = std::string(formatv("(*{0}.begin())", v)); 328 auto repl = formatv(fmt, v); 329 LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n"); 330 return std::string(repl); 331 } 332 333 // If this op has no result at all but still we bind a symbol to it, it 334 // means we want to capture the op itself. 335 if (op->getNumResults() == 0) { 336 LLVM_DEBUG(dbgs() << name << " (Op)\n"); 337 return formatv(fmt, name); 338 } 339 340 // We are referencing all results of the multi-result op. A specific result 341 // can either be a value or a range. Then join them with `separator`. 342 SmallVector<std::string, 4> values; 343 values.reserve(op->getNumResults()); 344 345 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 346 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i)); 347 if (!op->getResult(i).isVariadic()) { 348 v = std::string(formatv("(*{0}.begin())", v)); 349 } 350 values.push_back(std::string(formatv(fmt, v))); 351 } 352 auto repl = llvm::join(values, separator); 353 LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n"); 354 return repl; 355 } 356 case Kind::Value: { 357 assert(index < 0); 358 assert(op == nullptr); 359 auto repl = formatv(fmt, name); 360 LLVM_DEBUG(dbgs() << repl << " (Value)\n"); 361 return std::string(repl); 362 } 363 case Kind::MultipleValues: { 364 assert(op == nullptr); 365 assert(index < getSize()); 366 if (index >= 0) { 367 std::string repl = 368 formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); 369 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n"); 370 return repl; 371 } 372 // If it doesn't specify certain element, unpack them all. 373 auto repl = 374 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); 375 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n"); 376 return std::string(repl); 377 } 378 } 379 llvm_unreachable("unknown kind"); 380 } 381 382 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( 383 StringRef name, int index, const char *fmt, const char *separator) const { 384 LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': "); 385 switch (kind) { 386 case Kind::Attr: 387 case Kind::Operand: { 388 assert(index < 0 && "only allowed for symbol bound to result"); 389 auto repl = formatv(fmt, name); 390 LLVM_DEBUG(dbgs() << repl << " (Operand/Attr)\n"); 391 return std::string(repl); 392 } 393 case Kind::Result: { 394 if (index >= 0) { 395 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); 396 LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n"); 397 return std::string(repl); 398 } 399 400 // We are referencing all results of the multi-result op. Each result should 401 // have a value range, and then join them with `separator`. 402 SmallVector<std::string, 4> values; 403 values.reserve(op->getNumResults()); 404 405 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 406 values.push_back(std::string( 407 formatv(fmt, formatv("{0}.getODSResults({1})", name, i)))); 408 } 409 auto repl = llvm::join(values, separator); 410 LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n"); 411 return repl; 412 } 413 case Kind::Value: { 414 assert(index < 0 && "only allowed for symbol bound to result"); 415 assert(op == nullptr); 416 auto repl = formatv(fmt, formatv("{{{0}}", name)); 417 LLVM_DEBUG(dbgs() << repl << " (Value)\n"); 418 return std::string(repl); 419 } 420 case Kind::MultipleValues: { 421 assert(op == nullptr); 422 assert(index < getSize()); 423 if (index >= 0) { 424 std::string repl = 425 formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); 426 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n"); 427 return repl; 428 } 429 auto repl = 430 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); 431 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n"); 432 return std::string(repl); 433 } 434 } 435 llvm_unreachable("unknown kind"); 436 } 437 438 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, 439 const Operator &op, int argIndex, 440 std::optional<int> variadicSubIndex) { 441 StringRef name = getValuePackName(symbol); 442 if (name != symbol) { 443 auto error = formatv( 444 "symbol '{0}' with trailing index cannot bind to op argument", symbol); 445 PrintFatalError(loc, error); 446 } 447 448 auto symInfo = 449 isa<NamedAttribute *>(op.getArg(argIndex)) 450 ? SymbolInfo::getAttr(&op, argIndex) 451 : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex); 452 453 std::string key = symbol.str(); 454 if (symbolInfoMap.count(key)) { 455 // Only non unique name for the operand is supported. 456 if (symInfo.kind != SymbolInfo::Kind::Operand) { 457 return false; 458 } 459 460 // Cannot add new operand if there is already non operand with the same 461 // name. 462 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) { 463 return false; 464 } 465 } 466 467 symbolInfoMap.emplace(key, symInfo); 468 return true; 469 } 470 471 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { 472 std::string name = getValuePackName(symbol).str(); 473 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op)); 474 475 return symbolInfoMap.count(inserted->first) == 1; 476 } 477 478 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) { 479 std::string name = getValuePackName(symbol).str(); 480 if (numValues > 1) 481 return bindMultipleValues(name, numValues); 482 return bindValue(name); 483 } 484 485 bool SymbolInfoMap::bindValue(StringRef symbol) { 486 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue()); 487 return symbolInfoMap.count(inserted->first) == 1; 488 } 489 490 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) { 491 std::string name = getValuePackName(symbol).str(); 492 auto inserted = 493 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues)); 494 return symbolInfoMap.count(inserted->first) == 1; 495 } 496 497 bool SymbolInfoMap::bindAttr(StringRef symbol) { 498 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr()); 499 return symbolInfoMap.count(inserted->first) == 1; 500 } 501 502 bool SymbolInfoMap::contains(StringRef symbol) const { 503 return find(symbol) != symbolInfoMap.end(); 504 } 505 506 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const { 507 std::string name = getValuePackName(key).str(); 508 509 return symbolInfoMap.find(name); 510 } 511 512 SymbolInfoMap::const_iterator 513 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op, 514 int argIndex, 515 std::optional<int> variadicSubIndex) const { 516 return findBoundSymbol( 517 key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex)); 518 } 519 520 SymbolInfoMap::const_iterator 521 SymbolInfoMap::findBoundSymbol(StringRef key, 522 const SymbolInfo &symbolInfo) const { 523 std::string name = getValuePackName(key).str(); 524 auto range = symbolInfoMap.equal_range(name); 525 526 for (auto it = range.first; it != range.second; ++it) 527 if (it->second.dagAndConstant == symbolInfo.dagAndConstant) 528 return it; 529 530 return symbolInfoMap.end(); 531 } 532 533 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator> 534 SymbolInfoMap::getRangeOfEqualElements(StringRef key) { 535 std::string name = getValuePackName(key).str(); 536 537 return symbolInfoMap.equal_range(name); 538 } 539 540 int SymbolInfoMap::count(StringRef key) const { 541 std::string name = getValuePackName(key).str(); 542 return symbolInfoMap.count(name); 543 } 544 545 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const { 546 StringRef name = getValuePackName(symbol); 547 if (name != symbol) { 548 // If there is a trailing index inside symbol, it references just one 549 // static value. 550 return 1; 551 } 552 // Otherwise, find how many it represents by querying the symbol's info. 553 return find(name)->second.getStaticValueCount(); 554 } 555 556 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol, 557 const char *fmt, 558 const char *separator) const { 559 int index = -1; 560 StringRef name = getValuePackName(symbol, &index); 561 562 auto it = symbolInfoMap.find(name.str()); 563 if (it == symbolInfoMap.end()) { 564 auto error = formatv("referencing unbound symbol '{0}'", symbol); 565 PrintFatalError(loc, error); 566 } 567 568 return it->second.getValueAndRangeUse(name, index, fmt, separator); 569 } 570 571 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt, 572 const char *separator) const { 573 int index = -1; 574 StringRef name = getValuePackName(symbol, &index); 575 576 auto it = symbolInfoMap.find(name.str()); 577 if (it == symbolInfoMap.end()) { 578 auto error = formatv("referencing unbound symbol '{0}'", symbol); 579 PrintFatalError(loc, error); 580 } 581 582 return it->second.getAllRangeUse(name, index, fmt, separator); 583 } 584 585 void SymbolInfoMap::assignUniqueAlternativeNames() { 586 llvm::StringSet<> usedNames; 587 588 for (auto symbolInfoIt = symbolInfoMap.begin(); 589 symbolInfoIt != symbolInfoMap.end();) { 590 auto range = symbolInfoMap.equal_range(symbolInfoIt->first); 591 auto startRange = range.first; 592 auto endRange = range.second; 593 594 auto operandName = symbolInfoIt->first; 595 int startSearchIndex = 0; 596 for (++startRange; startRange != endRange; ++startRange) { 597 // Current operand name is not unique, find a unique one 598 // and set the alternative name. 599 for (int i = startSearchIndex;; ++i) { 600 std::string alternativeName = operandName + std::to_string(i); 601 if (!usedNames.contains(alternativeName) && 602 symbolInfoMap.count(alternativeName) == 0) { 603 usedNames.insert(alternativeName); 604 startRange->second.alternativeName = alternativeName; 605 startSearchIndex = i + 1; 606 607 break; 608 } 609 } 610 } 611 612 symbolInfoIt = endRange; 613 } 614 } 615 616 //===----------------------------------------------------------------------===// 617 // Pattern 618 //==----------------------------------------------------------------------===// 619 620 Pattern::Pattern(const Record *def, RecordOperatorMap *mapper) 621 : def(*def), recordOpMap(mapper) {} 622 623 DagNode Pattern::getSourcePattern() const { 624 return DagNode(def.getValueAsDag("sourcePattern")); 625 } 626 627 int Pattern::getNumResultPatterns() const { 628 auto *results = def.getValueAsListInit("resultPatterns"); 629 return results->size(); 630 } 631 632 DagNode Pattern::getResultPattern(unsigned index) const { 633 auto *results = def.getValueAsListInit("resultPatterns"); 634 return DagNode(cast<DagInit>(results->getElement(index))); 635 } 636 637 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) { 638 LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n"); 639 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); 640 LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n"); 641 642 LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n"); 643 infoMap.assignUniqueAlternativeNames(); 644 LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n"); 645 } 646 647 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { 648 LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n"); 649 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { 650 auto pattern = getResultPattern(i); 651 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); 652 } 653 LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n"); 654 } 655 656 const Operator &Pattern::getSourceRootOp() { 657 return getSourcePattern().getDialectOp(recordOpMap); 658 } 659 660 Operator &Pattern::getDialectOp(DagNode node) { 661 return node.getDialectOp(recordOpMap); 662 } 663 664 std::vector<AppliedConstraint> Pattern::getConstraints() const { 665 auto *listInit = def.getValueAsListInit("constraints"); 666 std::vector<AppliedConstraint> ret; 667 ret.reserve(listInit->size()); 668 669 for (auto *it : *listInit) { 670 auto *dagInit = dyn_cast<DagInit>(it); 671 if (!dagInit) 672 PrintFatalError(&def, "all elements in Pattern multi-entity " 673 "constraints should be DAG nodes"); 674 675 std::vector<std::string> entities; 676 entities.reserve(dagInit->arg_size()); 677 for (auto *argName : dagInit->getArgNames()) { 678 if (!argName) { 679 PrintFatalError( 680 &def, 681 "operands to additional constraints can only be symbol references"); 682 } 683 entities.emplace_back(argName->getValue()); 684 } 685 686 ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(), 687 dagInit->getNameStr(), std::move(entities)); 688 } 689 return ret; 690 } 691 692 int Pattern::getNumSupplementalPatterns() const { 693 auto *results = def.getValueAsListInit("supplementalPatterns"); 694 return results->size(); 695 } 696 697 DagNode Pattern::getSupplementalPattern(unsigned index) const { 698 auto *results = def.getValueAsListInit("supplementalPatterns"); 699 return DagNode(cast<DagInit>(results->getElement(index))); 700 } 701 702 int Pattern::getBenefit() const { 703 // The initial benefit value is a heuristic with number of ops in the source 704 // pattern. 705 int initBenefit = getSourcePattern().getNumOps(); 706 const DagInit *delta = def.getValueAsDag("benefitDelta"); 707 if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) { 708 PrintFatalError(&def, 709 "The 'addBenefit' takes and only takes one integer value"); 710 } 711 return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue(); 712 } 713 714 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const { 715 std::vector<std::pair<StringRef, unsigned>> result; 716 result.reserve(def.getLoc().size()); 717 for (auto loc : def.getLoc()) { 718 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); 719 assert(buf && "invalid source location"); 720 result.emplace_back( 721 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), 722 llvm::SrcMgr.getLineAndColumn(loc, buf).first); 723 } 724 return result; 725 } 726 727 void Pattern::verifyBind(bool result, StringRef symbolName) { 728 if (!result) { 729 auto err = formatv("symbol '{0}' bound more than once", symbolName); 730 PrintFatalError(&def, err); 731 } 732 } 733 734 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 735 bool isSrcPattern) { 736 auto treeName = tree.getSymbol(); 737 auto numTreeArgs = tree.getNumArgs(); 738 739 if (tree.isNativeCodeCall()) { 740 if (!treeName.empty()) { 741 if (!isSrcPattern) { 742 LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: " 743 << treeName << '\n'); 744 verifyBind( 745 infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()), 746 treeName); 747 } else { 748 PrintFatalError(&def, 749 formatv("binding symbol '{0}' to NativecodeCall in " 750 "MatchPattern is not supported", 751 treeName)); 752 } 753 } 754 755 for (int i = 0; i != numTreeArgs; ++i) { 756 if (auto treeArg = tree.getArgAsNestedDag(i)) { 757 // This DAG node argument is a DAG node itself. Go inside recursively. 758 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 759 continue; 760 } 761 762 if (!isSrcPattern) 763 continue; 764 765 // We can only bind symbols to arguments in source pattern. Those 766 // symbols are referenced in result patterns. 767 auto treeArgName = tree.getArgName(i); 768 769 // `$_` is a special symbol meaning ignore the current argument. 770 if (!treeArgName.empty() && treeArgName != "_") { 771 DagLeaf leaf = tree.getArgAsLeaf(i); 772 773 // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c), 774 if (leaf.isUnspecified()) { 775 // This is case of $c, a Value without any constraints. 776 verifyBind(infoMap.bindValue(treeArgName), treeArgName); 777 } else { 778 auto constraint = leaf.getAsConstraint(); 779 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || 780 leaf.isConstantAttr() || 781 constraint.getKind() == Constraint::Kind::CK_Attr; 782 783 if (isAttr) { 784 // This is case of $a, a binding to a certain attribute. 785 verifyBind(infoMap.bindAttr(treeArgName), treeArgName); 786 continue; 787 } 788 789 // This is case of $b, a binding to a certain type. 790 verifyBind(infoMap.bindValue(treeArgName), treeArgName); 791 } 792 } 793 } 794 795 return; 796 } 797 798 if (tree.isOperation()) { 799 auto &op = getDialectOp(tree); 800 auto numOpArgs = op.getNumArgs(); 801 int numEither = 0; 802 803 // We need to exclude the trailing directives and `either` directive groups 804 // two operands of the operation. 805 int numDirectives = 0; 806 for (int i = numTreeArgs - 1; i >= 0; --i) { 807 if (auto dagArg = tree.getArgAsNestedDag(i)) { 808 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective()) 809 ++numDirectives; 810 else if (dagArg.isEither()) 811 ++numEither; 812 } 813 } 814 815 if (numOpArgs != numTreeArgs - numDirectives + numEither) { 816 auto err = 817 formatv("op '{0}' argument number mismatch: " 818 "{1} in pattern vs. {2} in definition", 819 op.getOperationName(), numTreeArgs + numEither, numOpArgs); 820 PrintFatalError(&def, err); 821 } 822 823 // The name attached to the DAG node's operator is for representing the 824 // results generated from this op. It should be remembered as bound results. 825 if (!treeName.empty()) { 826 LLVM_DEBUG(dbgs() << "found symbol bound to op result: " << treeName 827 << '\n'); 828 verifyBind(infoMap.bindOpResult(treeName, op), treeName); 829 } 830 831 // The operand in `either` DAG should be bound to the operation in the 832 // parent DagNode. 833 auto collectSymbolInEither = [&](DagNode parent, DagNode tree, 834 int opArgIdx) { 835 for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) { 836 if (DagNode subTree = tree.getArgAsNestedDag(i)) { 837 collectBoundSymbols(subTree, infoMap, isSrcPattern); 838 } else { 839 auto argName = tree.getArgName(i); 840 if (!argName.empty() && argName != "_") { 841 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx), 842 argName); 843 } 844 } 845 } 846 }; 847 848 // The operand in `variadic` DAG should be bound to the operation in the 849 // parent DagNode. The range index must be included as well to distinguish 850 // (potentially) repeating argName within the `variadic` DAG. 851 auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree, 852 int opArgIdx) { 853 auto treeName = tree.getSymbol(); 854 if (!treeName.empty()) { 855 // If treeName is specified, bind to the full variadic operand_range. 856 verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx, 857 std::nullopt), 858 treeName); 859 } 860 861 for (int i = 0; i < tree.getNumArgs(); ++i) { 862 if (DagNode subTree = tree.getArgAsNestedDag(i)) { 863 collectBoundSymbols(subTree, infoMap, isSrcPattern); 864 } else { 865 auto argName = tree.getArgName(i); 866 if (!argName.empty() && argName != "_") { 867 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx, 868 /*variadicSubIndex=*/i), 869 argName); 870 } 871 } 872 } 873 }; 874 875 for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) { 876 if (auto treeArg = tree.getArgAsNestedDag(i)) { 877 if (treeArg.isEither()) { 878 collectSymbolInEither(tree, treeArg, opArgIdx); 879 // `either` DAG is *flattened*. For example, 880 // 881 // (FooOp (either arg0, arg1), arg2) 882 // 883 // can be viewed as: 884 // 885 // (FooOp arg0, arg1, arg2) 886 ++opArgIdx; 887 } else if (treeArg.isVariadic()) { 888 collectSymbolInVariadic(tree, treeArg, opArgIdx); 889 } else { 890 // This DAG node argument is a DAG node itself. Go inside recursively. 891 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 892 } 893 continue; 894 } 895 896 if (isSrcPattern) { 897 // We can only bind symbols to op arguments in source pattern. Those 898 // symbols are referenced in result patterns. 899 auto treeArgName = tree.getArgName(i); 900 // `$_` is a special symbol meaning ignore the current argument. 901 if (!treeArgName.empty() && treeArgName != "_") { 902 LLVM_DEBUG(dbgs() << "found symbol bound to op argument: " 903 << treeArgName << '\n'); 904 verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx), 905 treeArgName); 906 } 907 } 908 } 909 return; 910 } 911 912 if (!treeName.empty()) { 913 PrintFatalError( 914 &def, formatv("binding symbol '{0}' to non-operation/native code call " 915 "unsupported right now", 916 treeName)); 917 } 918 } 919