xref: /llvm-project/mlir/lib/TableGen/Pattern.cpp (revision 26d513d197e14b824dd9d353aff38af1925c3770)
1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
25 
26 using namespace mlir;
27 using namespace tblgen;
28 
29 using llvm::DagInit;
30 using llvm::dbgs;
31 using llvm::DefInit;
32 using llvm::formatv;
33 using llvm::IntInit;
34 using llvm::Record;
35 
36 //===----------------------------------------------------------------------===//
37 // DagLeaf
38 //===----------------------------------------------------------------------===//
39 
40 bool DagLeaf::isUnspecified() const {
41   return isa_and_nonnull<llvm::UnsetInit>(def);
42 }
43 
44 bool DagLeaf::isOperandMatcher() const {
45   // Operand matchers specify a type constraint.
46   return isSubClassOf("TypeConstraint");
47 }
48 
49 bool DagLeaf::isAttrMatcher() const {
50   // Attribute matchers specify an attribute constraint.
51   return isSubClassOf("AttrConstraint");
52 }
53 
54 bool DagLeaf::isNativeCodeCall() const {
55   return isSubClassOf("NativeCodeCall");
56 }
57 
58 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
59 
60 bool DagLeaf::isEnumAttrCase() const {
61   return isSubClassOf("EnumAttrCaseInfo");
62 }
63 
64 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
65 
66 Constraint DagLeaf::getAsConstraint() const {
67   assert((isOperandMatcher() || isAttrMatcher()) &&
68          "the DAG leaf must be operand or attribute");
69   return Constraint(cast<DefInit>(def)->getDef());
70 }
71 
72 ConstantAttr DagLeaf::getAsConstantAttr() const {
73   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
74   return ConstantAttr(cast<DefInit>(def));
75 }
76 
77 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
78   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
79   return EnumAttrCase(cast<DefInit>(def));
80 }
81 
82 std::string DagLeaf::getConditionTemplate() const {
83   return getAsConstraint().getConditionTemplate();
84 }
85 
86 StringRef DagLeaf::getNativeCodeTemplate() const {
87   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
88   return cast<DefInit>(def)->getDef()->getValueAsString("expression");
89 }
90 
91 int DagLeaf::getNumReturnsOfNativeCode() const {
92   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
93   return cast<DefInit>(def)->getDef()->getValueAsInt("numReturns");
94 }
95 
96 std::string DagLeaf::getStringAttr() const {
97   assert(isStringAttr() && "the DAG leaf must be string attribute");
98   return def->getAsUnquotedString();
99 }
100 bool DagLeaf::isSubClassOf(StringRef superclass) const {
101   if (auto *defInit = dyn_cast_or_null<DefInit>(def))
102     return defInit->getDef()->isSubClassOf(superclass);
103   return false;
104 }
105 
106 void DagLeaf::print(raw_ostream &os) const {
107   if (def)
108     def->print(os);
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // DagNode
113 //===----------------------------------------------------------------------===//
114 
115 bool DagNode::isNativeCodeCall() const {
116   if (auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
117     return defInit->getDef()->isSubClassOf("NativeCodeCall");
118   return false;
119 }
120 
121 bool DagNode::isOperation() const {
122   return !isNativeCodeCall() && !isReplaceWithValue() &&
123          !isLocationDirective() && !isReturnTypeDirective() && !isEither() &&
124          !isVariadic();
125 }
126 
127 StringRef DagNode::getNativeCodeTemplate() const {
128   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
129   return cast<DefInit>(node->getOperator())
130       ->getDef()
131       ->getValueAsString("expression");
132 }
133 
134 int DagNode::getNumReturnsOfNativeCode() const {
135   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
136   return cast<DefInit>(node->getOperator())
137       ->getDef()
138       ->getValueAsInt("numReturns");
139 }
140 
141 StringRef DagNode::getSymbol() const { return node->getNameStr(); }
142 
143 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
144   const Record *opDef = cast<DefInit>(node->getOperator())->getDef();
145   auto [it, inserted] = mapper->try_emplace(opDef);
146   if (inserted)
147     it->second = std::make_unique<Operator>(opDef);
148   return *it->second;
149 }
150 
151 int DagNode::getNumOps() const {
152   // We want to get number of operations recursively involved in the DAG tree.
153   // All other directives should be excluded.
154   int count = isOperation() ? 1 : 0;
155   for (int i = 0, e = getNumArgs(); i != e; ++i) {
156     if (auto child = getArgAsNestedDag(i))
157       count += child.getNumOps();
158   }
159   return count;
160 }
161 
162 int DagNode::getNumArgs() const { return node->getNumArgs(); }
163 
164 bool DagNode::isNestedDagArg(unsigned index) const {
165   return isa<DagInit>(node->getArg(index));
166 }
167 
168 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
169   return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index)));
170 }
171 
172 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
173   assert(!isNestedDagArg(index));
174   return DagLeaf(node->getArg(index));
175 }
176 
177 StringRef DagNode::getArgName(unsigned index) const {
178   return node->getArgNameStr(index);
179 }
180 
181 bool DagNode::isReplaceWithValue() const {
182   auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
183   return dagOpDef->getName() == "replaceWithValue";
184 }
185 
186 bool DagNode::isLocationDirective() const {
187   auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
188   return dagOpDef->getName() == "location";
189 }
190 
191 bool DagNode::isReturnTypeDirective() const {
192   auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
193   return dagOpDef->getName() == "returnType";
194 }
195 
196 bool DagNode::isEither() const {
197   auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
198   return dagOpDef->getName() == "either";
199 }
200 
201 bool DagNode::isVariadic() const {
202   auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
203   return dagOpDef->getName() == "variadic";
204 }
205 
206 void DagNode::print(raw_ostream &os) const {
207   if (node)
208     node->print(os);
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // SymbolInfoMap
213 //===----------------------------------------------------------------------===//
214 
215 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
216   int idx = -1;
217   auto [name, indexStr] = symbol.rsplit("__");
218 
219   if (indexStr.consumeInteger(10, idx)) {
220     // The second part is not an index; we return the whole symbol as-is.
221     return symbol;
222   }
223   if (index) {
224     *index = idx;
225   }
226   return name;
227 }
228 
229 SymbolInfoMap::SymbolInfo::SymbolInfo(
230     const Operator *op, SymbolInfo::Kind kind,
231     std::optional<DagAndConstant> dagAndConstant)
232     : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
233 
234 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
235   switch (kind) {
236   case Kind::Attr:
237   case Kind::Operand:
238   case Kind::Value:
239     return 1;
240   case Kind::Result:
241     return op->getNumResults();
242   case Kind::MultipleValues:
243     return getSize();
244   }
245   llvm_unreachable("unknown kind");
246 }
247 
248 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
249   return alternativeName ? *alternativeName : name.str();
250 }
251 
252 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
253   LLVM_DEBUG(dbgs() << "getVarTypeStr for '" << name << "': ");
254   switch (kind) {
255   case Kind::Attr: {
256     if (op)
257       return cast<NamedAttribute *>(op->getArg(getArgIndex()))
258           ->attr.getStorageType()
259           .str();
260     // TODO(suderman): Use a more exact type when available.
261     return "::mlir::Attribute";
262   }
263   case Kind::Operand: {
264     // Use operand range for captured operands (to support potential variadic
265     // operands).
266     return "::mlir::Operation::operand_range";
267   }
268   case Kind::Value: {
269     return "::mlir::Value";
270   }
271   case Kind::MultipleValues: {
272     return "::mlir::ValueRange";
273   }
274   case Kind::Result: {
275     // Use the op itself for captured results.
276     return op->getQualCppClassName();
277   }
278   }
279   llvm_unreachable("unknown kind");
280 }
281 
282 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
283   LLVM_DEBUG(dbgs() << "getVarDecl for '" << name << "': ");
284   std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
285   return std::string(
286       formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
287 }
288 
289 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
290   LLVM_DEBUG(dbgs() << "getArgDecl for '" << name << "': ");
291   return std::string(
292       formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
293 }
294 
295 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
296     StringRef name, int index, const char *fmt, const char *separator) const {
297   LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '" << name << "': ");
298   switch (kind) {
299   case Kind::Attr: {
300     assert(index < 0);
301     auto repl = formatv(fmt, name);
302     LLVM_DEBUG(dbgs() << repl << " (Attr)\n");
303     return std::string(repl);
304   }
305   case Kind::Operand: {
306     assert(index < 0);
307     auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex()));
308     // If this operand is variadic and this SymbolInfo doesn't have a range
309     // index, then return the full variadic operand_range. Otherwise, return
310     // the value itself.
311     if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
312       auto repl = formatv(fmt, name);
313       LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n");
314       return std::string(repl);
315     }
316     auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
317     LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n");
318     return std::string(repl);
319   }
320   case Kind::Result: {
321     // If `index` is greater than zero, then we are referencing a specific
322     // result of a multi-result op. The result can still be variadic.
323     if (index >= 0) {
324       std::string v =
325           std::string(formatv("{0}.getODSResults({1})", name, index));
326       if (!op->getResult(index).isVariadic())
327         v = std::string(formatv("(*{0}.begin())", v));
328       auto repl = formatv(fmt, v);
329       LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
330       return std::string(repl);
331     }
332 
333     // If this op has no result at all but still we bind a symbol to it, it
334     // means we want to capture the op itself.
335     if (op->getNumResults() == 0) {
336       LLVM_DEBUG(dbgs() << name << " (Op)\n");
337       return formatv(fmt, name);
338     }
339 
340     // We are referencing all results of the multi-result op. A specific result
341     // can either be a value or a range. Then join them with `separator`.
342     SmallVector<std::string, 4> values;
343     values.reserve(op->getNumResults());
344 
345     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
346       std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
347       if (!op->getResult(i).isVariadic()) {
348         v = std::string(formatv("(*{0}.begin())", v));
349       }
350       values.push_back(std::string(formatv(fmt, v)));
351     }
352     auto repl = llvm::join(values, separator);
353     LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
354     return repl;
355   }
356   case Kind::Value: {
357     assert(index < 0);
358     assert(op == nullptr);
359     auto repl = formatv(fmt, name);
360     LLVM_DEBUG(dbgs() << repl << " (Value)\n");
361     return std::string(repl);
362   }
363   case Kind::MultipleValues: {
364     assert(op == nullptr);
365     assert(index < getSize());
366     if (index >= 0) {
367       std::string repl =
368           formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
369       LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
370       return repl;
371     }
372     // If it doesn't specify certain element, unpack them all.
373     auto repl =
374         formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
375     LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
376     return std::string(repl);
377   }
378   }
379   llvm_unreachable("unknown kind");
380 }
381 
382 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
383     StringRef name, int index, const char *fmt, const char *separator) const {
384   LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': ");
385   switch (kind) {
386   case Kind::Attr:
387   case Kind::Operand: {
388     assert(index < 0 && "only allowed for symbol bound to result");
389     auto repl = formatv(fmt, name);
390     LLVM_DEBUG(dbgs() << repl << " (Operand/Attr)\n");
391     return std::string(repl);
392   }
393   case Kind::Result: {
394     if (index >= 0) {
395       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
396       LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
397       return std::string(repl);
398     }
399 
400     // We are referencing all results of the multi-result op. Each result should
401     // have a value range, and then join them with `separator`.
402     SmallVector<std::string, 4> values;
403     values.reserve(op->getNumResults());
404 
405     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
406       values.push_back(std::string(
407           formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
408     }
409     auto repl = llvm::join(values, separator);
410     LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
411     return repl;
412   }
413   case Kind::Value: {
414     assert(index < 0 && "only allowed for symbol bound to result");
415     assert(op == nullptr);
416     auto repl = formatv(fmt, formatv("{{{0}}", name));
417     LLVM_DEBUG(dbgs() << repl << " (Value)\n");
418     return std::string(repl);
419   }
420   case Kind::MultipleValues: {
421     assert(op == nullptr);
422     assert(index < getSize());
423     if (index >= 0) {
424       std::string repl =
425           formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
426       LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
427       return repl;
428     }
429     auto repl =
430         formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
431     LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
432     return std::string(repl);
433   }
434   }
435   llvm_unreachable("unknown kind");
436 }
437 
438 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
439                                    const Operator &op, int argIndex,
440                                    std::optional<int> variadicSubIndex) {
441   StringRef name = getValuePackName(symbol);
442   if (name != symbol) {
443     auto error = formatv(
444         "symbol '{0}' with trailing index cannot bind to op argument", symbol);
445     PrintFatalError(loc, error);
446   }
447 
448   auto symInfo =
449       isa<NamedAttribute *>(op.getArg(argIndex))
450           ? SymbolInfo::getAttr(&op, argIndex)
451           : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
452 
453   std::string key = symbol.str();
454   if (symbolInfoMap.count(key)) {
455     // Only non unique name for the operand is supported.
456     if (symInfo.kind != SymbolInfo::Kind::Operand) {
457       return false;
458     }
459 
460     // Cannot add new operand if there is already non operand with the same
461     // name.
462     if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
463       return false;
464     }
465   }
466 
467   symbolInfoMap.emplace(key, symInfo);
468   return true;
469 }
470 
471 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
472   std::string name = getValuePackName(symbol).str();
473   auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
474 
475   return symbolInfoMap.count(inserted->first) == 1;
476 }
477 
478 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
479   std::string name = getValuePackName(symbol).str();
480   if (numValues > 1)
481     return bindMultipleValues(name, numValues);
482   return bindValue(name);
483 }
484 
485 bool SymbolInfoMap::bindValue(StringRef symbol) {
486   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
487   return symbolInfoMap.count(inserted->first) == 1;
488 }
489 
490 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
491   std::string name = getValuePackName(symbol).str();
492   auto inserted =
493       symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
494   return symbolInfoMap.count(inserted->first) == 1;
495 }
496 
497 bool SymbolInfoMap::bindAttr(StringRef symbol) {
498   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
499   return symbolInfoMap.count(inserted->first) == 1;
500 }
501 
502 bool SymbolInfoMap::contains(StringRef symbol) const {
503   return find(symbol) != symbolInfoMap.end();
504 }
505 
506 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
507   std::string name = getValuePackName(key).str();
508 
509   return symbolInfoMap.find(name);
510 }
511 
512 SymbolInfoMap::const_iterator
513 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
514                                int argIndex,
515                                std::optional<int> variadicSubIndex) const {
516   return findBoundSymbol(
517       key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
518 }
519 
520 SymbolInfoMap::const_iterator
521 SymbolInfoMap::findBoundSymbol(StringRef key,
522                                const SymbolInfo &symbolInfo) const {
523   std::string name = getValuePackName(key).str();
524   auto range = symbolInfoMap.equal_range(name);
525 
526   for (auto it = range.first; it != range.second; ++it)
527     if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
528       return it;
529 
530   return symbolInfoMap.end();
531 }
532 
533 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
534 SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
535   std::string name = getValuePackName(key).str();
536 
537   return symbolInfoMap.equal_range(name);
538 }
539 
540 int SymbolInfoMap::count(StringRef key) const {
541   std::string name = getValuePackName(key).str();
542   return symbolInfoMap.count(name);
543 }
544 
545 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
546   StringRef name = getValuePackName(symbol);
547   if (name != symbol) {
548     // If there is a trailing index inside symbol, it references just one
549     // static value.
550     return 1;
551   }
552   // Otherwise, find how many it represents by querying the symbol's info.
553   return find(name)->second.getStaticValueCount();
554 }
555 
556 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
557                                                const char *fmt,
558                                                const char *separator) const {
559   int index = -1;
560   StringRef name = getValuePackName(symbol, &index);
561 
562   auto it = symbolInfoMap.find(name.str());
563   if (it == symbolInfoMap.end()) {
564     auto error = formatv("referencing unbound symbol '{0}'", symbol);
565     PrintFatalError(loc, error);
566   }
567 
568   return it->second.getValueAndRangeUse(name, index, fmt, separator);
569 }
570 
571 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
572                                           const char *separator) const {
573   int index = -1;
574   StringRef name = getValuePackName(symbol, &index);
575 
576   auto it = symbolInfoMap.find(name.str());
577   if (it == symbolInfoMap.end()) {
578     auto error = formatv("referencing unbound symbol '{0}'", symbol);
579     PrintFatalError(loc, error);
580   }
581 
582   return it->second.getAllRangeUse(name, index, fmt, separator);
583 }
584 
585 void SymbolInfoMap::assignUniqueAlternativeNames() {
586   llvm::StringSet<> usedNames;
587 
588   for (auto symbolInfoIt = symbolInfoMap.begin();
589        symbolInfoIt != symbolInfoMap.end();) {
590     auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
591     auto startRange = range.first;
592     auto endRange = range.second;
593 
594     auto operandName = symbolInfoIt->first;
595     int startSearchIndex = 0;
596     for (++startRange; startRange != endRange; ++startRange) {
597       // Current operand name is not unique, find a unique one
598       // and set the alternative name.
599       for (int i = startSearchIndex;; ++i) {
600         std::string alternativeName = operandName + std::to_string(i);
601         if (!usedNames.contains(alternativeName) &&
602             symbolInfoMap.count(alternativeName) == 0) {
603           usedNames.insert(alternativeName);
604           startRange->second.alternativeName = alternativeName;
605           startSearchIndex = i + 1;
606 
607           break;
608         }
609       }
610     }
611 
612     symbolInfoIt = endRange;
613   }
614 }
615 
616 //===----------------------------------------------------------------------===//
617 // Pattern
618 //==----------------------------------------------------------------------===//
619 
620 Pattern::Pattern(const Record *def, RecordOperatorMap *mapper)
621     : def(*def), recordOpMap(mapper) {}
622 
623 DagNode Pattern::getSourcePattern() const {
624   return DagNode(def.getValueAsDag("sourcePattern"));
625 }
626 
627 int Pattern::getNumResultPatterns() const {
628   auto *results = def.getValueAsListInit("resultPatterns");
629   return results->size();
630 }
631 
632 DagNode Pattern::getResultPattern(unsigned index) const {
633   auto *results = def.getValueAsListInit("resultPatterns");
634   return DagNode(cast<DagInit>(results->getElement(index)));
635 }
636 
637 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
638   LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n");
639   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
640   LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n");
641 
642   LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n");
643   infoMap.assignUniqueAlternativeNames();
644   LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n");
645 }
646 
647 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
648   LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n");
649   for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
650     auto pattern = getResultPattern(i);
651     collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
652   }
653   LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n");
654 }
655 
656 const Operator &Pattern::getSourceRootOp() {
657   return getSourcePattern().getDialectOp(recordOpMap);
658 }
659 
660 Operator &Pattern::getDialectOp(DagNode node) {
661   return node.getDialectOp(recordOpMap);
662 }
663 
664 std::vector<AppliedConstraint> Pattern::getConstraints() const {
665   auto *listInit = def.getValueAsListInit("constraints");
666   std::vector<AppliedConstraint> ret;
667   ret.reserve(listInit->size());
668 
669   for (auto *it : *listInit) {
670     auto *dagInit = dyn_cast<DagInit>(it);
671     if (!dagInit)
672       PrintFatalError(&def, "all elements in Pattern multi-entity "
673                             "constraints should be DAG nodes");
674 
675     std::vector<std::string> entities;
676     entities.reserve(dagInit->arg_size());
677     for (auto *argName : dagInit->getArgNames()) {
678       if (!argName) {
679         PrintFatalError(
680             &def,
681             "operands to additional constraints can only be symbol references");
682       }
683       entities.emplace_back(argName->getValue());
684     }
685 
686     ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
687                      dagInit->getNameStr(), std::move(entities));
688   }
689   return ret;
690 }
691 
692 int Pattern::getNumSupplementalPatterns() const {
693   auto *results = def.getValueAsListInit("supplementalPatterns");
694   return results->size();
695 }
696 
697 DagNode Pattern::getSupplementalPattern(unsigned index) const {
698   auto *results = def.getValueAsListInit("supplementalPatterns");
699   return DagNode(cast<DagInit>(results->getElement(index)));
700 }
701 
702 int Pattern::getBenefit() const {
703   // The initial benefit value is a heuristic with number of ops in the source
704   // pattern.
705   int initBenefit = getSourcePattern().getNumOps();
706   const DagInit *delta = def.getValueAsDag("benefitDelta");
707   if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
708     PrintFatalError(&def,
709                     "The 'addBenefit' takes and only takes one integer value");
710   }
711   return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
712 }
713 
714 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
715   std::vector<std::pair<StringRef, unsigned>> result;
716   result.reserve(def.getLoc().size());
717   for (auto loc : def.getLoc()) {
718     unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
719     assert(buf && "invalid source location");
720     result.emplace_back(
721         llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
722         llvm::SrcMgr.getLineAndColumn(loc, buf).first);
723   }
724   return result;
725 }
726 
727 void Pattern::verifyBind(bool result, StringRef symbolName) {
728   if (!result) {
729     auto err = formatv("symbol '{0}' bound more than once", symbolName);
730     PrintFatalError(&def, err);
731   }
732 }
733 
734 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
735                                   bool isSrcPattern) {
736   auto treeName = tree.getSymbol();
737   auto numTreeArgs = tree.getNumArgs();
738 
739   if (tree.isNativeCodeCall()) {
740     if (!treeName.empty()) {
741       if (!isSrcPattern) {
742         LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: "
743                           << treeName << '\n');
744         verifyBind(
745             infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
746             treeName);
747       } else {
748         PrintFatalError(&def,
749                         formatv("binding symbol '{0}' to NativecodeCall in "
750                                 "MatchPattern is not supported",
751                                 treeName));
752       }
753     }
754 
755     for (int i = 0; i != numTreeArgs; ++i) {
756       if (auto treeArg = tree.getArgAsNestedDag(i)) {
757         // This DAG node argument is a DAG node itself. Go inside recursively.
758         collectBoundSymbols(treeArg, infoMap, isSrcPattern);
759         continue;
760       }
761 
762       if (!isSrcPattern)
763         continue;
764 
765       // We can only bind symbols to arguments in source pattern. Those
766       // symbols are referenced in result patterns.
767       auto treeArgName = tree.getArgName(i);
768 
769       // `$_` is a special symbol meaning ignore the current argument.
770       if (!treeArgName.empty() && treeArgName != "_") {
771         DagLeaf leaf = tree.getArgAsLeaf(i);
772 
773         // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
774         if (leaf.isUnspecified()) {
775           // This is case of $c, a Value without any constraints.
776           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
777         } else {
778           auto constraint = leaf.getAsConstraint();
779           bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
780                         leaf.isConstantAttr() ||
781                         constraint.getKind() == Constraint::Kind::CK_Attr;
782 
783           if (isAttr) {
784             // This is case of $a, a binding to a certain attribute.
785             verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
786             continue;
787           }
788 
789           // This is case of $b, a binding to a certain type.
790           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
791         }
792       }
793     }
794 
795     return;
796   }
797 
798   if (tree.isOperation()) {
799     auto &op = getDialectOp(tree);
800     auto numOpArgs = op.getNumArgs();
801     int numEither = 0;
802 
803     // We need to exclude the trailing directives and `either` directive groups
804     // two operands of the operation.
805     int numDirectives = 0;
806     for (int i = numTreeArgs - 1; i >= 0; --i) {
807       if (auto dagArg = tree.getArgAsNestedDag(i)) {
808         if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
809           ++numDirectives;
810         else if (dagArg.isEither())
811           ++numEither;
812       }
813     }
814 
815     if (numOpArgs != numTreeArgs - numDirectives + numEither) {
816       auto err =
817           formatv("op '{0}' argument number mismatch: "
818                   "{1} in pattern vs. {2} in definition",
819                   op.getOperationName(), numTreeArgs + numEither, numOpArgs);
820       PrintFatalError(&def, err);
821     }
822 
823     // The name attached to the DAG node's operator is for representing the
824     // results generated from this op. It should be remembered as bound results.
825     if (!treeName.empty()) {
826       LLVM_DEBUG(dbgs() << "found symbol bound to op result: " << treeName
827                         << '\n');
828       verifyBind(infoMap.bindOpResult(treeName, op), treeName);
829     }
830 
831     // The operand in `either` DAG should be bound to the operation in the
832     // parent DagNode.
833     auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
834                                      int opArgIdx) {
835       for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
836         if (DagNode subTree = tree.getArgAsNestedDag(i)) {
837           collectBoundSymbols(subTree, infoMap, isSrcPattern);
838         } else {
839           auto argName = tree.getArgName(i);
840           if (!argName.empty() && argName != "_") {
841             verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
842                        argName);
843           }
844         }
845       }
846     };
847 
848     // The operand in `variadic` DAG should be bound to the operation in the
849     // parent DagNode. The range index must be included as well to distinguish
850     // (potentially) repeating argName within the `variadic` DAG.
851     auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
852                                        int opArgIdx) {
853       auto treeName = tree.getSymbol();
854       if (!treeName.empty()) {
855         // If treeName is specified, bind to the full variadic operand_range.
856         verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
857                                           std::nullopt),
858                    treeName);
859       }
860 
861       for (int i = 0; i < tree.getNumArgs(); ++i) {
862         if (DagNode subTree = tree.getArgAsNestedDag(i)) {
863           collectBoundSymbols(subTree, infoMap, isSrcPattern);
864         } else {
865           auto argName = tree.getArgName(i);
866           if (!argName.empty() && argName != "_") {
867             verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
868                                               /*variadicSubIndex=*/i),
869                        argName);
870           }
871         }
872       }
873     };
874 
875     for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
876       if (auto treeArg = tree.getArgAsNestedDag(i)) {
877         if (treeArg.isEither()) {
878           collectSymbolInEither(tree, treeArg, opArgIdx);
879           // `either` DAG is *flattened*. For example,
880           //
881           //  (FooOp (either arg0, arg1), arg2)
882           //
883           //  can be viewed as:
884           //
885           //  (FooOp arg0, arg1, arg2)
886           ++opArgIdx;
887         } else if (treeArg.isVariadic()) {
888           collectSymbolInVariadic(tree, treeArg, opArgIdx);
889         } else {
890           // This DAG node argument is a DAG node itself. Go inside recursively.
891           collectBoundSymbols(treeArg, infoMap, isSrcPattern);
892         }
893         continue;
894       }
895 
896       if (isSrcPattern) {
897         // We can only bind symbols to op arguments in source pattern. Those
898         // symbols are referenced in result patterns.
899         auto treeArgName = tree.getArgName(i);
900         // `$_` is a special symbol meaning ignore the current argument.
901         if (!treeArgName.empty() && treeArgName != "_") {
902           LLVM_DEBUG(dbgs() << "found symbol bound to op argument: "
903                             << treeArgName << '\n');
904           verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
905                      treeArgName);
906         }
907       }
908     }
909     return;
910   }
911 
912   if (!treeName.empty()) {
913     PrintFatalError(
914         &def, formatv("binding symbol '{0}' to non-operation/native code call "
915                       "unsupported right now",
916                       treeName));
917   }
918 }
919