1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 10 // Pattern. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <utility> 15 16 #include "mlir/TableGen/Pattern.h" 17 #include "llvm/ADT/StringExtras.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/Debug.h" 20 #include "llvm/Support/FormatVariadic.h" 21 #include "llvm/TableGen/Error.h" 22 #include "llvm/TableGen/Record.h" 23 24 #define DEBUG_TYPE "mlir-tblgen-pattern" 25 26 using namespace mlir; 27 using namespace tblgen; 28 29 using llvm::formatv; 30 31 //===----------------------------------------------------------------------===// 32 // DagLeaf 33 //===----------------------------------------------------------------------===// 34 35 bool DagLeaf::isUnspecified() const { 36 return dyn_cast_or_null<llvm::UnsetInit>(def); 37 } 38 39 bool DagLeaf::isOperandMatcher() const { 40 // Operand matchers specify a type constraint. 41 return isSubClassOf("TypeConstraint"); 42 } 43 44 bool DagLeaf::isAttrMatcher() const { 45 // Attribute matchers specify an attribute constraint. 46 return isSubClassOf("AttrConstraint"); 47 } 48 49 bool DagLeaf::isNativeCodeCall() const { 50 return isSubClassOf("NativeCodeCall"); 51 } 52 53 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); } 54 55 bool DagLeaf::isEnumAttrCase() const { 56 return isSubClassOf("EnumAttrCaseInfo"); 57 } 58 59 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); } 60 61 Constraint DagLeaf::getAsConstraint() const { 62 assert((isOperandMatcher() || isAttrMatcher()) && 63 "the DAG leaf must be operand or attribute"); 64 return Constraint(cast<llvm::DefInit>(def)->getDef()); 65 } 66 67 ConstantAttr DagLeaf::getAsConstantAttr() const { 68 assert(isConstantAttr() && "the DAG leaf must be constant attribute"); 69 return ConstantAttr(cast<llvm::DefInit>(def)); 70 } 71 72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const { 73 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); 74 return EnumAttrCase(cast<llvm::DefInit>(def)); 75 } 76 77 std::string DagLeaf::getConditionTemplate() const { 78 return getAsConstraint().getConditionTemplate(); 79 } 80 81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const { 82 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 83 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression"); 84 } 85 86 int DagLeaf::getNumReturnsOfNativeCode() const { 87 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 88 return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns"); 89 } 90 91 std::string DagLeaf::getStringAttr() const { 92 assert(isStringAttr() && "the DAG leaf must be string attribute"); 93 return def->getAsUnquotedString(); 94 } 95 bool DagLeaf::isSubClassOf(StringRef superclass) const { 96 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def)) 97 return defInit->getDef()->isSubClassOf(superclass); 98 return false; 99 } 100 101 void DagLeaf::print(raw_ostream &os) const { 102 if (def) 103 def->print(os); 104 } 105 106 //===----------------------------------------------------------------------===// 107 // DagNode 108 //===----------------------------------------------------------------------===// 109 110 bool DagNode::isNativeCodeCall() const { 111 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) 112 return defInit->getDef()->isSubClassOf("NativeCodeCall"); 113 return false; 114 } 115 116 bool DagNode::isOperation() const { 117 return !isNativeCodeCall() && !isReplaceWithValue() && 118 !isLocationDirective() && !isReturnTypeDirective() && !isEither(); 119 } 120 121 llvm::StringRef DagNode::getNativeCodeTemplate() const { 122 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 123 return cast<llvm::DefInit>(node->getOperator()) 124 ->getDef() 125 ->getValueAsString("expression"); 126 } 127 128 int DagNode::getNumReturnsOfNativeCode() const { 129 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 130 return cast<llvm::DefInit>(node->getOperator()) 131 ->getDef() 132 ->getValueAsInt("numReturns"); 133 } 134 135 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); } 136 137 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { 138 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 139 auto it = mapper->find(opDef); 140 if (it != mapper->end()) 141 return *it->second; 142 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef)) 143 .first->second; 144 } 145 146 int DagNode::getNumOps() const { 147 // We want to get number of operations recursively involved in the DAG tree. 148 // All other directives should be excluded. 149 int count = isOperation() ? 1 : 0; 150 for (int i = 0, e = getNumArgs(); i != e; ++i) { 151 if (auto child = getArgAsNestedDag(i)) 152 count += child.getNumOps(); 153 } 154 return count; 155 } 156 157 int DagNode::getNumArgs() const { return node->getNumArgs(); } 158 159 bool DagNode::isNestedDagArg(unsigned index) const { 160 return isa<llvm::DagInit>(node->getArg(index)); 161 } 162 163 DagNode DagNode::getArgAsNestedDag(unsigned index) const { 164 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index))); 165 } 166 167 DagLeaf DagNode::getArgAsLeaf(unsigned index) const { 168 assert(!isNestedDagArg(index)); 169 return DagLeaf(node->getArg(index)); 170 } 171 172 StringRef DagNode::getArgName(unsigned index) const { 173 return node->getArgNameStr(index); 174 } 175 176 bool DagNode::isReplaceWithValue() const { 177 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 178 return dagOpDef->getName() == "replaceWithValue"; 179 } 180 181 bool DagNode::isLocationDirective() const { 182 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 183 return dagOpDef->getName() == "location"; 184 } 185 186 bool DagNode::isReturnTypeDirective() const { 187 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 188 return dagOpDef->getName() == "returnType"; 189 } 190 191 bool DagNode::isEither() const { 192 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 193 return dagOpDef->getName() == "either"; 194 } 195 196 void DagNode::print(raw_ostream &os) const { 197 if (node) 198 node->print(os); 199 } 200 201 //===----------------------------------------------------------------------===// 202 // SymbolInfoMap 203 //===----------------------------------------------------------------------===// 204 205 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) { 206 StringRef name, indexStr; 207 int idx = -1; 208 std::tie(name, indexStr) = symbol.rsplit("__"); 209 210 if (indexStr.consumeInteger(10, idx)) { 211 // The second part is not an index; we return the whole symbol as-is. 212 return symbol; 213 } 214 if (index) { 215 *index = idx; 216 } 217 return name; 218 } 219 220 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind, 221 Optional<DagAndConstant> dagAndConstant) 222 : op(op), kind(kind), dagAndConstant(std::move(dagAndConstant)) {} 223 224 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { 225 switch (kind) { 226 case Kind::Attr: 227 case Kind::Operand: 228 case Kind::Value: 229 return 1; 230 case Kind::Result: 231 return op->getNumResults(); 232 case Kind::MultipleValues: 233 return getSize(); 234 } 235 llvm_unreachable("unknown kind"); 236 } 237 238 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const { 239 return alternativeName ? *alternativeName : name.str(); 240 } 241 242 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const { 243 LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': "); 244 switch (kind) { 245 case Kind::Attr: { 246 if (op) 247 return op->getArg(getArgIndex()) 248 .get<NamedAttribute *>() 249 ->attr.getStorageType() 250 .str(); 251 // TODO(suderman): Use a more exact type when available. 252 return "::mlir::Attribute"; 253 } 254 case Kind::Operand: { 255 // Use operand range for captured operands (to support potential variadic 256 // operands). 257 return "::mlir::Operation::operand_range"; 258 } 259 case Kind::Value: { 260 return "::mlir::Value"; 261 } 262 case Kind::MultipleValues: { 263 return "::mlir::ValueRange"; 264 } 265 case Kind::Result: { 266 // Use the op itself for captured results. 267 return op->getQualCppClassName(); 268 } 269 } 270 llvm_unreachable("unknown kind"); 271 } 272 273 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { 274 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); 275 std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : ""; 276 return std::string( 277 formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit)); 278 } 279 280 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const { 281 LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': "); 282 return std::string( 283 formatv("{0} &{1}", getVarTypeStr(name), getVarName(name))); 284 } 285 286 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( 287 StringRef name, int index, const char *fmt, const char *separator) const { 288 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); 289 switch (kind) { 290 case Kind::Attr: { 291 assert(index < 0); 292 auto repl = formatv(fmt, name); 293 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n"); 294 return std::string(repl); 295 } 296 case Kind::Operand: { 297 assert(index < 0); 298 auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>(); 299 // If this operand is variadic, then return a range. Otherwise, return the 300 // value itself. 301 if (operand->isVariableLength()) { 302 auto repl = formatv(fmt, name); 303 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); 304 return std::string(repl); 305 } 306 auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); 307 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); 308 return std::string(repl); 309 } 310 case Kind::Result: { 311 // If `index` is greater than zero, then we are referencing a specific 312 // result of a multi-result op. The result can still be variadic. 313 if (index >= 0) { 314 std::string v = 315 std::string(formatv("{0}.getODSResults({1})", name, index)); 316 if (!op->getResult(index).isVariadic()) 317 v = std::string(formatv("(*{0}.begin())", v)); 318 auto repl = formatv(fmt, v); 319 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 320 return std::string(repl); 321 } 322 323 // If this op has no result at all but still we bind a symbol to it, it 324 // means we want to capture the op itself. 325 if (op->getNumResults() == 0) { 326 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); 327 return formatv(fmt, name); 328 } 329 330 // We are referencing all results of the multi-result op. A specific result 331 // can either be a value or a range. Then join them with `separator`. 332 SmallVector<std::string, 4> values; 333 values.reserve(op->getNumResults()); 334 335 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 336 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i)); 337 if (!op->getResult(i).isVariadic()) { 338 v = std::string(formatv("(*{0}.begin())", v)); 339 } 340 values.push_back(std::string(formatv(fmt, v))); 341 } 342 auto repl = llvm::join(values, separator); 343 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 344 return repl; 345 } 346 case Kind::Value: { 347 assert(index < 0); 348 assert(op == nullptr); 349 auto repl = formatv(fmt, name); 350 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 351 return std::string(repl); 352 } 353 case Kind::MultipleValues: { 354 assert(op == nullptr); 355 assert(index < getSize()); 356 if (index >= 0) { 357 std::string repl = 358 formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); 359 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 360 return repl; 361 } 362 // If it doesn't specify certain element, unpack them all. 363 auto repl = 364 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); 365 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 366 return std::string(repl); 367 } 368 } 369 llvm_unreachable("unknown kind"); 370 } 371 372 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( 373 StringRef name, int index, const char *fmt, const char *separator) const { 374 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); 375 switch (kind) { 376 case Kind::Attr: 377 case Kind::Operand: { 378 assert(index < 0 && "only allowed for symbol bound to result"); 379 auto repl = formatv(fmt, name); 380 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n"); 381 return std::string(repl); 382 } 383 case Kind::Result: { 384 if (index >= 0) { 385 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); 386 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 387 return std::string(repl); 388 } 389 390 // We are referencing all results of the multi-result op. Each result should 391 // have a value range, and then join them with `separator`. 392 SmallVector<std::string, 4> values; 393 values.reserve(op->getNumResults()); 394 395 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 396 values.push_back(std::string( 397 formatv(fmt, formatv("{0}.getODSResults({1})", name, i)))); 398 } 399 auto repl = llvm::join(values, separator); 400 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 401 return repl; 402 } 403 case Kind::Value: { 404 assert(index < 0 && "only allowed for symbol bound to result"); 405 assert(op == nullptr); 406 auto repl = formatv(fmt, formatv("{{{0}}", name)); 407 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 408 return std::string(repl); 409 } 410 case Kind::MultipleValues: { 411 assert(op == nullptr); 412 assert(index < getSize()); 413 if (index >= 0) { 414 std::string repl = 415 formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); 416 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 417 return repl; 418 } 419 auto repl = 420 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); 421 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 422 return std::string(repl); 423 } 424 } 425 llvm_unreachable("unknown kind"); 426 } 427 428 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, 429 const Operator &op, int argIndex) { 430 StringRef name = getValuePackName(symbol); 431 if (name != symbol) { 432 auto error = formatv( 433 "symbol '{0}' with trailing index cannot bind to op argument", symbol); 434 PrintFatalError(loc, error); 435 } 436 437 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>() 438 ? SymbolInfo::getAttr(&op, argIndex) 439 : SymbolInfo::getOperand(node, &op, argIndex); 440 441 std::string key = symbol.str(); 442 if (symbolInfoMap.count(key)) { 443 // Only non unique name for the operand is supported. 444 if (symInfo.kind != SymbolInfo::Kind::Operand) { 445 return false; 446 } 447 448 // Cannot add new operand if there is already non operand with the same 449 // name. 450 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) { 451 return false; 452 } 453 } 454 455 symbolInfoMap.emplace(key, symInfo); 456 return true; 457 } 458 459 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { 460 std::string name = getValuePackName(symbol).str(); 461 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op)); 462 463 return symbolInfoMap.count(inserted->first) == 1; 464 } 465 466 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) { 467 std::string name = getValuePackName(symbol).str(); 468 if (numValues > 1) 469 return bindMultipleValues(name, numValues); 470 return bindValue(name); 471 } 472 473 bool SymbolInfoMap::bindValue(StringRef symbol) { 474 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue()); 475 return symbolInfoMap.count(inserted->first) == 1; 476 } 477 478 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) { 479 std::string name = getValuePackName(symbol).str(); 480 auto inserted = 481 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues)); 482 return symbolInfoMap.count(inserted->first) == 1; 483 } 484 485 bool SymbolInfoMap::bindAttr(StringRef symbol) { 486 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr()); 487 return symbolInfoMap.count(inserted->first) == 1; 488 } 489 490 bool SymbolInfoMap::contains(StringRef symbol) const { 491 return find(symbol) != symbolInfoMap.end(); 492 } 493 494 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const { 495 std::string name = getValuePackName(key).str(); 496 497 return symbolInfoMap.find(name); 498 } 499 500 SymbolInfoMap::const_iterator 501 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op, 502 int argIndex) const { 503 return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex)); 504 } 505 506 SymbolInfoMap::const_iterator 507 SymbolInfoMap::findBoundSymbol(StringRef key, 508 const SymbolInfo &symbolInfo) const { 509 std::string name = getValuePackName(key).str(); 510 auto range = symbolInfoMap.equal_range(name); 511 512 for (auto it = range.first; it != range.second; ++it) 513 if (it->second.dagAndConstant == symbolInfo.dagAndConstant) 514 return it; 515 516 return symbolInfoMap.end(); 517 } 518 519 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator> 520 SymbolInfoMap::getRangeOfEqualElements(StringRef key) { 521 std::string name = getValuePackName(key).str(); 522 523 return symbolInfoMap.equal_range(name); 524 } 525 526 int SymbolInfoMap::count(StringRef key) const { 527 std::string name = getValuePackName(key).str(); 528 return symbolInfoMap.count(name); 529 } 530 531 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const { 532 StringRef name = getValuePackName(symbol); 533 if (name != symbol) { 534 // If there is a trailing index inside symbol, it references just one 535 // static value. 536 return 1; 537 } 538 // Otherwise, find how many it represents by querying the symbol's info. 539 return find(name)->second.getStaticValueCount(); 540 } 541 542 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol, 543 const char *fmt, 544 const char *separator) const { 545 int index = -1; 546 StringRef name = getValuePackName(symbol, &index); 547 548 auto it = symbolInfoMap.find(name.str()); 549 if (it == symbolInfoMap.end()) { 550 auto error = formatv("referencing unbound symbol '{0}'", symbol); 551 PrintFatalError(loc, error); 552 } 553 554 return it->second.getValueAndRangeUse(name, index, fmt, separator); 555 } 556 557 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt, 558 const char *separator) const { 559 int index = -1; 560 StringRef name = getValuePackName(symbol, &index); 561 562 auto it = symbolInfoMap.find(name.str()); 563 if (it == symbolInfoMap.end()) { 564 auto error = formatv("referencing unbound symbol '{0}'", symbol); 565 PrintFatalError(loc, error); 566 } 567 568 return it->second.getAllRangeUse(name, index, fmt, separator); 569 } 570 571 void SymbolInfoMap::assignUniqueAlternativeNames() { 572 llvm::StringSet<> usedNames; 573 574 for (auto symbolInfoIt = symbolInfoMap.begin(); 575 symbolInfoIt != symbolInfoMap.end();) { 576 auto range = symbolInfoMap.equal_range(symbolInfoIt->first); 577 auto startRange = range.first; 578 auto endRange = range.second; 579 580 auto operandName = symbolInfoIt->first; 581 int startSearchIndex = 0; 582 for (++startRange; startRange != endRange; ++startRange) { 583 // Current operand name is not unique, find a unique one 584 // and set the alternative name. 585 for (int i = startSearchIndex;; ++i) { 586 std::string alternativeName = operandName + std::to_string(i); 587 if (!usedNames.contains(alternativeName) && 588 symbolInfoMap.count(alternativeName) == 0) { 589 usedNames.insert(alternativeName); 590 startRange->second.alternativeName = alternativeName; 591 startSearchIndex = i + 1; 592 593 break; 594 } 595 } 596 } 597 598 symbolInfoIt = endRange; 599 } 600 } 601 602 //===----------------------------------------------------------------------===// 603 // Pattern 604 //==----------------------------------------------------------------------===// 605 606 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) 607 : def(*def), recordOpMap(mapper) {} 608 609 DagNode Pattern::getSourcePattern() const { 610 return DagNode(def.getValueAsDag("sourcePattern")); 611 } 612 613 int Pattern::getNumResultPatterns() const { 614 auto *results = def.getValueAsListInit("resultPatterns"); 615 return results->size(); 616 } 617 618 DagNode Pattern::getResultPattern(unsigned index) const { 619 auto *results = def.getValueAsListInit("resultPatterns"); 620 return DagNode(cast<llvm::DagInit>(results->getElement(index))); 621 } 622 623 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) { 624 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); 625 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); 626 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); 627 628 LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n"); 629 infoMap.assignUniqueAlternativeNames(); 630 LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n"); 631 } 632 633 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { 634 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); 635 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { 636 auto pattern = getResultPattern(i); 637 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); 638 } 639 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); 640 } 641 642 const Operator &Pattern::getSourceRootOp() { 643 return getSourcePattern().getDialectOp(recordOpMap); 644 } 645 646 Operator &Pattern::getDialectOp(DagNode node) { 647 return node.getDialectOp(recordOpMap); 648 } 649 650 std::vector<AppliedConstraint> Pattern::getConstraints() const { 651 auto *listInit = def.getValueAsListInit("constraints"); 652 std::vector<AppliedConstraint> ret; 653 ret.reserve(listInit->size()); 654 655 for (auto *it : *listInit) { 656 auto *dagInit = dyn_cast<llvm::DagInit>(it); 657 if (!dagInit) 658 PrintFatalError(&def, "all elements in Pattern multi-entity " 659 "constraints should be DAG nodes"); 660 661 std::vector<std::string> entities; 662 entities.reserve(dagInit->arg_size()); 663 for (auto *argName : dagInit->getArgNames()) { 664 if (!argName) { 665 PrintFatalError( 666 &def, 667 "operands to additional constraints can only be symbol references"); 668 } 669 entities.emplace_back(argName->getValue()); 670 } 671 672 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(), 673 dagInit->getNameStr(), std::move(entities)); 674 } 675 return ret; 676 } 677 678 int Pattern::getBenefit() const { 679 // The initial benefit value is a heuristic with number of ops in the source 680 // pattern. 681 int initBenefit = getSourcePattern().getNumOps(); 682 llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); 683 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) { 684 PrintFatalError(&def, 685 "The 'addBenefit' takes and only takes one integer value"); 686 } 687 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); 688 } 689 690 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const { 691 std::vector<std::pair<StringRef, unsigned>> result; 692 result.reserve(def.getLoc().size()); 693 for (auto loc : def.getLoc()) { 694 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); 695 assert(buf && "invalid source location"); 696 result.emplace_back( 697 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), 698 llvm::SrcMgr.getLineAndColumn(loc, buf).first); 699 } 700 return result; 701 } 702 703 void Pattern::verifyBind(bool result, StringRef symbolName) { 704 if (!result) { 705 auto err = formatv("symbol '{0}' bound more than once", symbolName); 706 PrintFatalError(&def, err); 707 } 708 } 709 710 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 711 bool isSrcPattern) { 712 auto treeName = tree.getSymbol(); 713 auto numTreeArgs = tree.getNumArgs(); 714 715 if (tree.isNativeCodeCall()) { 716 if (!treeName.empty()) { 717 if (!isSrcPattern) { 718 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: " 719 << treeName << '\n'); 720 verifyBind( 721 infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()), 722 treeName); 723 } else { 724 PrintFatalError(&def, 725 formatv("binding symbol '{0}' to NativecodeCall in " 726 "MatchPattern is not supported", 727 treeName)); 728 } 729 } 730 731 for (int i = 0; i != numTreeArgs; ++i) { 732 if (auto treeArg = tree.getArgAsNestedDag(i)) { 733 // This DAG node argument is a DAG node itself. Go inside recursively. 734 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 735 continue; 736 } 737 738 if (!isSrcPattern) 739 continue; 740 741 // We can only bind symbols to arguments in source pattern. Those 742 // symbols are referenced in result patterns. 743 auto treeArgName = tree.getArgName(i); 744 745 // `$_` is a special symbol meaning ignore the current argument. 746 if (!treeArgName.empty() && treeArgName != "_") { 747 DagLeaf leaf = tree.getArgAsLeaf(i); 748 749 // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c), 750 if (leaf.isUnspecified()) { 751 // This is case of $c, a Value without any constraints. 752 verifyBind(infoMap.bindValue(treeArgName), treeArgName); 753 } else { 754 auto constraint = leaf.getAsConstraint(); 755 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || 756 leaf.isConstantAttr() || 757 constraint.getKind() == Constraint::Kind::CK_Attr; 758 759 if (isAttr) { 760 // This is case of $a, a binding to a certain attribute. 761 verifyBind(infoMap.bindAttr(treeArgName), treeArgName); 762 continue; 763 } 764 765 // This is case of $b, a binding to a certain type. 766 verifyBind(infoMap.bindValue(treeArgName), treeArgName); 767 } 768 } 769 } 770 771 return; 772 } 773 774 if (tree.isOperation()) { 775 auto &op = getDialectOp(tree); 776 auto numOpArgs = op.getNumArgs(); 777 int numEither = 0; 778 779 // We need to exclude the trailing directives and `either` directive groups 780 // two operands of the operation. 781 int numDirectives = 0; 782 for (int i = numTreeArgs - 1; i >= 0; --i) { 783 if (auto dagArg = tree.getArgAsNestedDag(i)) { 784 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective()) 785 ++numDirectives; 786 else if (dagArg.isEither()) 787 ++numEither; 788 } 789 } 790 791 if (numOpArgs != numTreeArgs - numDirectives + numEither) { 792 auto err = 793 formatv("op '{0}' argument number mismatch: " 794 "{1} in pattern vs. {2} in definition", 795 op.getOperationName(), numTreeArgs + numEither, numOpArgs); 796 PrintFatalError(&def, err); 797 } 798 799 // The name attached to the DAG node's operator is for representing the 800 // results generated from this op. It should be remembered as bound results. 801 if (!treeName.empty()) { 802 LLVM_DEBUG(llvm::dbgs() 803 << "found symbol bound to op result: " << treeName << '\n'); 804 verifyBind(infoMap.bindOpResult(treeName, op), treeName); 805 } 806 807 // The operand in `either` DAG should be bound to the operation in the 808 // parent DagNode. 809 auto collectSymbolInEither = [&](DagNode parent, DagNode tree, 810 int &opArgIdx) { 811 for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) { 812 if (DagNode subTree = tree.getArgAsNestedDag(i)) { 813 collectBoundSymbols(subTree, infoMap, isSrcPattern); 814 } else { 815 auto argName = tree.getArgName(i); 816 if (!argName.empty() && argName != "_") 817 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx), 818 argName); 819 } 820 } 821 }; 822 823 for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) { 824 if (auto treeArg = tree.getArgAsNestedDag(i)) { 825 if (treeArg.isEither()) { 826 collectSymbolInEither(tree, treeArg, opArgIdx); 827 } else { 828 // This DAG node argument is a DAG node itself. Go inside recursively. 829 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 830 } 831 continue; 832 } 833 834 if (isSrcPattern) { 835 // We can only bind symbols to op arguments in source pattern. Those 836 // symbols are referenced in result patterns. 837 auto treeArgName = tree.getArgName(i); 838 // `$_` is a special symbol meaning ignore the current argument. 839 if (!treeArgName.empty() && treeArgName != "_") { 840 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " 841 << treeArgName << '\n'); 842 verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx), 843 treeArgName); 844 } 845 } 846 } 847 return; 848 } 849 850 if (!treeName.empty()) { 851 PrintFatalError( 852 &def, formatv("binding symbol '{0}' to non-operation/native code call " 853 "unsupported right now", 854 treeName)); 855 } 856 } 857