1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 10 // Pattern. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <utility> 15 16 #include "mlir/TableGen/Pattern.h" 17 #include "llvm/ADT/StringExtras.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/Debug.h" 20 #include "llvm/Support/FormatVariadic.h" 21 #include "llvm/TableGen/Error.h" 22 #include "llvm/TableGen/Record.h" 23 24 #define DEBUG_TYPE "mlir-tblgen-pattern" 25 26 using namespace mlir; 27 using namespace tblgen; 28 29 using llvm::formatv; 30 31 //===----------------------------------------------------------------------===// 32 // DagLeaf 33 //===----------------------------------------------------------------------===// 34 35 bool DagLeaf::isUnspecified() const { 36 return isa_and_nonnull<llvm::UnsetInit>(def); 37 } 38 39 bool DagLeaf::isOperandMatcher() const { 40 // Operand matchers specify a type constraint. 41 return isSubClassOf("TypeConstraint"); 42 } 43 44 bool DagLeaf::isAttrMatcher() const { 45 // Attribute matchers specify an attribute constraint. 46 return isSubClassOf("AttrConstraint"); 47 } 48 49 bool DagLeaf::isNativeCodeCall() const { 50 return isSubClassOf("NativeCodeCall"); 51 } 52 53 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); } 54 55 bool DagLeaf::isEnumAttrCase() const { 56 return isSubClassOf("EnumAttrCaseInfo"); 57 } 58 59 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); } 60 61 Constraint DagLeaf::getAsConstraint() const { 62 assert((isOperandMatcher() || isAttrMatcher()) && 63 "the DAG leaf must be operand or attribute"); 64 return Constraint(cast<llvm::DefInit>(def)->getDef()); 65 } 66 67 ConstantAttr DagLeaf::getAsConstantAttr() const { 68 assert(isConstantAttr() && "the DAG leaf must be constant attribute"); 69 return ConstantAttr(cast<llvm::DefInit>(def)); 70 } 71 72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const { 73 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); 74 return EnumAttrCase(cast<llvm::DefInit>(def)); 75 } 76 77 std::string DagLeaf::getConditionTemplate() const { 78 return getAsConstraint().getConditionTemplate(); 79 } 80 81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const { 82 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 83 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression"); 84 } 85 86 int DagLeaf::getNumReturnsOfNativeCode() const { 87 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 88 return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns"); 89 } 90 91 std::string DagLeaf::getStringAttr() const { 92 assert(isStringAttr() && "the DAG leaf must be string attribute"); 93 return def->getAsUnquotedString(); 94 } 95 bool DagLeaf::isSubClassOf(StringRef superclass) const { 96 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def)) 97 return defInit->getDef()->isSubClassOf(superclass); 98 return false; 99 } 100 101 void DagLeaf::print(raw_ostream &os) const { 102 if (def) 103 def->print(os); 104 } 105 106 //===----------------------------------------------------------------------===// 107 // DagNode 108 //===----------------------------------------------------------------------===// 109 110 bool DagNode::isNativeCodeCall() const { 111 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) 112 return defInit->getDef()->isSubClassOf("NativeCodeCall"); 113 return false; 114 } 115 116 bool DagNode::isOperation() const { 117 return !isNativeCodeCall() && !isReplaceWithValue() && 118 !isLocationDirective() && !isReturnTypeDirective() && !isEither() && 119 !isVariadic(); 120 } 121 122 llvm::StringRef DagNode::getNativeCodeTemplate() const { 123 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 124 return cast<llvm::DefInit>(node->getOperator()) 125 ->getDef() 126 ->getValueAsString("expression"); 127 } 128 129 int DagNode::getNumReturnsOfNativeCode() const { 130 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 131 return cast<llvm::DefInit>(node->getOperator()) 132 ->getDef() 133 ->getValueAsInt("numReturns"); 134 } 135 136 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); } 137 138 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { 139 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 140 auto it = mapper->find(opDef); 141 if (it != mapper->end()) 142 return *it->second; 143 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef)) 144 .first->second; 145 } 146 147 int DagNode::getNumOps() const { 148 // We want to get number of operations recursively involved in the DAG tree. 149 // All other directives should be excluded. 150 int count = isOperation() ? 1 : 0; 151 for (int i = 0, e = getNumArgs(); i != e; ++i) { 152 if (auto child = getArgAsNestedDag(i)) 153 count += child.getNumOps(); 154 } 155 return count; 156 } 157 158 int DagNode::getNumArgs() const { return node->getNumArgs(); } 159 160 bool DagNode::isNestedDagArg(unsigned index) const { 161 return isa<llvm::DagInit>(node->getArg(index)); 162 } 163 164 DagNode DagNode::getArgAsNestedDag(unsigned index) const { 165 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index))); 166 } 167 168 DagLeaf DagNode::getArgAsLeaf(unsigned index) const { 169 assert(!isNestedDagArg(index)); 170 return DagLeaf(node->getArg(index)); 171 } 172 173 StringRef DagNode::getArgName(unsigned index) const { 174 return node->getArgNameStr(index); 175 } 176 177 bool DagNode::isReplaceWithValue() const { 178 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 179 return dagOpDef->getName() == "replaceWithValue"; 180 } 181 182 bool DagNode::isLocationDirective() const { 183 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 184 return dagOpDef->getName() == "location"; 185 } 186 187 bool DagNode::isReturnTypeDirective() const { 188 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 189 return dagOpDef->getName() == "returnType"; 190 } 191 192 bool DagNode::isEither() const { 193 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 194 return dagOpDef->getName() == "either"; 195 } 196 197 bool DagNode::isVariadic() const { 198 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 199 return dagOpDef->getName() == "variadic"; 200 } 201 202 void DagNode::print(raw_ostream &os) const { 203 if (node) 204 node->print(os); 205 } 206 207 //===----------------------------------------------------------------------===// 208 // SymbolInfoMap 209 //===----------------------------------------------------------------------===// 210 211 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) { 212 int idx = -1; 213 auto [name, indexStr] = symbol.rsplit("__"); 214 215 if (indexStr.consumeInteger(10, idx)) { 216 // The second part is not an index; we return the whole symbol as-is. 217 return symbol; 218 } 219 if (index) { 220 *index = idx; 221 } 222 return name; 223 } 224 225 SymbolInfoMap::SymbolInfo::SymbolInfo( 226 const Operator *op, SymbolInfo::Kind kind, 227 std::optional<DagAndConstant> dagAndConstant) 228 : op(op), kind(kind), dagAndConstant(dagAndConstant) {} 229 230 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { 231 switch (kind) { 232 case Kind::Attr: 233 case Kind::Operand: 234 case Kind::Value: 235 return 1; 236 case Kind::Result: 237 return op->getNumResults(); 238 case Kind::MultipleValues: 239 return getSize(); 240 } 241 llvm_unreachable("unknown kind"); 242 } 243 244 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const { 245 return alternativeName ? *alternativeName : name.str(); 246 } 247 248 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const { 249 LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': "); 250 switch (kind) { 251 case Kind::Attr: { 252 if (op) 253 return op->getArg(getArgIndex()) 254 .get<NamedAttribute *>() 255 ->attr.getStorageType() 256 .str(); 257 // TODO(suderman): Use a more exact type when available. 258 return "::mlir::Attribute"; 259 } 260 case Kind::Operand: { 261 // Use operand range for captured operands (to support potential variadic 262 // operands). 263 return "::mlir::Operation::operand_range"; 264 } 265 case Kind::Value: { 266 return "::mlir::Value"; 267 } 268 case Kind::MultipleValues: { 269 return "::mlir::ValueRange"; 270 } 271 case Kind::Result: { 272 // Use the op itself for captured results. 273 return op->getQualCppClassName(); 274 } 275 } 276 llvm_unreachable("unknown kind"); 277 } 278 279 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { 280 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); 281 std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : ""; 282 return std::string( 283 formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit)); 284 } 285 286 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const { 287 LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': "); 288 return std::string( 289 formatv("{0} &{1}", getVarTypeStr(name), getVarName(name))); 290 } 291 292 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( 293 StringRef name, int index, const char *fmt, const char *separator) const { 294 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); 295 switch (kind) { 296 case Kind::Attr: { 297 assert(index < 0); 298 auto repl = formatv(fmt, name); 299 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n"); 300 return std::string(repl); 301 } 302 case Kind::Operand: { 303 assert(index < 0); 304 auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>(); 305 // If this operand is variadic and this SymbolInfo doesn't have a range 306 // index, then return the full variadic operand_range. Otherwise, return 307 // the value itself. 308 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) { 309 auto repl = formatv(fmt, name); 310 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); 311 return std::string(repl); 312 } 313 auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); 314 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); 315 return std::string(repl); 316 } 317 case Kind::Result: { 318 // If `index` is greater than zero, then we are referencing a specific 319 // result of a multi-result op. The result can still be variadic. 320 if (index >= 0) { 321 std::string v = 322 std::string(formatv("{0}.getODSResults({1})", name, index)); 323 if (!op->getResult(index).isVariadic()) 324 v = std::string(formatv("(*{0}.begin())", v)); 325 auto repl = formatv(fmt, v); 326 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 327 return std::string(repl); 328 } 329 330 // If this op has no result at all but still we bind a symbol to it, it 331 // means we want to capture the op itself. 332 if (op->getNumResults() == 0) { 333 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); 334 return formatv(fmt, name); 335 } 336 337 // We are referencing all results of the multi-result op. A specific result 338 // can either be a value or a range. Then join them with `separator`. 339 SmallVector<std::string, 4> values; 340 values.reserve(op->getNumResults()); 341 342 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 343 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i)); 344 if (!op->getResult(i).isVariadic()) { 345 v = std::string(formatv("(*{0}.begin())", v)); 346 } 347 values.push_back(std::string(formatv(fmt, v))); 348 } 349 auto repl = llvm::join(values, separator); 350 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 351 return repl; 352 } 353 case Kind::Value: { 354 assert(index < 0); 355 assert(op == nullptr); 356 auto repl = formatv(fmt, name); 357 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 358 return std::string(repl); 359 } 360 case Kind::MultipleValues: { 361 assert(op == nullptr); 362 assert(index < getSize()); 363 if (index >= 0) { 364 std::string repl = 365 formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); 366 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 367 return repl; 368 } 369 // If it doesn't specify certain element, unpack them all. 370 auto repl = 371 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); 372 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 373 return std::string(repl); 374 } 375 } 376 llvm_unreachable("unknown kind"); 377 } 378 379 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( 380 StringRef name, int index, const char *fmt, const char *separator) const { 381 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); 382 switch (kind) { 383 case Kind::Attr: 384 case Kind::Operand: { 385 assert(index < 0 && "only allowed for symbol bound to result"); 386 auto repl = formatv(fmt, name); 387 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n"); 388 return std::string(repl); 389 } 390 case Kind::Result: { 391 if (index >= 0) { 392 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); 393 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 394 return std::string(repl); 395 } 396 397 // We are referencing all results of the multi-result op. Each result should 398 // have a value range, and then join them with `separator`. 399 SmallVector<std::string, 4> values; 400 values.reserve(op->getNumResults()); 401 402 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 403 values.push_back(std::string( 404 formatv(fmt, formatv("{0}.getODSResults({1})", name, i)))); 405 } 406 auto repl = llvm::join(values, separator); 407 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 408 return repl; 409 } 410 case Kind::Value: { 411 assert(index < 0 && "only allowed for symbol bound to result"); 412 assert(op == nullptr); 413 auto repl = formatv(fmt, formatv("{{{0}}", name)); 414 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 415 return std::string(repl); 416 } 417 case Kind::MultipleValues: { 418 assert(op == nullptr); 419 assert(index < getSize()); 420 if (index >= 0) { 421 std::string repl = 422 formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); 423 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 424 return repl; 425 } 426 auto repl = 427 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); 428 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 429 return std::string(repl); 430 } 431 } 432 llvm_unreachable("unknown kind"); 433 } 434 435 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, 436 const Operator &op, int argIndex, 437 std::optional<int> variadicSubIndex) { 438 StringRef name = getValuePackName(symbol); 439 if (name != symbol) { 440 auto error = formatv( 441 "symbol '{0}' with trailing index cannot bind to op argument", symbol); 442 PrintFatalError(loc, error); 443 } 444 445 auto symInfo = 446 op.getArg(argIndex).is<NamedAttribute *>() 447 ? SymbolInfo::getAttr(&op, argIndex) 448 : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex); 449 450 std::string key = symbol.str(); 451 if (symbolInfoMap.count(key)) { 452 // Only non unique name for the operand is supported. 453 if (symInfo.kind != SymbolInfo::Kind::Operand) { 454 return false; 455 } 456 457 // Cannot add new operand if there is already non operand with the same 458 // name. 459 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) { 460 return false; 461 } 462 } 463 464 symbolInfoMap.emplace(key, symInfo); 465 return true; 466 } 467 468 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { 469 std::string name = getValuePackName(symbol).str(); 470 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op)); 471 472 return symbolInfoMap.count(inserted->first) == 1; 473 } 474 475 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) { 476 std::string name = getValuePackName(symbol).str(); 477 if (numValues > 1) 478 return bindMultipleValues(name, numValues); 479 return bindValue(name); 480 } 481 482 bool SymbolInfoMap::bindValue(StringRef symbol) { 483 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue()); 484 return symbolInfoMap.count(inserted->first) == 1; 485 } 486 487 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) { 488 std::string name = getValuePackName(symbol).str(); 489 auto inserted = 490 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues)); 491 return symbolInfoMap.count(inserted->first) == 1; 492 } 493 494 bool SymbolInfoMap::bindAttr(StringRef symbol) { 495 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr()); 496 return symbolInfoMap.count(inserted->first) == 1; 497 } 498 499 bool SymbolInfoMap::contains(StringRef symbol) const { 500 return find(symbol) != symbolInfoMap.end(); 501 } 502 503 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const { 504 std::string name = getValuePackName(key).str(); 505 506 return symbolInfoMap.find(name); 507 } 508 509 SymbolInfoMap::const_iterator 510 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op, 511 int argIndex, 512 std::optional<int> variadicSubIndex) const { 513 return findBoundSymbol( 514 key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex)); 515 } 516 517 SymbolInfoMap::const_iterator 518 SymbolInfoMap::findBoundSymbol(StringRef key, 519 const SymbolInfo &symbolInfo) const { 520 std::string name = getValuePackName(key).str(); 521 auto range = symbolInfoMap.equal_range(name); 522 523 for (auto it = range.first; it != range.second; ++it) 524 if (it->second.dagAndConstant == symbolInfo.dagAndConstant) 525 return it; 526 527 return symbolInfoMap.end(); 528 } 529 530 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator> 531 SymbolInfoMap::getRangeOfEqualElements(StringRef key) { 532 std::string name = getValuePackName(key).str(); 533 534 return symbolInfoMap.equal_range(name); 535 } 536 537 int SymbolInfoMap::count(StringRef key) const { 538 std::string name = getValuePackName(key).str(); 539 return symbolInfoMap.count(name); 540 } 541 542 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const { 543 StringRef name = getValuePackName(symbol); 544 if (name != symbol) { 545 // If there is a trailing index inside symbol, it references just one 546 // static value. 547 return 1; 548 } 549 // Otherwise, find how many it represents by querying the symbol's info. 550 return find(name)->second.getStaticValueCount(); 551 } 552 553 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol, 554 const char *fmt, 555 const char *separator) const { 556 int index = -1; 557 StringRef name = getValuePackName(symbol, &index); 558 559 auto it = symbolInfoMap.find(name.str()); 560 if (it == symbolInfoMap.end()) { 561 auto error = formatv("referencing unbound symbol '{0}'", symbol); 562 PrintFatalError(loc, error); 563 } 564 565 return it->second.getValueAndRangeUse(name, index, fmt, separator); 566 } 567 568 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt, 569 const char *separator) const { 570 int index = -1; 571 StringRef name = getValuePackName(symbol, &index); 572 573 auto it = symbolInfoMap.find(name.str()); 574 if (it == symbolInfoMap.end()) { 575 auto error = formatv("referencing unbound symbol '{0}'", symbol); 576 PrintFatalError(loc, error); 577 } 578 579 return it->second.getAllRangeUse(name, index, fmt, separator); 580 } 581 582 void SymbolInfoMap::assignUniqueAlternativeNames() { 583 llvm::StringSet<> usedNames; 584 585 for (auto symbolInfoIt = symbolInfoMap.begin(); 586 symbolInfoIt != symbolInfoMap.end();) { 587 auto range = symbolInfoMap.equal_range(symbolInfoIt->first); 588 auto startRange = range.first; 589 auto endRange = range.second; 590 591 auto operandName = symbolInfoIt->first; 592 int startSearchIndex = 0; 593 for (++startRange; startRange != endRange; ++startRange) { 594 // Current operand name is not unique, find a unique one 595 // and set the alternative name. 596 for (int i = startSearchIndex;; ++i) { 597 std::string alternativeName = operandName + std::to_string(i); 598 if (!usedNames.contains(alternativeName) && 599 symbolInfoMap.count(alternativeName) == 0) { 600 usedNames.insert(alternativeName); 601 startRange->second.alternativeName = alternativeName; 602 startSearchIndex = i + 1; 603 604 break; 605 } 606 } 607 } 608 609 symbolInfoIt = endRange; 610 } 611 } 612 613 //===----------------------------------------------------------------------===// 614 // Pattern 615 //==----------------------------------------------------------------------===// 616 617 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) 618 : def(*def), recordOpMap(mapper) {} 619 620 DagNode Pattern::getSourcePattern() const { 621 return DagNode(def.getValueAsDag("sourcePattern")); 622 } 623 624 int Pattern::getNumResultPatterns() const { 625 auto *results = def.getValueAsListInit("resultPatterns"); 626 return results->size(); 627 } 628 629 DagNode Pattern::getResultPattern(unsigned index) const { 630 auto *results = def.getValueAsListInit("resultPatterns"); 631 return DagNode(cast<llvm::DagInit>(results->getElement(index))); 632 } 633 634 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) { 635 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); 636 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); 637 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); 638 639 LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n"); 640 infoMap.assignUniqueAlternativeNames(); 641 LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n"); 642 } 643 644 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { 645 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); 646 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { 647 auto pattern = getResultPattern(i); 648 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); 649 } 650 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); 651 } 652 653 const Operator &Pattern::getSourceRootOp() { 654 return getSourcePattern().getDialectOp(recordOpMap); 655 } 656 657 Operator &Pattern::getDialectOp(DagNode node) { 658 return node.getDialectOp(recordOpMap); 659 } 660 661 std::vector<AppliedConstraint> Pattern::getConstraints() const { 662 auto *listInit = def.getValueAsListInit("constraints"); 663 std::vector<AppliedConstraint> ret; 664 ret.reserve(listInit->size()); 665 666 for (auto *it : *listInit) { 667 auto *dagInit = dyn_cast<llvm::DagInit>(it); 668 if (!dagInit) 669 PrintFatalError(&def, "all elements in Pattern multi-entity " 670 "constraints should be DAG nodes"); 671 672 std::vector<std::string> entities; 673 entities.reserve(dagInit->arg_size()); 674 for (auto *argName : dagInit->getArgNames()) { 675 if (!argName) { 676 PrintFatalError( 677 &def, 678 "operands to additional constraints can only be symbol references"); 679 } 680 entities.emplace_back(argName->getValue()); 681 } 682 683 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(), 684 dagInit->getNameStr(), std::move(entities)); 685 } 686 return ret; 687 } 688 689 int Pattern::getNumSupplementalPatterns() const { 690 auto *results = def.getValueAsListInit("supplementalPatterns"); 691 return results->size(); 692 } 693 694 DagNode Pattern::getSupplementalPattern(unsigned index) const { 695 auto *results = def.getValueAsListInit("supplementalPatterns"); 696 return DagNode(cast<llvm::DagInit>(results->getElement(index))); 697 } 698 699 int Pattern::getBenefit() const { 700 // The initial benefit value is a heuristic with number of ops in the source 701 // pattern. 702 int initBenefit = getSourcePattern().getNumOps(); 703 llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); 704 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) { 705 PrintFatalError(&def, 706 "The 'addBenefit' takes and only takes one integer value"); 707 } 708 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); 709 } 710 711 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const { 712 std::vector<std::pair<StringRef, unsigned>> result; 713 result.reserve(def.getLoc().size()); 714 for (auto loc : def.getLoc()) { 715 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); 716 assert(buf && "invalid source location"); 717 result.emplace_back( 718 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), 719 llvm::SrcMgr.getLineAndColumn(loc, buf).first); 720 } 721 return result; 722 } 723 724 void Pattern::verifyBind(bool result, StringRef symbolName) { 725 if (!result) { 726 auto err = formatv("symbol '{0}' bound more than once", symbolName); 727 PrintFatalError(&def, err); 728 } 729 } 730 731 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 732 bool isSrcPattern) { 733 auto treeName = tree.getSymbol(); 734 auto numTreeArgs = tree.getNumArgs(); 735 736 if (tree.isNativeCodeCall()) { 737 if (!treeName.empty()) { 738 if (!isSrcPattern) { 739 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: " 740 << treeName << '\n'); 741 verifyBind( 742 infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()), 743 treeName); 744 } else { 745 PrintFatalError(&def, 746 formatv("binding symbol '{0}' to NativecodeCall in " 747 "MatchPattern is not supported", 748 treeName)); 749 } 750 } 751 752 for (int i = 0; i != numTreeArgs; ++i) { 753 if (auto treeArg = tree.getArgAsNestedDag(i)) { 754 // This DAG node argument is a DAG node itself. Go inside recursively. 755 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 756 continue; 757 } 758 759 if (!isSrcPattern) 760 continue; 761 762 // We can only bind symbols to arguments in source pattern. Those 763 // symbols are referenced in result patterns. 764 auto treeArgName = tree.getArgName(i); 765 766 // `$_` is a special symbol meaning ignore the current argument. 767 if (!treeArgName.empty() && treeArgName != "_") { 768 DagLeaf leaf = tree.getArgAsLeaf(i); 769 770 // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c), 771 if (leaf.isUnspecified()) { 772 // This is case of $c, a Value without any constraints. 773 verifyBind(infoMap.bindValue(treeArgName), treeArgName); 774 } else { 775 auto constraint = leaf.getAsConstraint(); 776 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || 777 leaf.isConstantAttr() || 778 constraint.getKind() == Constraint::Kind::CK_Attr; 779 780 if (isAttr) { 781 // This is case of $a, a binding to a certain attribute. 782 verifyBind(infoMap.bindAttr(treeArgName), treeArgName); 783 continue; 784 } 785 786 // This is case of $b, a binding to a certain type. 787 verifyBind(infoMap.bindValue(treeArgName), treeArgName); 788 } 789 } 790 } 791 792 return; 793 } 794 795 if (tree.isOperation()) { 796 auto &op = getDialectOp(tree); 797 auto numOpArgs = op.getNumArgs(); 798 int numEither = 0; 799 800 // We need to exclude the trailing directives and `either` directive groups 801 // two operands of the operation. 802 int numDirectives = 0; 803 for (int i = numTreeArgs - 1; i >= 0; --i) { 804 if (auto dagArg = tree.getArgAsNestedDag(i)) { 805 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective()) 806 ++numDirectives; 807 else if (dagArg.isEither()) 808 ++numEither; 809 } 810 } 811 812 if (numOpArgs != numTreeArgs - numDirectives + numEither) { 813 auto err = 814 formatv("op '{0}' argument number mismatch: " 815 "{1} in pattern vs. {2} in definition", 816 op.getOperationName(), numTreeArgs + numEither, numOpArgs); 817 PrintFatalError(&def, err); 818 } 819 820 // The name attached to the DAG node's operator is for representing the 821 // results generated from this op. It should be remembered as bound results. 822 if (!treeName.empty()) { 823 LLVM_DEBUG(llvm::dbgs() 824 << "found symbol bound to op result: " << treeName << '\n'); 825 verifyBind(infoMap.bindOpResult(treeName, op), treeName); 826 } 827 828 // The operand in `either` DAG should be bound to the operation in the 829 // parent DagNode. 830 auto collectSymbolInEither = [&](DagNode parent, DagNode tree, 831 int opArgIdx) { 832 for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) { 833 if (DagNode subTree = tree.getArgAsNestedDag(i)) { 834 collectBoundSymbols(subTree, infoMap, isSrcPattern); 835 } else { 836 auto argName = tree.getArgName(i); 837 if (!argName.empty() && argName != "_") { 838 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx), 839 argName); 840 } 841 } 842 } 843 }; 844 845 // The operand in `variadic` DAG should be bound to the operation in the 846 // parent DagNode. The range index must be included as well to distinguish 847 // (potentially) repeating argName within the `variadic` DAG. 848 auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree, 849 int opArgIdx) { 850 auto treeName = tree.getSymbol(); 851 if (!treeName.empty()) { 852 // If treeName is specified, bind to the full variadic operand_range. 853 verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx, 854 std::nullopt), 855 treeName); 856 } 857 858 for (int i = 0; i < tree.getNumArgs(); ++i) { 859 if (DagNode subTree = tree.getArgAsNestedDag(i)) { 860 collectBoundSymbols(subTree, infoMap, isSrcPattern); 861 } else { 862 auto argName = tree.getArgName(i); 863 if (!argName.empty() && argName != "_") { 864 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx, 865 /*variadicSubIndex=*/i), 866 argName); 867 } 868 } 869 } 870 }; 871 872 for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) { 873 if (auto treeArg = tree.getArgAsNestedDag(i)) { 874 if (treeArg.isEither()) { 875 collectSymbolInEither(tree, treeArg, opArgIdx); 876 // `either` DAG is *flattened*. For example, 877 // 878 // (FooOp (either arg0, arg1), arg2) 879 // 880 // can be viewed as: 881 // 882 // (FooOp arg0, arg1, arg2) 883 ++opArgIdx; 884 } else if (treeArg.isVariadic()) { 885 collectSymbolInVariadic(tree, treeArg, opArgIdx); 886 } else { 887 // This DAG node argument is a DAG node itself. Go inside recursively. 888 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 889 } 890 continue; 891 } 892 893 if (isSrcPattern) { 894 // We can only bind symbols to op arguments in source pattern. Those 895 // symbols are referenced in result patterns. 896 auto treeArgName = tree.getArgName(i); 897 // `$_` is a special symbol meaning ignore the current argument. 898 if (!treeArgName.empty() && treeArgName != "_") { 899 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " 900 << treeArgName << '\n'); 901 verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx), 902 treeArgName); 903 } 904 } 905 } 906 return; 907 } 908 909 if (!treeName.empty()) { 910 PrintFatalError( 911 &def, formatv("binding symbol '{0}' to non-operation/native code call " 912 "unsupported right now", 913 treeName)); 914 } 915 } 916