xref: /llvm-project/mlir/lib/TableGen/Pattern.cpp (revision e5639b3fa45f68c53e8e676749b0f1aacd16b41d)
1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Pattern.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 #define DEBUG_TYPE "mlir-tblgen-pattern"
23 
24 using namespace mlir;
25 using namespace tblgen;
26 
27 using llvm::formatv;
28 
29 //===----------------------------------------------------------------------===//
30 // DagLeaf
31 //===----------------------------------------------------------------------===//
32 
33 bool DagLeaf::isUnspecified() const {
34   return dyn_cast_or_null<llvm::UnsetInit>(def);
35 }
36 
37 bool DagLeaf::isOperandMatcher() const {
38   // Operand matchers specify a type constraint.
39   return isSubClassOf("TypeConstraint");
40 }
41 
42 bool DagLeaf::isAttrMatcher() const {
43   // Attribute matchers specify an attribute constraint.
44   return isSubClassOf("AttrConstraint");
45 }
46 
47 bool DagLeaf::isNativeCodeCall() const {
48   return isSubClassOf("NativeCodeCall");
49 }
50 
51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
52 
53 bool DagLeaf::isEnumAttrCase() const {
54   return isSubClassOf("EnumAttrCaseInfo");
55 }
56 
57 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
58 
59 Constraint DagLeaf::getAsConstraint() const {
60   assert((isOperandMatcher() || isAttrMatcher()) &&
61          "the DAG leaf must be operand or attribute");
62   return Constraint(cast<llvm::DefInit>(def)->getDef());
63 }
64 
65 ConstantAttr DagLeaf::getAsConstantAttr() const {
66   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
67   return ConstantAttr(cast<llvm::DefInit>(def));
68 }
69 
70 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
71   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
72   return EnumAttrCase(cast<llvm::DefInit>(def));
73 }
74 
75 std::string DagLeaf::getConditionTemplate() const {
76   return getAsConstraint().getConditionTemplate();
77 }
78 
79 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
80   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
81   return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
82 }
83 
84 int DagLeaf::getNumReturnsOfNativeCode() const {
85   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
86   return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
87 }
88 
89 std::string DagLeaf::getStringAttr() const {
90   assert(isStringAttr() && "the DAG leaf must be string attribute");
91   return def->getAsUnquotedString();
92 }
93 bool DagLeaf::isSubClassOf(StringRef superclass) const {
94   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
95     return defInit->getDef()->isSubClassOf(superclass);
96   return false;
97 }
98 
99 void DagLeaf::print(raw_ostream &os) const {
100   if (def)
101     def->print(os);
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // DagNode
106 //===----------------------------------------------------------------------===//
107 
108 bool DagNode::isNativeCodeCall() const {
109   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
110     return defInit->getDef()->isSubClassOf("NativeCodeCall");
111   return false;
112 }
113 
114 bool DagNode::isOperation() const {
115   return !isNativeCodeCall() && !isReplaceWithValue() &&
116          !isLocationDirective() && !isReturnTypeDirective() && !isEither();
117 }
118 
119 llvm::StringRef DagNode::getNativeCodeTemplate() const {
120   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
121   return cast<llvm::DefInit>(node->getOperator())
122       ->getDef()
123       ->getValueAsString("expression");
124 }
125 
126 int DagNode::getNumReturnsOfNativeCode() const {
127   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
128   return cast<llvm::DefInit>(node->getOperator())
129       ->getDef()
130       ->getValueAsInt("numReturns");
131 }
132 
133 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
134 
135 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
136   llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
137   auto it = mapper->find(opDef);
138   if (it != mapper->end())
139     return *it->second;
140   return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
141               .first->second;
142 }
143 
144 int DagNode::getNumOps() const {
145   // We want to get number of operations recursively involved in the DAG tree.
146   // All other directives should be excluded.
147   int count = isOperation() ? 1 : 0;
148   for (int i = 0, e = getNumArgs(); i != e; ++i) {
149     if (auto child = getArgAsNestedDag(i))
150       count += child.getNumOps();
151   }
152   return count;
153 }
154 
155 int DagNode::getNumArgs() const { return node->getNumArgs(); }
156 
157 bool DagNode::isNestedDagArg(unsigned index) const {
158   return isa<llvm::DagInit>(node->getArg(index));
159 }
160 
161 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
162   return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
163 }
164 
165 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
166   assert(!isNestedDagArg(index));
167   return DagLeaf(node->getArg(index));
168 }
169 
170 StringRef DagNode::getArgName(unsigned index) const {
171   return node->getArgNameStr(index);
172 }
173 
174 bool DagNode::isReplaceWithValue() const {
175   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
176   return dagOpDef->getName() == "replaceWithValue";
177 }
178 
179 bool DagNode::isLocationDirective() const {
180   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
181   return dagOpDef->getName() == "location";
182 }
183 
184 bool DagNode::isReturnTypeDirective() const {
185   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
186   return dagOpDef->getName() == "returnType";
187 }
188 
189 bool DagNode::isEither() const {
190   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
191   return dagOpDef->getName() == "either";
192 }
193 
194 void DagNode::print(raw_ostream &os) const {
195   if (node)
196     node->print(os);
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // SymbolInfoMap
201 //===----------------------------------------------------------------------===//
202 
203 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
204   StringRef name, indexStr;
205   int idx = -1;
206   std::tie(name, indexStr) = symbol.rsplit("__");
207 
208   if (indexStr.consumeInteger(10, idx)) {
209     // The second part is not an index; we return the whole symbol as-is.
210     return symbol;
211   }
212   if (index) {
213     *index = idx;
214   }
215   return name;
216 }
217 
218 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
219                                       Optional<DagAndConstant> dagAndConstant)
220     : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
221 
222 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
223   switch (kind) {
224   case Kind::Attr:
225   case Kind::Operand:
226   case Kind::Value:
227     return 1;
228   case Kind::Result:
229     return op->getNumResults();
230   case Kind::MultipleValues:
231     return getSize();
232   }
233   llvm_unreachable("unknown kind");
234 }
235 
236 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
237   return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
238 }
239 
240 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
241   LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': ");
242   switch (kind) {
243   case Kind::Attr: {
244     if (op)
245       return op->getArg(getArgIndex())
246           .get<NamedAttribute *>()
247           ->attr.getStorageType()
248           .str();
249     // TODO(suderman): Use a more exact type when available.
250     return "Attribute";
251   }
252   case Kind::Operand: {
253     // Use operand range for captured operands (to support potential variadic
254     // operands).
255     return "::mlir::Operation::operand_range";
256   }
257   case Kind::Value: {
258     return "::mlir::Value";
259   }
260   case Kind::MultipleValues: {
261     return "::mlir::ValueRange";
262   }
263   case Kind::Result: {
264     // Use the op itself for captured results.
265     return op->getQualCppClassName();
266   }
267   }
268   llvm_unreachable("unknown kind");
269 }
270 
271 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
272   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
273   std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
274   return std::string(
275       formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
276 }
277 
278 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
279   LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': ");
280   return std::string(
281       formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
282 }
283 
284 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
285     StringRef name, int index, const char *fmt, const char *separator) const {
286   LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
287   switch (kind) {
288   case Kind::Attr: {
289     assert(index < 0);
290     auto repl = formatv(fmt, name);
291     LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
292     return std::string(repl);
293   }
294   case Kind::Operand: {
295     assert(index < 0);
296     auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
297     // If this operand is variadic, then return a range. Otherwise, return the
298     // value itself.
299     if (operand->isVariableLength()) {
300       auto repl = formatv(fmt, name);
301       LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
302       return std::string(repl);
303     }
304     auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
305     LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
306     return std::string(repl);
307   }
308   case Kind::Result: {
309     // If `index` is greater than zero, then we are referencing a specific
310     // result of a multi-result op. The result can still be variadic.
311     if (index >= 0) {
312       std::string v =
313           std::string(formatv("{0}.getODSResults({1})", name, index));
314       if (!op->getResult(index).isVariadic())
315         v = std::string(formatv("(*{0}.begin())", v));
316       auto repl = formatv(fmt, v);
317       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
318       return std::string(repl);
319     }
320 
321     // If this op has no result at all but still we bind a symbol to it, it
322     // means we want to capture the op itself.
323     if (op->getNumResults() == 0) {
324       LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
325       return std::string(name);
326     }
327 
328     // We are referencing all results of the multi-result op. A specific result
329     // can either be a value or a range. Then join them with `separator`.
330     SmallVector<std::string, 4> values;
331     values.reserve(op->getNumResults());
332 
333     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
334       std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
335       if (!op->getResult(i).isVariadic()) {
336         v = std::string(formatv("(*{0}.begin())", v));
337       }
338       values.push_back(std::string(formatv(fmt, v)));
339     }
340     auto repl = llvm::join(values, separator);
341     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
342     return repl;
343   }
344   case Kind::Value: {
345     assert(index < 0);
346     assert(op == nullptr);
347     auto repl = formatv(fmt, name);
348     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
349     return std::string(repl);
350   }
351   case Kind::MultipleValues: {
352     assert(op == nullptr);
353     assert(index < getSize());
354     if (index >= 0) {
355       std::string repl =
356           formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
357       LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
358       return repl;
359     }
360     // If it doesn't specify certain element, unpack them all.
361     auto repl =
362         formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
363     LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
364     return std::string(repl);
365   }
366   }
367   llvm_unreachable("unknown kind");
368 }
369 
370 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
371     StringRef name, int index, const char *fmt, const char *separator) const {
372   LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
373   switch (kind) {
374   case Kind::Attr:
375   case Kind::Operand: {
376     assert(index < 0 && "only allowed for symbol bound to result");
377     auto repl = formatv(fmt, name);
378     LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
379     return std::string(repl);
380   }
381   case Kind::Result: {
382     if (index >= 0) {
383       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
384       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
385       return std::string(repl);
386     }
387 
388     // We are referencing all results of the multi-result op. Each result should
389     // have a value range, and then join them with `separator`.
390     SmallVector<std::string, 4> values;
391     values.reserve(op->getNumResults());
392 
393     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
394       values.push_back(std::string(
395           formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
396     }
397     auto repl = llvm::join(values, separator);
398     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
399     return repl;
400   }
401   case Kind::Value: {
402     assert(index < 0 && "only allowed for symbol bound to result");
403     assert(op == nullptr);
404     auto repl = formatv(fmt, formatv("{{{0}}", name));
405     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
406     return std::string(repl);
407   }
408   case Kind::MultipleValues: {
409     assert(op == nullptr);
410     assert(index < getSize());
411     if (index >= 0) {
412       std::string repl =
413           formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
414       LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
415       return repl;
416     }
417     auto repl =
418         formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
419     LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
420     return std::string(repl);
421   }
422   }
423   llvm_unreachable("unknown kind");
424 }
425 
426 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
427                                    const Operator &op, int argIndex) {
428   StringRef name = getValuePackName(symbol);
429   if (name != symbol) {
430     auto error = formatv(
431         "symbol '{0}' with trailing index cannot bind to op argument", symbol);
432     PrintFatalError(loc, error);
433   }
434 
435   auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
436                      ? SymbolInfo::getAttr(&op, argIndex)
437                      : SymbolInfo::getOperand(node, &op, argIndex);
438 
439   std::string key = symbol.str();
440   if (symbolInfoMap.count(key)) {
441     // Only non unique name for the operand is supported.
442     if (symInfo.kind != SymbolInfo::Kind::Operand) {
443       return false;
444     }
445 
446     // Cannot add new operand if there is already non operand with the same
447     // name.
448     if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
449       return false;
450     }
451   }
452 
453   symbolInfoMap.emplace(key, symInfo);
454   return true;
455 }
456 
457 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
458   std::string name = getValuePackName(symbol).str();
459   auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
460 
461   return symbolInfoMap.count(inserted->first) == 1;
462 }
463 
464 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
465   std::string name = getValuePackName(symbol).str();
466   if (numValues > 1)
467     return bindMultipleValues(name, numValues);
468   return bindValue(name);
469 }
470 
471 bool SymbolInfoMap::bindValue(StringRef symbol) {
472   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
473   return symbolInfoMap.count(inserted->first) == 1;
474 }
475 
476 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
477   std::string name = getValuePackName(symbol).str();
478   auto inserted =
479       symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
480   return symbolInfoMap.count(inserted->first) == 1;
481 }
482 
483 bool SymbolInfoMap::bindAttr(StringRef symbol) {
484   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
485   return symbolInfoMap.count(inserted->first) == 1;
486 }
487 
488 bool SymbolInfoMap::contains(StringRef symbol) const {
489   return find(symbol) != symbolInfoMap.end();
490 }
491 
492 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
493   std::string name = getValuePackName(key).str();
494 
495   return symbolInfoMap.find(name);
496 }
497 
498 SymbolInfoMap::const_iterator
499 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
500                                int argIndex) const {
501   return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex));
502 }
503 
504 SymbolInfoMap::const_iterator
505 SymbolInfoMap::findBoundSymbol(StringRef key, SymbolInfo symbolInfo) const {
506   std::string name = getValuePackName(key).str();
507   auto range = symbolInfoMap.equal_range(name);
508 
509   for (auto it = range.first; it != range.second; ++it)
510     if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
511       return it;
512 
513   return symbolInfoMap.end();
514 }
515 
516 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
517 SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
518   std::string name = getValuePackName(key).str();
519 
520   return symbolInfoMap.equal_range(name);
521 }
522 
523 int SymbolInfoMap::count(StringRef key) const {
524   std::string name = getValuePackName(key).str();
525   return symbolInfoMap.count(name);
526 }
527 
528 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
529   StringRef name = getValuePackName(symbol);
530   if (name != symbol) {
531     // If there is a trailing index inside symbol, it references just one
532     // static value.
533     return 1;
534   }
535   // Otherwise, find how many it represents by querying the symbol's info.
536   return find(name)->second.getStaticValueCount();
537 }
538 
539 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
540                                                const char *fmt,
541                                                const char *separator) const {
542   int index = -1;
543   StringRef name = getValuePackName(symbol, &index);
544 
545   auto it = symbolInfoMap.find(name.str());
546   if (it == symbolInfoMap.end()) {
547     auto error = formatv("referencing unbound symbol '{0}'", symbol);
548     PrintFatalError(loc, error);
549   }
550 
551   return it->second.getValueAndRangeUse(name, index, fmt, separator);
552 }
553 
554 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
555                                           const char *separator) const {
556   int index = -1;
557   StringRef name = getValuePackName(symbol, &index);
558 
559   auto it = symbolInfoMap.find(name.str());
560   if (it == symbolInfoMap.end()) {
561     auto error = formatv("referencing unbound symbol '{0}'", symbol);
562     PrintFatalError(loc, error);
563   }
564 
565   return it->second.getAllRangeUse(name, index, fmt, separator);
566 }
567 
568 void SymbolInfoMap::assignUniqueAlternativeNames() {
569   llvm::StringSet<> usedNames;
570 
571   for (auto symbolInfoIt = symbolInfoMap.begin();
572        symbolInfoIt != symbolInfoMap.end();) {
573     auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
574     auto startRange = range.first;
575     auto endRange = range.second;
576 
577     auto operandName = symbolInfoIt->first;
578     int startSearchIndex = 0;
579     for (++startRange; startRange != endRange; ++startRange) {
580       // Current operand name is not unique, find a unique one
581       // and set the alternative name.
582       for (int i = startSearchIndex;; ++i) {
583         std::string alternativeName = operandName + std::to_string(i);
584         if (!usedNames.contains(alternativeName) &&
585             symbolInfoMap.count(alternativeName) == 0) {
586           usedNames.insert(alternativeName);
587           startRange->second.alternativeName = alternativeName;
588           startSearchIndex = i + 1;
589 
590           break;
591         }
592       }
593     }
594 
595     symbolInfoIt = endRange;
596   }
597 }
598 
599 //===----------------------------------------------------------------------===//
600 // Pattern
601 //==----------------------------------------------------------------------===//
602 
603 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
604     : def(*def), recordOpMap(mapper) {}
605 
606 DagNode Pattern::getSourcePattern() const {
607   return DagNode(def.getValueAsDag("sourcePattern"));
608 }
609 
610 int Pattern::getNumResultPatterns() const {
611   auto *results = def.getValueAsListInit("resultPatterns");
612   return results->size();
613 }
614 
615 DagNode Pattern::getResultPattern(unsigned index) const {
616   auto *results = def.getValueAsListInit("resultPatterns");
617   return DagNode(cast<llvm::DagInit>(results->getElement(index)));
618 }
619 
620 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
621   LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
622   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
623   LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
624 
625   LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
626   infoMap.assignUniqueAlternativeNames();
627   LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
628 }
629 
630 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
631   LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
632   for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
633     auto pattern = getResultPattern(i);
634     collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
635   }
636   LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
637 }
638 
639 const Operator &Pattern::getSourceRootOp() {
640   return getSourcePattern().getDialectOp(recordOpMap);
641 }
642 
643 Operator &Pattern::getDialectOp(DagNode node) {
644   return node.getDialectOp(recordOpMap);
645 }
646 
647 std::vector<AppliedConstraint> Pattern::getConstraints() const {
648   auto *listInit = def.getValueAsListInit("constraints");
649   std::vector<AppliedConstraint> ret;
650   ret.reserve(listInit->size());
651 
652   for (auto *it : *listInit) {
653     auto *dagInit = dyn_cast<llvm::DagInit>(it);
654     if (!dagInit)
655       PrintFatalError(&def, "all elements in Pattern multi-entity "
656                             "constraints should be DAG nodes");
657 
658     std::vector<std::string> entities;
659     entities.reserve(dagInit->arg_size());
660     for (auto *argName : dagInit->getArgNames()) {
661       if (!argName) {
662         PrintFatalError(
663             &def,
664             "operands to additional constraints can only be symbol references");
665       }
666       entities.emplace_back(argName->getValue());
667     }
668 
669     ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
670                      dagInit->getNameStr(), std::move(entities));
671   }
672   return ret;
673 }
674 
675 int Pattern::getBenefit() const {
676   // The initial benefit value is a heuristic with number of ops in the source
677   // pattern.
678   int initBenefit = getSourcePattern().getNumOps();
679   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
680   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
681     PrintFatalError(&def,
682                     "The 'addBenefit' takes and only takes one integer value");
683   }
684   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
685 }
686 
687 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
688   std::vector<std::pair<StringRef, unsigned>> result;
689   result.reserve(def.getLoc().size());
690   for (auto loc : def.getLoc()) {
691     unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
692     assert(buf && "invalid source location");
693     result.emplace_back(
694         llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
695         llvm::SrcMgr.getLineAndColumn(loc, buf).first);
696   }
697   return result;
698 }
699 
700 void Pattern::verifyBind(bool result, StringRef symbolName) {
701   if (!result) {
702     auto err = formatv("symbol '{0}' bound more than once", symbolName);
703     PrintFatalError(&def, err);
704   }
705 }
706 
707 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
708                                   bool isSrcPattern) {
709   auto treeName = tree.getSymbol();
710   auto numTreeArgs = tree.getNumArgs();
711 
712   if (tree.isNativeCodeCall()) {
713     if (!treeName.empty()) {
714       if (!isSrcPattern) {
715         LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
716                                 << treeName << '\n');
717         verifyBind(
718             infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
719             treeName);
720       } else {
721         PrintFatalError(&def,
722                         formatv("binding symbol '{0}' to NativecodeCall in "
723                                 "MatchPattern is not supported",
724                                 treeName));
725       }
726     }
727 
728     for (int i = 0; i != numTreeArgs; ++i) {
729       if (auto treeArg = tree.getArgAsNestedDag(i)) {
730         // This DAG node argument is a DAG node itself. Go inside recursively.
731         collectBoundSymbols(treeArg, infoMap, isSrcPattern);
732         continue;
733       }
734 
735       if (!isSrcPattern)
736         continue;
737 
738       // We can only bind symbols to arguments in source pattern. Those
739       // symbols are referenced in result patterns.
740       auto treeArgName = tree.getArgName(i);
741 
742       // `$_` is a special symbol meaning ignore the current argument.
743       if (!treeArgName.empty() && treeArgName != "_") {
744         DagLeaf leaf = tree.getArgAsLeaf(i);
745 
746         // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
747         if (leaf.isUnspecified()) {
748           // This is case of $c, a Value without any constraints.
749           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
750         } else {
751           auto constraint = leaf.getAsConstraint();
752           bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
753                         leaf.isConstantAttr() ||
754                         constraint.getKind() == Constraint::Kind::CK_Attr;
755 
756           if (isAttr) {
757             // This is case of $a, a binding to a certain attribute.
758             verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
759             continue;
760           }
761 
762           // This is case of $b, a binding to a certain type.
763           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
764         }
765       }
766     }
767 
768     return;
769   }
770 
771   if (tree.isOperation()) {
772     auto &op = getDialectOp(tree);
773     auto numOpArgs = op.getNumArgs();
774     int numEither = 0;
775 
776     // We need to exclude the trailing directives and `either` directive groups
777     // two operands of the operation.
778     int numDirectives = 0;
779     for (int i = numTreeArgs - 1; i >= 0; --i) {
780       if (auto dagArg = tree.getArgAsNestedDag(i)) {
781         if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
782           ++numDirectives;
783         else if (dagArg.isEither())
784           ++numEither;
785       }
786     }
787 
788     if (numOpArgs != numTreeArgs - numDirectives + numEither) {
789       auto err =
790           formatv("op '{0}' argument number mismatch: "
791                   "{1} in pattern vs. {2} in definition",
792                   op.getOperationName(), numTreeArgs + numEither, numOpArgs);
793       PrintFatalError(&def, err);
794     }
795 
796     // The name attached to the DAG node's operator is for representing the
797     // results generated from this op. It should be remembered as bound results.
798     if (!treeName.empty()) {
799       LLVM_DEBUG(llvm::dbgs()
800                  << "found symbol bound to op result: " << treeName << '\n');
801       verifyBind(infoMap.bindOpResult(treeName, op), treeName);
802     }
803 
804     // The operand in `either` DAG should be bound to the operation in the
805     // parent DagNode.
806     auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
807                                      int &opArgIdx) {
808       for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
809         if (DagNode subTree = tree.getArgAsNestedDag(i)) {
810           collectBoundSymbols(subTree, infoMap, isSrcPattern);
811         } else {
812           auto argName = tree.getArgName(i);
813           if (!argName.empty() && argName != "_")
814             verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
815                        argName);
816         }
817       }
818     };
819 
820     for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
821       if (auto treeArg = tree.getArgAsNestedDag(i)) {
822         if (treeArg.isEither()) {
823           collectSymbolInEither(tree, treeArg, opArgIdx);
824         } else {
825           // This DAG node argument is a DAG node itself. Go inside recursively.
826           collectBoundSymbols(treeArg, infoMap, isSrcPattern);
827         }
828         continue;
829       }
830 
831       if (isSrcPattern) {
832         // We can only bind symbols to op arguments in source pattern. Those
833         // symbols are referenced in result patterns.
834         auto treeArgName = tree.getArgName(i);
835         // `$_` is a special symbol meaning ignore the current argument.
836         if (!treeArgName.empty() && treeArgName != "_") {
837           LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
838                                   << treeArgName << '\n');
839           verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
840                      treeArgName);
841         }
842       }
843     }
844     return;
845   }
846 
847   if (!treeName.empty()) {
848     PrintFatalError(
849         &def, formatv("binding symbol '{0}' to non-operation/native code call "
850                       "unsupported right now",
851                       treeName));
852   }
853 }
854