xref: /llvm-project/mlir/lib/TableGen/Pattern.cpp (revision 1f97219a9099b56c18d2e339db8b4aa7964825fd)
1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
25 
26 using namespace mlir;
27 using namespace tblgen;
28 
29 using llvm::formatv;
30 
31 //===----------------------------------------------------------------------===//
32 // DagLeaf
33 //===----------------------------------------------------------------------===//
34 
35 bool DagLeaf::isUnspecified() const {
36   return isa_and_nonnull<llvm::UnsetInit>(def);
37 }
38 
39 bool DagLeaf::isOperandMatcher() const {
40   // Operand matchers specify a type constraint.
41   return isSubClassOf("TypeConstraint");
42 }
43 
44 bool DagLeaf::isAttrMatcher() const {
45   // Attribute matchers specify an attribute constraint.
46   return isSubClassOf("AttrConstraint");
47 }
48 
49 bool DagLeaf::isNativeCodeCall() const {
50   return isSubClassOf("NativeCodeCall");
51 }
52 
53 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
54 
55 bool DagLeaf::isEnumAttrCase() const {
56   return isSubClassOf("EnumAttrCaseInfo");
57 }
58 
59 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
60 
61 Constraint DagLeaf::getAsConstraint() const {
62   assert((isOperandMatcher() || isAttrMatcher()) &&
63          "the DAG leaf must be operand or attribute");
64   return Constraint(cast<llvm::DefInit>(def)->getDef());
65 }
66 
67 ConstantAttr DagLeaf::getAsConstantAttr() const {
68   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
69   return ConstantAttr(cast<llvm::DefInit>(def));
70 }
71 
72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
73   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
74   return EnumAttrCase(cast<llvm::DefInit>(def));
75 }
76 
77 std::string DagLeaf::getConditionTemplate() const {
78   return getAsConstraint().getConditionTemplate();
79 }
80 
81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
82   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
83   return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
84 }
85 
86 int DagLeaf::getNumReturnsOfNativeCode() const {
87   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
88   return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
89 }
90 
91 std::string DagLeaf::getStringAttr() const {
92   assert(isStringAttr() && "the DAG leaf must be string attribute");
93   return def->getAsUnquotedString();
94 }
95 bool DagLeaf::isSubClassOf(StringRef superclass) const {
96   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
97     return defInit->getDef()->isSubClassOf(superclass);
98   return false;
99 }
100 
101 void DagLeaf::print(raw_ostream &os) const {
102   if (def)
103     def->print(os);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // DagNode
108 //===----------------------------------------------------------------------===//
109 
110 bool DagNode::isNativeCodeCall() const {
111   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
112     return defInit->getDef()->isSubClassOf("NativeCodeCall");
113   return false;
114 }
115 
116 bool DagNode::isOperation() const {
117   return !isNativeCodeCall() && !isReplaceWithValue() &&
118          !isLocationDirective() && !isReturnTypeDirective() && !isEither() &&
119          !isVariadic();
120 }
121 
122 llvm::StringRef DagNode::getNativeCodeTemplate() const {
123   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
124   return cast<llvm::DefInit>(node->getOperator())
125       ->getDef()
126       ->getValueAsString("expression");
127 }
128 
129 int DagNode::getNumReturnsOfNativeCode() const {
130   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
131   return cast<llvm::DefInit>(node->getOperator())
132       ->getDef()
133       ->getValueAsInt("numReturns");
134 }
135 
136 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
137 
138 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
139   llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
140   auto it = mapper->find(opDef);
141   if (it != mapper->end())
142     return *it->second;
143   return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
144               .first->second;
145 }
146 
147 int DagNode::getNumOps() const {
148   // We want to get number of operations recursively involved in the DAG tree.
149   // All other directives should be excluded.
150   int count = isOperation() ? 1 : 0;
151   for (int i = 0, e = getNumArgs(); i != e; ++i) {
152     if (auto child = getArgAsNestedDag(i))
153       count += child.getNumOps();
154   }
155   return count;
156 }
157 
158 int DagNode::getNumArgs() const { return node->getNumArgs(); }
159 
160 bool DagNode::isNestedDagArg(unsigned index) const {
161   return isa<llvm::DagInit>(node->getArg(index));
162 }
163 
164 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
165   return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
166 }
167 
168 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
169   assert(!isNestedDagArg(index));
170   return DagLeaf(node->getArg(index));
171 }
172 
173 StringRef DagNode::getArgName(unsigned index) const {
174   return node->getArgNameStr(index);
175 }
176 
177 bool DagNode::isReplaceWithValue() const {
178   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
179   return dagOpDef->getName() == "replaceWithValue";
180 }
181 
182 bool DagNode::isLocationDirective() const {
183   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
184   return dagOpDef->getName() == "location";
185 }
186 
187 bool DagNode::isReturnTypeDirective() const {
188   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
189   return dagOpDef->getName() == "returnType";
190 }
191 
192 bool DagNode::isEither() const {
193   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
194   return dagOpDef->getName() == "either";
195 }
196 
197 bool DagNode::isVariadic() const {
198   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
199   return dagOpDef->getName() == "variadic";
200 }
201 
202 void DagNode::print(raw_ostream &os) const {
203   if (node)
204     node->print(os);
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // SymbolInfoMap
209 //===----------------------------------------------------------------------===//
210 
211 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
212   int idx = -1;
213   auto [name, indexStr] = symbol.rsplit("__");
214 
215   if (indexStr.consumeInteger(10, idx)) {
216     // The second part is not an index; we return the whole symbol as-is.
217     return symbol;
218   }
219   if (index) {
220     *index = idx;
221   }
222   return name;
223 }
224 
225 SymbolInfoMap::SymbolInfo::SymbolInfo(
226     const Operator *op, SymbolInfo::Kind kind,
227     std::optional<DagAndConstant> dagAndConstant)
228     : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
229 
230 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
231   switch (kind) {
232   case Kind::Attr:
233   case Kind::Operand:
234   case Kind::Value:
235     return 1;
236   case Kind::Result:
237     return op->getNumResults();
238   case Kind::MultipleValues:
239     return getSize();
240   }
241   llvm_unreachable("unknown kind");
242 }
243 
244 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
245   return alternativeName ? *alternativeName : name.str();
246 }
247 
248 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
249   LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': ");
250   switch (kind) {
251   case Kind::Attr: {
252     if (op)
253       return op->getArg(getArgIndex())
254           .get<NamedAttribute *>()
255           ->attr.getStorageType()
256           .str();
257     // TODO(suderman): Use a more exact type when available.
258     return "::mlir::Attribute";
259   }
260   case Kind::Operand: {
261     // Use operand range for captured operands (to support potential variadic
262     // operands).
263     return "::mlir::Operation::operand_range";
264   }
265   case Kind::Value: {
266     return "::mlir::Value";
267   }
268   case Kind::MultipleValues: {
269     return "::mlir::ValueRange";
270   }
271   case Kind::Result: {
272     // Use the op itself for captured results.
273     return op->getQualCppClassName();
274   }
275   }
276   llvm_unreachable("unknown kind");
277 }
278 
279 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
280   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
281   std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
282   return std::string(
283       formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
284 }
285 
286 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
287   LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': ");
288   return std::string(
289       formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
290 }
291 
292 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
293     StringRef name, int index, const char *fmt, const char *separator) const {
294   LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
295   switch (kind) {
296   case Kind::Attr: {
297     assert(index < 0);
298     auto repl = formatv(fmt, name);
299     LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
300     return std::string(repl);
301   }
302   case Kind::Operand: {
303     assert(index < 0);
304     auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
305     // If this operand is variadic and this SymbolInfo doesn't have a range
306     // index, then return the full variadic operand_range. Otherwise, return
307     // the value itself.
308     if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
309       auto repl = formatv(fmt, name);
310       LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
311       return std::string(repl);
312     }
313     auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
314     LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
315     return std::string(repl);
316   }
317   case Kind::Result: {
318     // If `index` is greater than zero, then we are referencing a specific
319     // result of a multi-result op. The result can still be variadic.
320     if (index >= 0) {
321       std::string v =
322           std::string(formatv("{0}.getODSResults({1})", name, index));
323       if (!op->getResult(index).isVariadic())
324         v = std::string(formatv("(*{0}.begin())", v));
325       auto repl = formatv(fmt, v);
326       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
327       return std::string(repl);
328     }
329 
330     // If this op has no result at all but still we bind a symbol to it, it
331     // means we want to capture the op itself.
332     if (op->getNumResults() == 0) {
333       LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
334       return formatv(fmt, name);
335     }
336 
337     // We are referencing all results of the multi-result op. A specific result
338     // can either be a value or a range. Then join them with `separator`.
339     SmallVector<std::string, 4> values;
340     values.reserve(op->getNumResults());
341 
342     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
343       std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
344       if (!op->getResult(i).isVariadic()) {
345         v = std::string(formatv("(*{0}.begin())", v));
346       }
347       values.push_back(std::string(formatv(fmt, v)));
348     }
349     auto repl = llvm::join(values, separator);
350     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
351     return repl;
352   }
353   case Kind::Value: {
354     assert(index < 0);
355     assert(op == nullptr);
356     auto repl = formatv(fmt, name);
357     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
358     return std::string(repl);
359   }
360   case Kind::MultipleValues: {
361     assert(op == nullptr);
362     assert(index < getSize());
363     if (index >= 0) {
364       std::string repl =
365           formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
366       LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
367       return repl;
368     }
369     // If it doesn't specify certain element, unpack them all.
370     auto repl =
371         formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
372     LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
373     return std::string(repl);
374   }
375   }
376   llvm_unreachable("unknown kind");
377 }
378 
379 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
380     StringRef name, int index, const char *fmt, const char *separator) const {
381   LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
382   switch (kind) {
383   case Kind::Attr:
384   case Kind::Operand: {
385     assert(index < 0 && "only allowed for symbol bound to result");
386     auto repl = formatv(fmt, name);
387     LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
388     return std::string(repl);
389   }
390   case Kind::Result: {
391     if (index >= 0) {
392       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
393       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
394       return std::string(repl);
395     }
396 
397     // We are referencing all results of the multi-result op. Each result should
398     // have a value range, and then join them with `separator`.
399     SmallVector<std::string, 4> values;
400     values.reserve(op->getNumResults());
401 
402     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
403       values.push_back(std::string(
404           formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
405     }
406     auto repl = llvm::join(values, separator);
407     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
408     return repl;
409   }
410   case Kind::Value: {
411     assert(index < 0 && "only allowed for symbol bound to result");
412     assert(op == nullptr);
413     auto repl = formatv(fmt, formatv("{{{0}}", name));
414     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
415     return std::string(repl);
416   }
417   case Kind::MultipleValues: {
418     assert(op == nullptr);
419     assert(index < getSize());
420     if (index >= 0) {
421       std::string repl =
422           formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
423       LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
424       return repl;
425     }
426     auto repl =
427         formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
428     LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
429     return std::string(repl);
430   }
431   }
432   llvm_unreachable("unknown kind");
433 }
434 
435 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
436                                    const Operator &op, int argIndex,
437                                    std::optional<int> variadicSubIndex) {
438   StringRef name = getValuePackName(symbol);
439   if (name != symbol) {
440     auto error = formatv(
441         "symbol '{0}' with trailing index cannot bind to op argument", symbol);
442     PrintFatalError(loc, error);
443   }
444 
445   auto symInfo =
446       op.getArg(argIndex).is<NamedAttribute *>()
447           ? SymbolInfo::getAttr(&op, argIndex)
448           : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
449 
450   std::string key = symbol.str();
451   if (symbolInfoMap.count(key)) {
452     // Only non unique name for the operand is supported.
453     if (symInfo.kind != SymbolInfo::Kind::Operand) {
454       return false;
455     }
456 
457     // Cannot add new operand if there is already non operand with the same
458     // name.
459     if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
460       return false;
461     }
462   }
463 
464   symbolInfoMap.emplace(key, symInfo);
465   return true;
466 }
467 
468 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
469   std::string name = getValuePackName(symbol).str();
470   auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
471 
472   return symbolInfoMap.count(inserted->first) == 1;
473 }
474 
475 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
476   std::string name = getValuePackName(symbol).str();
477   if (numValues > 1)
478     return bindMultipleValues(name, numValues);
479   return bindValue(name);
480 }
481 
482 bool SymbolInfoMap::bindValue(StringRef symbol) {
483   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
484   return symbolInfoMap.count(inserted->first) == 1;
485 }
486 
487 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
488   std::string name = getValuePackName(symbol).str();
489   auto inserted =
490       symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
491   return symbolInfoMap.count(inserted->first) == 1;
492 }
493 
494 bool SymbolInfoMap::bindAttr(StringRef symbol) {
495   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
496   return symbolInfoMap.count(inserted->first) == 1;
497 }
498 
499 bool SymbolInfoMap::contains(StringRef symbol) const {
500   return find(symbol) != symbolInfoMap.end();
501 }
502 
503 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
504   std::string name = getValuePackName(key).str();
505 
506   return symbolInfoMap.find(name);
507 }
508 
509 SymbolInfoMap::const_iterator
510 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
511                                int argIndex,
512                                std::optional<int> variadicSubIndex) const {
513   return findBoundSymbol(
514       key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
515 }
516 
517 SymbolInfoMap::const_iterator
518 SymbolInfoMap::findBoundSymbol(StringRef key,
519                                const SymbolInfo &symbolInfo) const {
520   std::string name = getValuePackName(key).str();
521   auto range = symbolInfoMap.equal_range(name);
522 
523   for (auto it = range.first; it != range.second; ++it)
524     if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
525       return it;
526 
527   return symbolInfoMap.end();
528 }
529 
530 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
531 SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
532   std::string name = getValuePackName(key).str();
533 
534   return symbolInfoMap.equal_range(name);
535 }
536 
537 int SymbolInfoMap::count(StringRef key) const {
538   std::string name = getValuePackName(key).str();
539   return symbolInfoMap.count(name);
540 }
541 
542 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
543   StringRef name = getValuePackName(symbol);
544   if (name != symbol) {
545     // If there is a trailing index inside symbol, it references just one
546     // static value.
547     return 1;
548   }
549   // Otherwise, find how many it represents by querying the symbol's info.
550   return find(name)->second.getStaticValueCount();
551 }
552 
553 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
554                                                const char *fmt,
555                                                const char *separator) const {
556   int index = -1;
557   StringRef name = getValuePackName(symbol, &index);
558 
559   auto it = symbolInfoMap.find(name.str());
560   if (it == symbolInfoMap.end()) {
561     auto error = formatv("referencing unbound symbol '{0}'", symbol);
562     PrintFatalError(loc, error);
563   }
564 
565   return it->second.getValueAndRangeUse(name, index, fmt, separator);
566 }
567 
568 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
569                                           const char *separator) const {
570   int index = -1;
571   StringRef name = getValuePackName(symbol, &index);
572 
573   auto it = symbolInfoMap.find(name.str());
574   if (it == symbolInfoMap.end()) {
575     auto error = formatv("referencing unbound symbol '{0}'", symbol);
576     PrintFatalError(loc, error);
577   }
578 
579   return it->second.getAllRangeUse(name, index, fmt, separator);
580 }
581 
582 void SymbolInfoMap::assignUniqueAlternativeNames() {
583   llvm::StringSet<> usedNames;
584 
585   for (auto symbolInfoIt = symbolInfoMap.begin();
586        symbolInfoIt != symbolInfoMap.end();) {
587     auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
588     auto startRange = range.first;
589     auto endRange = range.second;
590 
591     auto operandName = symbolInfoIt->first;
592     int startSearchIndex = 0;
593     for (++startRange; startRange != endRange; ++startRange) {
594       // Current operand name is not unique, find a unique one
595       // and set the alternative name.
596       for (int i = startSearchIndex;; ++i) {
597         std::string alternativeName = operandName + std::to_string(i);
598         if (!usedNames.contains(alternativeName) &&
599             symbolInfoMap.count(alternativeName) == 0) {
600           usedNames.insert(alternativeName);
601           startRange->second.alternativeName = alternativeName;
602           startSearchIndex = i + 1;
603 
604           break;
605         }
606       }
607     }
608 
609     symbolInfoIt = endRange;
610   }
611 }
612 
613 //===----------------------------------------------------------------------===//
614 // Pattern
615 //==----------------------------------------------------------------------===//
616 
617 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
618     : def(*def), recordOpMap(mapper) {}
619 
620 DagNode Pattern::getSourcePattern() const {
621   return DagNode(def.getValueAsDag("sourcePattern"));
622 }
623 
624 int Pattern::getNumResultPatterns() const {
625   auto *results = def.getValueAsListInit("resultPatterns");
626   return results->size();
627 }
628 
629 DagNode Pattern::getResultPattern(unsigned index) const {
630   auto *results = def.getValueAsListInit("resultPatterns");
631   return DagNode(cast<llvm::DagInit>(results->getElement(index)));
632 }
633 
634 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
635   LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
636   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
637   LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
638 
639   LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
640   infoMap.assignUniqueAlternativeNames();
641   LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
642 }
643 
644 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
645   LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
646   for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
647     auto pattern = getResultPattern(i);
648     collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
649   }
650   LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
651 }
652 
653 const Operator &Pattern::getSourceRootOp() {
654   return getSourcePattern().getDialectOp(recordOpMap);
655 }
656 
657 Operator &Pattern::getDialectOp(DagNode node) {
658   return node.getDialectOp(recordOpMap);
659 }
660 
661 std::vector<AppliedConstraint> Pattern::getConstraints() const {
662   auto *listInit = def.getValueAsListInit("constraints");
663   std::vector<AppliedConstraint> ret;
664   ret.reserve(listInit->size());
665 
666   for (auto *it : *listInit) {
667     auto *dagInit = dyn_cast<llvm::DagInit>(it);
668     if (!dagInit)
669       PrintFatalError(&def, "all elements in Pattern multi-entity "
670                             "constraints should be DAG nodes");
671 
672     std::vector<std::string> entities;
673     entities.reserve(dagInit->arg_size());
674     for (auto *argName : dagInit->getArgNames()) {
675       if (!argName) {
676         PrintFatalError(
677             &def,
678             "operands to additional constraints can only be symbol references");
679       }
680       entities.emplace_back(argName->getValue());
681     }
682 
683     ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
684                      dagInit->getNameStr(), std::move(entities));
685   }
686   return ret;
687 }
688 
689 int Pattern::getNumSupplementalPatterns() const {
690   auto *results = def.getValueAsListInit("supplementalPatterns");
691   return results->size();
692 }
693 
694 DagNode Pattern::getSupplementalPattern(unsigned index) const {
695   auto *results = def.getValueAsListInit("supplementalPatterns");
696   return DagNode(cast<llvm::DagInit>(results->getElement(index)));
697 }
698 
699 int Pattern::getBenefit() const {
700   // The initial benefit value is a heuristic with number of ops in the source
701   // pattern.
702   int initBenefit = getSourcePattern().getNumOps();
703   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
704   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
705     PrintFatalError(&def,
706                     "The 'addBenefit' takes and only takes one integer value");
707   }
708   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
709 }
710 
711 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
712   std::vector<std::pair<StringRef, unsigned>> result;
713   result.reserve(def.getLoc().size());
714   for (auto loc : def.getLoc()) {
715     unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
716     assert(buf && "invalid source location");
717     result.emplace_back(
718         llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
719         llvm::SrcMgr.getLineAndColumn(loc, buf).first);
720   }
721   return result;
722 }
723 
724 void Pattern::verifyBind(bool result, StringRef symbolName) {
725   if (!result) {
726     auto err = formatv("symbol '{0}' bound more than once", symbolName);
727     PrintFatalError(&def, err);
728   }
729 }
730 
731 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
732                                   bool isSrcPattern) {
733   auto treeName = tree.getSymbol();
734   auto numTreeArgs = tree.getNumArgs();
735 
736   if (tree.isNativeCodeCall()) {
737     if (!treeName.empty()) {
738       if (!isSrcPattern) {
739         LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
740                                 << treeName << '\n');
741         verifyBind(
742             infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
743             treeName);
744       } else {
745         PrintFatalError(&def,
746                         formatv("binding symbol '{0}' to NativecodeCall in "
747                                 "MatchPattern is not supported",
748                                 treeName));
749       }
750     }
751 
752     for (int i = 0; i != numTreeArgs; ++i) {
753       if (auto treeArg = tree.getArgAsNestedDag(i)) {
754         // This DAG node argument is a DAG node itself. Go inside recursively.
755         collectBoundSymbols(treeArg, infoMap, isSrcPattern);
756         continue;
757       }
758 
759       if (!isSrcPattern)
760         continue;
761 
762       // We can only bind symbols to arguments in source pattern. Those
763       // symbols are referenced in result patterns.
764       auto treeArgName = tree.getArgName(i);
765 
766       // `$_` is a special symbol meaning ignore the current argument.
767       if (!treeArgName.empty() && treeArgName != "_") {
768         DagLeaf leaf = tree.getArgAsLeaf(i);
769 
770         // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
771         if (leaf.isUnspecified()) {
772           // This is case of $c, a Value without any constraints.
773           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
774         } else {
775           auto constraint = leaf.getAsConstraint();
776           bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
777                         leaf.isConstantAttr() ||
778                         constraint.getKind() == Constraint::Kind::CK_Attr;
779 
780           if (isAttr) {
781             // This is case of $a, a binding to a certain attribute.
782             verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
783             continue;
784           }
785 
786           // This is case of $b, a binding to a certain type.
787           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
788         }
789       }
790     }
791 
792     return;
793   }
794 
795   if (tree.isOperation()) {
796     auto &op = getDialectOp(tree);
797     auto numOpArgs = op.getNumArgs();
798     int numEither = 0;
799 
800     // We need to exclude the trailing directives and `either` directive groups
801     // two operands of the operation.
802     int numDirectives = 0;
803     for (int i = numTreeArgs - 1; i >= 0; --i) {
804       if (auto dagArg = tree.getArgAsNestedDag(i)) {
805         if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
806           ++numDirectives;
807         else if (dagArg.isEither())
808           ++numEither;
809       }
810     }
811 
812     if (numOpArgs != numTreeArgs - numDirectives + numEither) {
813       auto err =
814           formatv("op '{0}' argument number mismatch: "
815                   "{1} in pattern vs. {2} in definition",
816                   op.getOperationName(), numTreeArgs + numEither, numOpArgs);
817       PrintFatalError(&def, err);
818     }
819 
820     // The name attached to the DAG node's operator is for representing the
821     // results generated from this op. It should be remembered as bound results.
822     if (!treeName.empty()) {
823       LLVM_DEBUG(llvm::dbgs()
824                  << "found symbol bound to op result: " << treeName << '\n');
825       verifyBind(infoMap.bindOpResult(treeName, op), treeName);
826     }
827 
828     // The operand in `either` DAG should be bound to the operation in the
829     // parent DagNode.
830     auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
831                                      int opArgIdx) {
832       for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
833         if (DagNode subTree = tree.getArgAsNestedDag(i)) {
834           collectBoundSymbols(subTree, infoMap, isSrcPattern);
835         } else {
836           auto argName = tree.getArgName(i);
837           if (!argName.empty() && argName != "_") {
838             verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
839                        argName);
840           }
841         }
842       }
843     };
844 
845     // The operand in `variadic` DAG should be bound to the operation in the
846     // parent DagNode. The range index must be included as well to distinguish
847     // (potentially) repeating argName within the `variadic` DAG.
848     auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
849                                        int opArgIdx) {
850       auto treeName = tree.getSymbol();
851       if (!treeName.empty()) {
852         // If treeName is specified, bind to the full variadic operand_range.
853         verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
854                                           std::nullopt),
855                    treeName);
856       }
857 
858       for (int i = 0; i < tree.getNumArgs(); ++i) {
859         if (DagNode subTree = tree.getArgAsNestedDag(i)) {
860           collectBoundSymbols(subTree, infoMap, isSrcPattern);
861         } else {
862           auto argName = tree.getArgName(i);
863           if (!argName.empty() && argName != "_") {
864             verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
865                                               /*variadicSubIndex=*/i),
866                        argName);
867           }
868         }
869       }
870     };
871 
872     for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
873       if (auto treeArg = tree.getArgAsNestedDag(i)) {
874         if (treeArg.isEither()) {
875           collectSymbolInEither(tree, treeArg, opArgIdx);
876           // `either` DAG is *flattened*. For example,
877           //
878           //  (FooOp (either arg0, arg1), arg2)
879           //
880           //  can be viewed as:
881           //
882           //  (FooOp arg0, arg1, arg2)
883           ++opArgIdx;
884         } else if (treeArg.isVariadic()) {
885           collectSymbolInVariadic(tree, treeArg, opArgIdx);
886         } else {
887           // This DAG node argument is a DAG node itself. Go inside recursively.
888           collectBoundSymbols(treeArg, infoMap, isSrcPattern);
889         }
890         continue;
891       }
892 
893       if (isSrcPattern) {
894         // We can only bind symbols to op arguments in source pattern. Those
895         // symbols are referenced in result patterns.
896         auto treeArgName = tree.getArgName(i);
897         // `$_` is a special symbol meaning ignore the current argument.
898         if (!treeArgName.empty() && treeArgName != "_") {
899           LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
900                                   << treeArgName << '\n');
901           verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
902                      treeArgName);
903         }
904       }
905     }
906     return;
907   }
908 
909   if (!treeName.empty()) {
910     PrintFatalError(
911         &def, formatv("binding symbol '{0}' to non-operation/native code call "
912                       "unsupported right now",
913                       treeName));
914   }
915 }
916