1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 10 // Pattern. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/Pattern.h" 15 #include "llvm/ADT/StringExtras.h" 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Support/Debug.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 #define DEBUG_TYPE "mlir-tblgen-pattern" 23 24 using namespace mlir; 25 using namespace tblgen; 26 27 using llvm::formatv; 28 29 //===----------------------------------------------------------------------===// 30 // DagLeaf 31 //===----------------------------------------------------------------------===// 32 33 bool DagLeaf::isUnspecified() const { 34 return dyn_cast_or_null<llvm::UnsetInit>(def); 35 } 36 37 bool DagLeaf::isOperandMatcher() const { 38 // Operand matchers specify a type constraint. 39 return isSubClassOf("TypeConstraint"); 40 } 41 42 bool DagLeaf::isAttrMatcher() const { 43 // Attribute matchers specify an attribute constraint. 44 return isSubClassOf("AttrConstraint"); 45 } 46 47 bool DagLeaf::isNativeCodeCall() const { 48 return isSubClassOf("NativeCodeCall"); 49 } 50 51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); } 52 53 bool DagLeaf::isEnumAttrCase() const { 54 return isSubClassOf("EnumAttrCaseInfo"); 55 } 56 57 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); } 58 59 Constraint DagLeaf::getAsConstraint() const { 60 assert((isOperandMatcher() || isAttrMatcher()) && 61 "the DAG leaf must be operand or attribute"); 62 return Constraint(cast<llvm::DefInit>(def)->getDef()); 63 } 64 65 ConstantAttr DagLeaf::getAsConstantAttr() const { 66 assert(isConstantAttr() && "the DAG leaf must be constant attribute"); 67 return ConstantAttr(cast<llvm::DefInit>(def)); 68 } 69 70 EnumAttrCase DagLeaf::getAsEnumAttrCase() const { 71 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); 72 return EnumAttrCase(cast<llvm::DefInit>(def)); 73 } 74 75 std::string DagLeaf::getConditionTemplate() const { 76 return getAsConstraint().getConditionTemplate(); 77 } 78 79 llvm::StringRef DagLeaf::getNativeCodeTemplate() const { 80 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 81 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression"); 82 } 83 84 int DagLeaf::getNumReturnsOfNativeCode() const { 85 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 86 return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns"); 87 } 88 89 std::string DagLeaf::getStringAttr() const { 90 assert(isStringAttr() && "the DAG leaf must be string attribute"); 91 return def->getAsUnquotedString(); 92 } 93 bool DagLeaf::isSubClassOf(StringRef superclass) const { 94 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def)) 95 return defInit->getDef()->isSubClassOf(superclass); 96 return false; 97 } 98 99 void DagLeaf::print(raw_ostream &os) const { 100 if (def) 101 def->print(os); 102 } 103 104 //===----------------------------------------------------------------------===// 105 // DagNode 106 //===----------------------------------------------------------------------===// 107 108 bool DagNode::isNativeCodeCall() const { 109 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) 110 return defInit->getDef()->isSubClassOf("NativeCodeCall"); 111 return false; 112 } 113 114 bool DagNode::isOperation() const { 115 return !isNativeCodeCall() && !isReplaceWithValue() && 116 !isLocationDirective() && !isReturnTypeDirective() && !isEither(); 117 } 118 119 llvm::StringRef DagNode::getNativeCodeTemplate() const { 120 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 121 return cast<llvm::DefInit>(node->getOperator()) 122 ->getDef() 123 ->getValueAsString("expression"); 124 } 125 126 int DagNode::getNumReturnsOfNativeCode() const { 127 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 128 return cast<llvm::DefInit>(node->getOperator()) 129 ->getDef() 130 ->getValueAsInt("numReturns"); 131 } 132 133 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); } 134 135 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { 136 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 137 auto it = mapper->find(opDef); 138 if (it != mapper->end()) 139 return *it->second; 140 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef)) 141 .first->second; 142 } 143 144 int DagNode::getNumOps() const { 145 // We want to get number of operations recursively involved in the DAG tree. 146 // All other directives should be excluded. 147 int count = isOperation() ? 1 : 0; 148 for (int i = 0, e = getNumArgs(); i != e; ++i) { 149 if (auto child = getArgAsNestedDag(i)) 150 count += child.getNumOps(); 151 } 152 return count; 153 } 154 155 int DagNode::getNumArgs() const { return node->getNumArgs(); } 156 157 bool DagNode::isNestedDagArg(unsigned index) const { 158 return isa<llvm::DagInit>(node->getArg(index)); 159 } 160 161 DagNode DagNode::getArgAsNestedDag(unsigned index) const { 162 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index))); 163 } 164 165 DagLeaf DagNode::getArgAsLeaf(unsigned index) const { 166 assert(!isNestedDagArg(index)); 167 return DagLeaf(node->getArg(index)); 168 } 169 170 StringRef DagNode::getArgName(unsigned index) const { 171 return node->getArgNameStr(index); 172 } 173 174 bool DagNode::isReplaceWithValue() const { 175 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 176 return dagOpDef->getName() == "replaceWithValue"; 177 } 178 179 bool DagNode::isLocationDirective() const { 180 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 181 return dagOpDef->getName() == "location"; 182 } 183 184 bool DagNode::isReturnTypeDirective() const { 185 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 186 return dagOpDef->getName() == "returnType"; 187 } 188 189 bool DagNode::isEither() const { 190 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 191 return dagOpDef->getName() == "either"; 192 } 193 194 void DagNode::print(raw_ostream &os) const { 195 if (node) 196 node->print(os); 197 } 198 199 //===----------------------------------------------------------------------===// 200 // SymbolInfoMap 201 //===----------------------------------------------------------------------===// 202 203 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) { 204 StringRef name, indexStr; 205 int idx = -1; 206 std::tie(name, indexStr) = symbol.rsplit("__"); 207 208 if (indexStr.consumeInteger(10, idx)) { 209 // The second part is not an index; we return the whole symbol as-is. 210 return symbol; 211 } 212 if (index) { 213 *index = idx; 214 } 215 return name; 216 } 217 218 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind, 219 Optional<DagAndConstant> dagAndConstant) 220 : op(op), kind(kind), dagAndConstant(dagAndConstant) {} 221 222 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { 223 switch (kind) { 224 case Kind::Attr: 225 case Kind::Operand: 226 case Kind::Value: 227 return 1; 228 case Kind::Result: 229 return op->getNumResults(); 230 case Kind::MultipleValues: 231 return getSize(); 232 } 233 llvm_unreachable("unknown kind"); 234 } 235 236 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const { 237 return alternativeName.hasValue() ? alternativeName.getValue() : name.str(); 238 } 239 240 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const { 241 LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': "); 242 switch (kind) { 243 case Kind::Attr: { 244 if (op) 245 return op->getArg(getArgIndex()) 246 .get<NamedAttribute *>() 247 ->attr.getStorageType() 248 .str(); 249 // TODO(suderman): Use a more exact type when available. 250 return "Attribute"; 251 } 252 case Kind::Operand: { 253 // Use operand range for captured operands (to support potential variadic 254 // operands). 255 return "::mlir::Operation::operand_range"; 256 } 257 case Kind::Value: { 258 return "::mlir::Value"; 259 } 260 case Kind::MultipleValues: { 261 return "::mlir::ValueRange"; 262 } 263 case Kind::Result: { 264 // Use the op itself for captured results. 265 return op->getQualCppClassName(); 266 } 267 } 268 llvm_unreachable("unknown kind"); 269 } 270 271 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { 272 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); 273 std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : ""; 274 return std::string( 275 formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit)); 276 } 277 278 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const { 279 LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': "); 280 return std::string( 281 formatv("{0} &{1}", getVarTypeStr(name), getVarName(name))); 282 } 283 284 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( 285 StringRef name, int index, const char *fmt, const char *separator) const { 286 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); 287 switch (kind) { 288 case Kind::Attr: { 289 assert(index < 0); 290 auto repl = formatv(fmt, name); 291 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n"); 292 return std::string(repl); 293 } 294 case Kind::Operand: { 295 assert(index < 0); 296 auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>(); 297 // If this operand is variadic, then return a range. Otherwise, return the 298 // value itself. 299 if (operand->isVariableLength()) { 300 auto repl = formatv(fmt, name); 301 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); 302 return std::string(repl); 303 } 304 auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); 305 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); 306 return std::string(repl); 307 } 308 case Kind::Result: { 309 // If `index` is greater than zero, then we are referencing a specific 310 // result of a multi-result op. The result can still be variadic. 311 if (index >= 0) { 312 std::string v = 313 std::string(formatv("{0}.getODSResults({1})", name, index)); 314 if (!op->getResult(index).isVariadic()) 315 v = std::string(formatv("(*{0}.begin())", v)); 316 auto repl = formatv(fmt, v); 317 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 318 return std::string(repl); 319 } 320 321 // If this op has no result at all but still we bind a symbol to it, it 322 // means we want to capture the op itself. 323 if (op->getNumResults() == 0) { 324 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); 325 return std::string(name); 326 } 327 328 // We are referencing all results of the multi-result op. A specific result 329 // can either be a value or a range. Then join them with `separator`. 330 SmallVector<std::string, 4> values; 331 values.reserve(op->getNumResults()); 332 333 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 334 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i)); 335 if (!op->getResult(i).isVariadic()) { 336 v = std::string(formatv("(*{0}.begin())", v)); 337 } 338 values.push_back(std::string(formatv(fmt, v))); 339 } 340 auto repl = llvm::join(values, separator); 341 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 342 return repl; 343 } 344 case Kind::Value: { 345 assert(index < 0); 346 assert(op == nullptr); 347 auto repl = formatv(fmt, name); 348 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 349 return std::string(repl); 350 } 351 case Kind::MultipleValues: { 352 assert(op == nullptr); 353 assert(index < getSize()); 354 if (index >= 0) { 355 std::string repl = 356 formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); 357 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 358 return repl; 359 } 360 // If it doesn't specify certain element, unpack them all. 361 auto repl = 362 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); 363 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 364 return std::string(repl); 365 } 366 } 367 llvm_unreachable("unknown kind"); 368 } 369 370 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( 371 StringRef name, int index, const char *fmt, const char *separator) const { 372 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); 373 switch (kind) { 374 case Kind::Attr: 375 case Kind::Operand: { 376 assert(index < 0 && "only allowed for symbol bound to result"); 377 auto repl = formatv(fmt, name); 378 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n"); 379 return std::string(repl); 380 } 381 case Kind::Result: { 382 if (index >= 0) { 383 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); 384 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 385 return std::string(repl); 386 } 387 388 // We are referencing all results of the multi-result op. Each result should 389 // have a value range, and then join them with `separator`. 390 SmallVector<std::string, 4> values; 391 values.reserve(op->getNumResults()); 392 393 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 394 values.push_back(std::string( 395 formatv(fmt, formatv("{0}.getODSResults({1})", name, i)))); 396 } 397 auto repl = llvm::join(values, separator); 398 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 399 return repl; 400 } 401 case Kind::Value: { 402 assert(index < 0 && "only allowed for symbol bound to result"); 403 assert(op == nullptr); 404 auto repl = formatv(fmt, formatv("{{{0}}", name)); 405 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 406 return std::string(repl); 407 } 408 case Kind::MultipleValues: { 409 assert(op == nullptr); 410 assert(index < getSize()); 411 if (index >= 0) { 412 std::string repl = 413 formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); 414 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 415 return repl; 416 } 417 auto repl = 418 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); 419 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 420 return std::string(repl); 421 } 422 } 423 llvm_unreachable("unknown kind"); 424 } 425 426 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, 427 const Operator &op, int argIndex) { 428 StringRef name = getValuePackName(symbol); 429 if (name != symbol) { 430 auto error = formatv( 431 "symbol '{0}' with trailing index cannot bind to op argument", symbol); 432 PrintFatalError(loc, error); 433 } 434 435 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>() 436 ? SymbolInfo::getAttr(&op, argIndex) 437 : SymbolInfo::getOperand(node, &op, argIndex); 438 439 std::string key = symbol.str(); 440 if (symbolInfoMap.count(key)) { 441 // Only non unique name for the operand is supported. 442 if (symInfo.kind != SymbolInfo::Kind::Operand) { 443 return false; 444 } 445 446 // Cannot add new operand if there is already non operand with the same 447 // name. 448 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) { 449 return false; 450 } 451 } 452 453 symbolInfoMap.emplace(key, symInfo); 454 return true; 455 } 456 457 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { 458 std::string name = getValuePackName(symbol).str(); 459 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op)); 460 461 return symbolInfoMap.count(inserted->first) == 1; 462 } 463 464 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) { 465 std::string name = getValuePackName(symbol).str(); 466 if (numValues > 1) 467 return bindMultipleValues(name, numValues); 468 return bindValue(name); 469 } 470 471 bool SymbolInfoMap::bindValue(StringRef symbol) { 472 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue()); 473 return symbolInfoMap.count(inserted->first) == 1; 474 } 475 476 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) { 477 std::string name = getValuePackName(symbol).str(); 478 auto inserted = 479 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues)); 480 return symbolInfoMap.count(inserted->first) == 1; 481 } 482 483 bool SymbolInfoMap::bindAttr(StringRef symbol) { 484 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr()); 485 return symbolInfoMap.count(inserted->first) == 1; 486 } 487 488 bool SymbolInfoMap::contains(StringRef symbol) const { 489 return find(symbol) != symbolInfoMap.end(); 490 } 491 492 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const { 493 std::string name = getValuePackName(key).str(); 494 495 return symbolInfoMap.find(name); 496 } 497 498 SymbolInfoMap::const_iterator 499 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op, 500 int argIndex) const { 501 return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex)); 502 } 503 504 SymbolInfoMap::const_iterator 505 SymbolInfoMap::findBoundSymbol(StringRef key, SymbolInfo symbolInfo) const { 506 std::string name = getValuePackName(key).str(); 507 auto range = symbolInfoMap.equal_range(name); 508 509 for (auto it = range.first; it != range.second; ++it) 510 if (it->second.dagAndConstant == symbolInfo.dagAndConstant) 511 return it; 512 513 return symbolInfoMap.end(); 514 } 515 516 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator> 517 SymbolInfoMap::getRangeOfEqualElements(StringRef key) { 518 std::string name = getValuePackName(key).str(); 519 520 return symbolInfoMap.equal_range(name); 521 } 522 523 int SymbolInfoMap::count(StringRef key) const { 524 std::string name = getValuePackName(key).str(); 525 return symbolInfoMap.count(name); 526 } 527 528 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const { 529 StringRef name = getValuePackName(symbol); 530 if (name != symbol) { 531 // If there is a trailing index inside symbol, it references just one 532 // static value. 533 return 1; 534 } 535 // Otherwise, find how many it represents by querying the symbol's info. 536 return find(name)->second.getStaticValueCount(); 537 } 538 539 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol, 540 const char *fmt, 541 const char *separator) const { 542 int index = -1; 543 StringRef name = getValuePackName(symbol, &index); 544 545 auto it = symbolInfoMap.find(name.str()); 546 if (it == symbolInfoMap.end()) { 547 auto error = formatv("referencing unbound symbol '{0}'", symbol); 548 PrintFatalError(loc, error); 549 } 550 551 return it->second.getValueAndRangeUse(name, index, fmt, separator); 552 } 553 554 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt, 555 const char *separator) const { 556 int index = -1; 557 StringRef name = getValuePackName(symbol, &index); 558 559 auto it = symbolInfoMap.find(name.str()); 560 if (it == symbolInfoMap.end()) { 561 auto error = formatv("referencing unbound symbol '{0}'", symbol); 562 PrintFatalError(loc, error); 563 } 564 565 return it->second.getAllRangeUse(name, index, fmt, separator); 566 } 567 568 void SymbolInfoMap::assignUniqueAlternativeNames() { 569 llvm::StringSet<> usedNames; 570 571 for (auto symbolInfoIt = symbolInfoMap.begin(); 572 symbolInfoIt != symbolInfoMap.end();) { 573 auto range = symbolInfoMap.equal_range(symbolInfoIt->first); 574 auto startRange = range.first; 575 auto endRange = range.second; 576 577 auto operandName = symbolInfoIt->first; 578 int startSearchIndex = 0; 579 for (++startRange; startRange != endRange; ++startRange) { 580 // Current operand name is not unique, find a unique one 581 // and set the alternative name. 582 for (int i = startSearchIndex;; ++i) { 583 std::string alternativeName = operandName + std::to_string(i); 584 if (!usedNames.contains(alternativeName) && 585 symbolInfoMap.count(alternativeName) == 0) { 586 usedNames.insert(alternativeName); 587 startRange->second.alternativeName = alternativeName; 588 startSearchIndex = i + 1; 589 590 break; 591 } 592 } 593 } 594 595 symbolInfoIt = endRange; 596 } 597 } 598 599 //===----------------------------------------------------------------------===// 600 // Pattern 601 //==----------------------------------------------------------------------===// 602 603 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) 604 : def(*def), recordOpMap(mapper) {} 605 606 DagNode Pattern::getSourcePattern() const { 607 return DagNode(def.getValueAsDag("sourcePattern")); 608 } 609 610 int Pattern::getNumResultPatterns() const { 611 auto *results = def.getValueAsListInit("resultPatterns"); 612 return results->size(); 613 } 614 615 DagNode Pattern::getResultPattern(unsigned index) const { 616 auto *results = def.getValueAsListInit("resultPatterns"); 617 return DagNode(cast<llvm::DagInit>(results->getElement(index))); 618 } 619 620 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) { 621 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); 622 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); 623 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); 624 625 LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n"); 626 infoMap.assignUniqueAlternativeNames(); 627 LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n"); 628 } 629 630 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { 631 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); 632 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { 633 auto pattern = getResultPattern(i); 634 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); 635 } 636 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); 637 } 638 639 const Operator &Pattern::getSourceRootOp() { 640 return getSourcePattern().getDialectOp(recordOpMap); 641 } 642 643 Operator &Pattern::getDialectOp(DagNode node) { 644 return node.getDialectOp(recordOpMap); 645 } 646 647 std::vector<AppliedConstraint> Pattern::getConstraints() const { 648 auto *listInit = def.getValueAsListInit("constraints"); 649 std::vector<AppliedConstraint> ret; 650 ret.reserve(listInit->size()); 651 652 for (auto *it : *listInit) { 653 auto *dagInit = dyn_cast<llvm::DagInit>(it); 654 if (!dagInit) 655 PrintFatalError(&def, "all elements in Pattern multi-entity " 656 "constraints should be DAG nodes"); 657 658 std::vector<std::string> entities; 659 entities.reserve(dagInit->arg_size()); 660 for (auto *argName : dagInit->getArgNames()) { 661 if (!argName) { 662 PrintFatalError( 663 &def, 664 "operands to additional constraints can only be symbol references"); 665 } 666 entities.emplace_back(argName->getValue()); 667 } 668 669 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(), 670 dagInit->getNameStr(), std::move(entities)); 671 } 672 return ret; 673 } 674 675 int Pattern::getBenefit() const { 676 // The initial benefit value is a heuristic with number of ops in the source 677 // pattern. 678 int initBenefit = getSourcePattern().getNumOps(); 679 llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); 680 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) { 681 PrintFatalError(&def, 682 "The 'addBenefit' takes and only takes one integer value"); 683 } 684 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); 685 } 686 687 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const { 688 std::vector<std::pair<StringRef, unsigned>> result; 689 result.reserve(def.getLoc().size()); 690 for (auto loc : def.getLoc()) { 691 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); 692 assert(buf && "invalid source location"); 693 result.emplace_back( 694 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), 695 llvm::SrcMgr.getLineAndColumn(loc, buf).first); 696 } 697 return result; 698 } 699 700 void Pattern::verifyBind(bool result, StringRef symbolName) { 701 if (!result) { 702 auto err = formatv("symbol '{0}' bound more than once", symbolName); 703 PrintFatalError(&def, err); 704 } 705 } 706 707 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 708 bool isSrcPattern) { 709 auto treeName = tree.getSymbol(); 710 auto numTreeArgs = tree.getNumArgs(); 711 712 if (tree.isNativeCodeCall()) { 713 if (!treeName.empty()) { 714 if (!isSrcPattern) { 715 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: " 716 << treeName << '\n'); 717 verifyBind( 718 infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()), 719 treeName); 720 } else { 721 PrintFatalError(&def, 722 formatv("binding symbol '{0}' to NativecodeCall in " 723 "MatchPattern is not supported", 724 treeName)); 725 } 726 } 727 728 for (int i = 0; i != numTreeArgs; ++i) { 729 if (auto treeArg = tree.getArgAsNestedDag(i)) { 730 // This DAG node argument is a DAG node itself. Go inside recursively. 731 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 732 continue; 733 } 734 735 if (!isSrcPattern) 736 continue; 737 738 // We can only bind symbols to arguments in source pattern. Those 739 // symbols are referenced in result patterns. 740 auto treeArgName = tree.getArgName(i); 741 742 // `$_` is a special symbol meaning ignore the current argument. 743 if (!treeArgName.empty() && treeArgName != "_") { 744 DagLeaf leaf = tree.getArgAsLeaf(i); 745 746 // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c), 747 if (leaf.isUnspecified()) { 748 // This is case of $c, a Value without any constraints. 749 verifyBind(infoMap.bindValue(treeArgName), treeArgName); 750 } else { 751 auto constraint = leaf.getAsConstraint(); 752 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || 753 leaf.isConstantAttr() || 754 constraint.getKind() == Constraint::Kind::CK_Attr; 755 756 if (isAttr) { 757 // This is case of $a, a binding to a certain attribute. 758 verifyBind(infoMap.bindAttr(treeArgName), treeArgName); 759 continue; 760 } 761 762 // This is case of $b, a binding to a certain type. 763 verifyBind(infoMap.bindValue(treeArgName), treeArgName); 764 } 765 } 766 } 767 768 return; 769 } 770 771 if (tree.isOperation()) { 772 auto &op = getDialectOp(tree); 773 auto numOpArgs = op.getNumArgs(); 774 int numEither = 0; 775 776 // We need to exclude the trailing directives and `either` directive groups 777 // two operands of the operation. 778 int numDirectives = 0; 779 for (int i = numTreeArgs - 1; i >= 0; --i) { 780 if (auto dagArg = tree.getArgAsNestedDag(i)) { 781 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective()) 782 ++numDirectives; 783 else if (dagArg.isEither()) 784 ++numEither; 785 } 786 } 787 788 if (numOpArgs != numTreeArgs - numDirectives + numEither) { 789 auto err = 790 formatv("op '{0}' argument number mismatch: " 791 "{1} in pattern vs. {2} in definition", 792 op.getOperationName(), numTreeArgs + numEither, numOpArgs); 793 PrintFatalError(&def, err); 794 } 795 796 // The name attached to the DAG node's operator is for representing the 797 // results generated from this op. It should be remembered as bound results. 798 if (!treeName.empty()) { 799 LLVM_DEBUG(llvm::dbgs() 800 << "found symbol bound to op result: " << treeName << '\n'); 801 verifyBind(infoMap.bindOpResult(treeName, op), treeName); 802 } 803 804 // The operand in `either` DAG should be bound to the operation in the 805 // parent DagNode. 806 auto collectSymbolInEither = [&](DagNode parent, DagNode tree, 807 int &opArgIdx) { 808 for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) { 809 if (DagNode subTree = tree.getArgAsNestedDag(i)) { 810 collectBoundSymbols(subTree, infoMap, isSrcPattern); 811 } else { 812 auto argName = tree.getArgName(i); 813 if (!argName.empty() && argName != "_") 814 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx), 815 argName); 816 } 817 } 818 }; 819 820 for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) { 821 if (auto treeArg = tree.getArgAsNestedDag(i)) { 822 if (treeArg.isEither()) { 823 collectSymbolInEither(tree, treeArg, opArgIdx); 824 } else { 825 // This DAG node argument is a DAG node itself. Go inside recursively. 826 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 827 } 828 continue; 829 } 830 831 if (isSrcPattern) { 832 // We can only bind symbols to op arguments in source pattern. Those 833 // symbols are referenced in result patterns. 834 auto treeArgName = tree.getArgName(i); 835 // `$_` is a special symbol meaning ignore the current argument. 836 if (!treeArgName.empty() && treeArgName != "_") { 837 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " 838 << treeArgName << '\n'); 839 verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx), 840 treeArgName); 841 } 842 } 843 } 844 return; 845 } 846 847 if (!treeName.empty()) { 848 PrintFatalError( 849 &def, formatv("binding symbol '{0}' to non-operation/native code call " 850 "unsupported right now", 851 treeName)); 852 } 853 } 854