xref: /llvm-project/mlir/lib/TableGen/Pattern.cpp (revision 064a08cd955019da9130f1109bfa534e79b8ec7c)
1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
25 
26 using namespace mlir;
27 using namespace tblgen;
28 
29 using llvm::formatv;
30 
31 //===----------------------------------------------------------------------===//
32 // DagLeaf
33 //===----------------------------------------------------------------------===//
34 
35 bool DagLeaf::isUnspecified() const {
36   return dyn_cast_or_null<llvm::UnsetInit>(def);
37 }
38 
39 bool DagLeaf::isOperandMatcher() const {
40   // Operand matchers specify a type constraint.
41   return isSubClassOf("TypeConstraint");
42 }
43 
44 bool DagLeaf::isAttrMatcher() const {
45   // Attribute matchers specify an attribute constraint.
46   return isSubClassOf("AttrConstraint");
47 }
48 
49 bool DagLeaf::isNativeCodeCall() const {
50   return isSubClassOf("NativeCodeCall");
51 }
52 
53 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
54 
55 bool DagLeaf::isEnumAttrCase() const {
56   return isSubClassOf("EnumAttrCaseInfo");
57 }
58 
59 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
60 
61 Constraint DagLeaf::getAsConstraint() const {
62   assert((isOperandMatcher() || isAttrMatcher()) &&
63          "the DAG leaf must be operand or attribute");
64   return Constraint(cast<llvm::DefInit>(def)->getDef());
65 }
66 
67 ConstantAttr DagLeaf::getAsConstantAttr() const {
68   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
69   return ConstantAttr(cast<llvm::DefInit>(def));
70 }
71 
72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
73   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
74   return EnumAttrCase(cast<llvm::DefInit>(def));
75 }
76 
77 std::string DagLeaf::getConditionTemplate() const {
78   return getAsConstraint().getConditionTemplate();
79 }
80 
81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
82   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
83   return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
84 }
85 
86 int DagLeaf::getNumReturnsOfNativeCode() const {
87   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
88   return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
89 }
90 
91 std::string DagLeaf::getStringAttr() const {
92   assert(isStringAttr() && "the DAG leaf must be string attribute");
93   return def->getAsUnquotedString();
94 }
95 bool DagLeaf::isSubClassOf(StringRef superclass) const {
96   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
97     return defInit->getDef()->isSubClassOf(superclass);
98   return false;
99 }
100 
101 void DagLeaf::print(raw_ostream &os) const {
102   if (def)
103     def->print(os);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // DagNode
108 //===----------------------------------------------------------------------===//
109 
110 bool DagNode::isNativeCodeCall() const {
111   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
112     return defInit->getDef()->isSubClassOf("NativeCodeCall");
113   return false;
114 }
115 
116 bool DagNode::isOperation() const {
117   return !isNativeCodeCall() && !isReplaceWithValue() &&
118          !isLocationDirective() && !isReturnTypeDirective() && !isEither();
119 }
120 
121 llvm::StringRef DagNode::getNativeCodeTemplate() const {
122   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
123   return cast<llvm::DefInit>(node->getOperator())
124       ->getDef()
125       ->getValueAsString("expression");
126 }
127 
128 int DagNode::getNumReturnsOfNativeCode() const {
129   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
130   return cast<llvm::DefInit>(node->getOperator())
131       ->getDef()
132       ->getValueAsInt("numReturns");
133 }
134 
135 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
136 
137 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
138   llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
139   auto it = mapper->find(opDef);
140   if (it != mapper->end())
141     return *it->second;
142   return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
143               .first->second;
144 }
145 
146 int DagNode::getNumOps() const {
147   // We want to get number of operations recursively involved in the DAG tree.
148   // All other directives should be excluded.
149   int count = isOperation() ? 1 : 0;
150   for (int i = 0, e = getNumArgs(); i != e; ++i) {
151     if (auto child = getArgAsNestedDag(i))
152       count += child.getNumOps();
153   }
154   return count;
155 }
156 
157 int DagNode::getNumArgs() const { return node->getNumArgs(); }
158 
159 bool DagNode::isNestedDagArg(unsigned index) const {
160   return isa<llvm::DagInit>(node->getArg(index));
161 }
162 
163 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
164   return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
165 }
166 
167 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
168   assert(!isNestedDagArg(index));
169   return DagLeaf(node->getArg(index));
170 }
171 
172 StringRef DagNode::getArgName(unsigned index) const {
173   return node->getArgNameStr(index);
174 }
175 
176 bool DagNode::isReplaceWithValue() const {
177   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
178   return dagOpDef->getName() == "replaceWithValue";
179 }
180 
181 bool DagNode::isLocationDirective() const {
182   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
183   return dagOpDef->getName() == "location";
184 }
185 
186 bool DagNode::isReturnTypeDirective() const {
187   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
188   return dagOpDef->getName() == "returnType";
189 }
190 
191 bool DagNode::isEither() const {
192   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
193   return dagOpDef->getName() == "either";
194 }
195 
196 void DagNode::print(raw_ostream &os) const {
197   if (node)
198     node->print(os);
199 }
200 
201 //===----------------------------------------------------------------------===//
202 // SymbolInfoMap
203 //===----------------------------------------------------------------------===//
204 
205 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
206   StringRef name, indexStr;
207   int idx = -1;
208   std::tie(name, indexStr) = symbol.rsplit("__");
209 
210   if (indexStr.consumeInteger(10, idx)) {
211     // The second part is not an index; we return the whole symbol as-is.
212     return symbol;
213   }
214   if (index) {
215     *index = idx;
216   }
217   return name;
218 }
219 
220 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
221                                       Optional<DagAndConstant> dagAndConstant)
222     : op(op), kind(kind), dagAndConstant(std::move(dagAndConstant)) {}
223 
224 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
225   switch (kind) {
226   case Kind::Attr:
227   case Kind::Operand:
228   case Kind::Value:
229     return 1;
230   case Kind::Result:
231     return op->getNumResults();
232   case Kind::MultipleValues:
233     return getSize();
234   }
235   llvm_unreachable("unknown kind");
236 }
237 
238 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
239   return alternativeName ? *alternativeName : name.str();
240 }
241 
242 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
243   LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': ");
244   switch (kind) {
245   case Kind::Attr: {
246     if (op)
247       return op->getArg(getArgIndex())
248           .get<NamedAttribute *>()
249           ->attr.getStorageType()
250           .str();
251     // TODO(suderman): Use a more exact type when available.
252     return "::mlir::Attribute";
253   }
254   case Kind::Operand: {
255     // Use operand range for captured operands (to support potential variadic
256     // operands).
257     return "::mlir::Operation::operand_range";
258   }
259   case Kind::Value: {
260     return "::mlir::Value";
261   }
262   case Kind::MultipleValues: {
263     return "::mlir::ValueRange";
264   }
265   case Kind::Result: {
266     // Use the op itself for captured results.
267     return op->getQualCppClassName();
268   }
269   }
270   llvm_unreachable("unknown kind");
271 }
272 
273 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
274   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
275   std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
276   return std::string(
277       formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
278 }
279 
280 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
281   LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': ");
282   return std::string(
283       formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
284 }
285 
286 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
287     StringRef name, int index, const char *fmt, const char *separator) const {
288   LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
289   switch (kind) {
290   case Kind::Attr: {
291     assert(index < 0);
292     auto repl = formatv(fmt, name);
293     LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
294     return std::string(repl);
295   }
296   case Kind::Operand: {
297     assert(index < 0);
298     auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
299     // If this operand is variadic, then return a range. Otherwise, return the
300     // value itself.
301     if (operand->isVariableLength()) {
302       auto repl = formatv(fmt, name);
303       LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
304       return std::string(repl);
305     }
306     auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
307     LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
308     return std::string(repl);
309   }
310   case Kind::Result: {
311     // If `index` is greater than zero, then we are referencing a specific
312     // result of a multi-result op. The result can still be variadic.
313     if (index >= 0) {
314       std::string v =
315           std::string(formatv("{0}.getODSResults({1})", name, index));
316       if (!op->getResult(index).isVariadic())
317         v = std::string(formatv("(*{0}.begin())", v));
318       auto repl = formatv(fmt, v);
319       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
320       return std::string(repl);
321     }
322 
323     // If this op has no result at all but still we bind a symbol to it, it
324     // means we want to capture the op itself.
325     if (op->getNumResults() == 0) {
326       LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
327       return formatv(fmt, name);
328     }
329 
330     // We are referencing all results of the multi-result op. A specific result
331     // can either be a value or a range. Then join them with `separator`.
332     SmallVector<std::string, 4> values;
333     values.reserve(op->getNumResults());
334 
335     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
336       std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
337       if (!op->getResult(i).isVariadic()) {
338         v = std::string(formatv("(*{0}.begin())", v));
339       }
340       values.push_back(std::string(formatv(fmt, v)));
341     }
342     auto repl = llvm::join(values, separator);
343     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
344     return repl;
345   }
346   case Kind::Value: {
347     assert(index < 0);
348     assert(op == nullptr);
349     auto repl = formatv(fmt, name);
350     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
351     return std::string(repl);
352   }
353   case Kind::MultipleValues: {
354     assert(op == nullptr);
355     assert(index < getSize());
356     if (index >= 0) {
357       std::string repl =
358           formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
359       LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
360       return repl;
361     }
362     // If it doesn't specify certain element, unpack them all.
363     auto repl =
364         formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
365     LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
366     return std::string(repl);
367   }
368   }
369   llvm_unreachable("unknown kind");
370 }
371 
372 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
373     StringRef name, int index, const char *fmt, const char *separator) const {
374   LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
375   switch (kind) {
376   case Kind::Attr:
377   case Kind::Operand: {
378     assert(index < 0 && "only allowed for symbol bound to result");
379     auto repl = formatv(fmt, name);
380     LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
381     return std::string(repl);
382   }
383   case Kind::Result: {
384     if (index >= 0) {
385       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
386       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
387       return std::string(repl);
388     }
389 
390     // We are referencing all results of the multi-result op. Each result should
391     // have a value range, and then join them with `separator`.
392     SmallVector<std::string, 4> values;
393     values.reserve(op->getNumResults());
394 
395     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
396       values.push_back(std::string(
397           formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
398     }
399     auto repl = llvm::join(values, separator);
400     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
401     return repl;
402   }
403   case Kind::Value: {
404     assert(index < 0 && "only allowed for symbol bound to result");
405     assert(op == nullptr);
406     auto repl = formatv(fmt, formatv("{{{0}}", name));
407     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
408     return std::string(repl);
409   }
410   case Kind::MultipleValues: {
411     assert(op == nullptr);
412     assert(index < getSize());
413     if (index >= 0) {
414       std::string repl =
415           formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
416       LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
417       return repl;
418     }
419     auto repl =
420         formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
421     LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
422     return std::string(repl);
423   }
424   }
425   llvm_unreachable("unknown kind");
426 }
427 
428 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
429                                    const Operator &op, int argIndex) {
430   StringRef name = getValuePackName(symbol);
431   if (name != symbol) {
432     auto error = formatv(
433         "symbol '{0}' with trailing index cannot bind to op argument", symbol);
434     PrintFatalError(loc, error);
435   }
436 
437   auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
438                      ? SymbolInfo::getAttr(&op, argIndex)
439                      : SymbolInfo::getOperand(node, &op, argIndex);
440 
441   std::string key = symbol.str();
442   if (symbolInfoMap.count(key)) {
443     // Only non unique name for the operand is supported.
444     if (symInfo.kind != SymbolInfo::Kind::Operand) {
445       return false;
446     }
447 
448     // Cannot add new operand if there is already non operand with the same
449     // name.
450     if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
451       return false;
452     }
453   }
454 
455   symbolInfoMap.emplace(key, symInfo);
456   return true;
457 }
458 
459 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
460   std::string name = getValuePackName(symbol).str();
461   auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
462 
463   return symbolInfoMap.count(inserted->first) == 1;
464 }
465 
466 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
467   std::string name = getValuePackName(symbol).str();
468   if (numValues > 1)
469     return bindMultipleValues(name, numValues);
470   return bindValue(name);
471 }
472 
473 bool SymbolInfoMap::bindValue(StringRef symbol) {
474   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
475   return symbolInfoMap.count(inserted->first) == 1;
476 }
477 
478 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
479   std::string name = getValuePackName(symbol).str();
480   auto inserted =
481       symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
482   return symbolInfoMap.count(inserted->first) == 1;
483 }
484 
485 bool SymbolInfoMap::bindAttr(StringRef symbol) {
486   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
487   return symbolInfoMap.count(inserted->first) == 1;
488 }
489 
490 bool SymbolInfoMap::contains(StringRef symbol) const {
491   return find(symbol) != symbolInfoMap.end();
492 }
493 
494 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
495   std::string name = getValuePackName(key).str();
496 
497   return symbolInfoMap.find(name);
498 }
499 
500 SymbolInfoMap::const_iterator
501 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
502                                int argIndex) const {
503   return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex));
504 }
505 
506 SymbolInfoMap::const_iterator
507 SymbolInfoMap::findBoundSymbol(StringRef key,
508                                const SymbolInfo &symbolInfo) const {
509   std::string name = getValuePackName(key).str();
510   auto range = symbolInfoMap.equal_range(name);
511 
512   for (auto it = range.first; it != range.second; ++it)
513     if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
514       return it;
515 
516   return symbolInfoMap.end();
517 }
518 
519 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
520 SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
521   std::string name = getValuePackName(key).str();
522 
523   return symbolInfoMap.equal_range(name);
524 }
525 
526 int SymbolInfoMap::count(StringRef key) const {
527   std::string name = getValuePackName(key).str();
528   return symbolInfoMap.count(name);
529 }
530 
531 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
532   StringRef name = getValuePackName(symbol);
533   if (name != symbol) {
534     // If there is a trailing index inside symbol, it references just one
535     // static value.
536     return 1;
537   }
538   // Otherwise, find how many it represents by querying the symbol's info.
539   return find(name)->second.getStaticValueCount();
540 }
541 
542 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
543                                                const char *fmt,
544                                                const char *separator) const {
545   int index = -1;
546   StringRef name = getValuePackName(symbol, &index);
547 
548   auto it = symbolInfoMap.find(name.str());
549   if (it == symbolInfoMap.end()) {
550     auto error = formatv("referencing unbound symbol '{0}'", symbol);
551     PrintFatalError(loc, error);
552   }
553 
554   return it->second.getValueAndRangeUse(name, index, fmt, separator);
555 }
556 
557 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
558                                           const char *separator) const {
559   int index = -1;
560   StringRef name = getValuePackName(symbol, &index);
561 
562   auto it = symbolInfoMap.find(name.str());
563   if (it == symbolInfoMap.end()) {
564     auto error = formatv("referencing unbound symbol '{0}'", symbol);
565     PrintFatalError(loc, error);
566   }
567 
568   return it->second.getAllRangeUse(name, index, fmt, separator);
569 }
570 
571 void SymbolInfoMap::assignUniqueAlternativeNames() {
572   llvm::StringSet<> usedNames;
573 
574   for (auto symbolInfoIt = symbolInfoMap.begin();
575        symbolInfoIt != symbolInfoMap.end();) {
576     auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
577     auto startRange = range.first;
578     auto endRange = range.second;
579 
580     auto operandName = symbolInfoIt->first;
581     int startSearchIndex = 0;
582     for (++startRange; startRange != endRange; ++startRange) {
583       // Current operand name is not unique, find a unique one
584       // and set the alternative name.
585       for (int i = startSearchIndex;; ++i) {
586         std::string alternativeName = operandName + std::to_string(i);
587         if (!usedNames.contains(alternativeName) &&
588             symbolInfoMap.count(alternativeName) == 0) {
589           usedNames.insert(alternativeName);
590           startRange->second.alternativeName = alternativeName;
591           startSearchIndex = i + 1;
592 
593           break;
594         }
595       }
596     }
597 
598     symbolInfoIt = endRange;
599   }
600 }
601 
602 //===----------------------------------------------------------------------===//
603 // Pattern
604 //==----------------------------------------------------------------------===//
605 
606 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
607     : def(*def), recordOpMap(mapper) {}
608 
609 DagNode Pattern::getSourcePattern() const {
610   return DagNode(def.getValueAsDag("sourcePattern"));
611 }
612 
613 int Pattern::getNumResultPatterns() const {
614   auto *results = def.getValueAsListInit("resultPatterns");
615   return results->size();
616 }
617 
618 DagNode Pattern::getResultPattern(unsigned index) const {
619   auto *results = def.getValueAsListInit("resultPatterns");
620   return DagNode(cast<llvm::DagInit>(results->getElement(index)));
621 }
622 
623 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
624   LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
625   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
626   LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
627 
628   LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
629   infoMap.assignUniqueAlternativeNames();
630   LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
631 }
632 
633 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
634   LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
635   for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
636     auto pattern = getResultPattern(i);
637     collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
638   }
639   LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
640 }
641 
642 const Operator &Pattern::getSourceRootOp() {
643   return getSourcePattern().getDialectOp(recordOpMap);
644 }
645 
646 Operator &Pattern::getDialectOp(DagNode node) {
647   return node.getDialectOp(recordOpMap);
648 }
649 
650 std::vector<AppliedConstraint> Pattern::getConstraints() const {
651   auto *listInit = def.getValueAsListInit("constraints");
652   std::vector<AppliedConstraint> ret;
653   ret.reserve(listInit->size());
654 
655   for (auto *it : *listInit) {
656     auto *dagInit = dyn_cast<llvm::DagInit>(it);
657     if (!dagInit)
658       PrintFatalError(&def, "all elements in Pattern multi-entity "
659                             "constraints should be DAG nodes");
660 
661     std::vector<std::string> entities;
662     entities.reserve(dagInit->arg_size());
663     for (auto *argName : dagInit->getArgNames()) {
664       if (!argName) {
665         PrintFatalError(
666             &def,
667             "operands to additional constraints can only be symbol references");
668       }
669       entities.emplace_back(argName->getValue());
670     }
671 
672     ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
673                      dagInit->getNameStr(), std::move(entities));
674   }
675   return ret;
676 }
677 
678 int Pattern::getBenefit() const {
679   // The initial benefit value is a heuristic with number of ops in the source
680   // pattern.
681   int initBenefit = getSourcePattern().getNumOps();
682   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
683   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
684     PrintFatalError(&def,
685                     "The 'addBenefit' takes and only takes one integer value");
686   }
687   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
688 }
689 
690 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
691   std::vector<std::pair<StringRef, unsigned>> result;
692   result.reserve(def.getLoc().size());
693   for (auto loc : def.getLoc()) {
694     unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
695     assert(buf && "invalid source location");
696     result.emplace_back(
697         llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
698         llvm::SrcMgr.getLineAndColumn(loc, buf).first);
699   }
700   return result;
701 }
702 
703 void Pattern::verifyBind(bool result, StringRef symbolName) {
704   if (!result) {
705     auto err = formatv("symbol '{0}' bound more than once", symbolName);
706     PrintFatalError(&def, err);
707   }
708 }
709 
710 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
711                                   bool isSrcPattern) {
712   auto treeName = tree.getSymbol();
713   auto numTreeArgs = tree.getNumArgs();
714 
715   if (tree.isNativeCodeCall()) {
716     if (!treeName.empty()) {
717       if (!isSrcPattern) {
718         LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
719                                 << treeName << '\n');
720         verifyBind(
721             infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
722             treeName);
723       } else {
724         PrintFatalError(&def,
725                         formatv("binding symbol '{0}' to NativecodeCall in "
726                                 "MatchPattern is not supported",
727                                 treeName));
728       }
729     }
730 
731     for (int i = 0; i != numTreeArgs; ++i) {
732       if (auto treeArg = tree.getArgAsNestedDag(i)) {
733         // This DAG node argument is a DAG node itself. Go inside recursively.
734         collectBoundSymbols(treeArg, infoMap, isSrcPattern);
735         continue;
736       }
737 
738       if (!isSrcPattern)
739         continue;
740 
741       // We can only bind symbols to arguments in source pattern. Those
742       // symbols are referenced in result patterns.
743       auto treeArgName = tree.getArgName(i);
744 
745       // `$_` is a special symbol meaning ignore the current argument.
746       if (!treeArgName.empty() && treeArgName != "_") {
747         DagLeaf leaf = tree.getArgAsLeaf(i);
748 
749         // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
750         if (leaf.isUnspecified()) {
751           // This is case of $c, a Value without any constraints.
752           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
753         } else {
754           auto constraint = leaf.getAsConstraint();
755           bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
756                         leaf.isConstantAttr() ||
757                         constraint.getKind() == Constraint::Kind::CK_Attr;
758 
759           if (isAttr) {
760             // This is case of $a, a binding to a certain attribute.
761             verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
762             continue;
763           }
764 
765           // This is case of $b, a binding to a certain type.
766           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
767         }
768       }
769     }
770 
771     return;
772   }
773 
774   if (tree.isOperation()) {
775     auto &op = getDialectOp(tree);
776     auto numOpArgs = op.getNumArgs();
777     int numEither = 0;
778 
779     // We need to exclude the trailing directives and `either` directive groups
780     // two operands of the operation.
781     int numDirectives = 0;
782     for (int i = numTreeArgs - 1; i >= 0; --i) {
783       if (auto dagArg = tree.getArgAsNestedDag(i)) {
784         if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
785           ++numDirectives;
786         else if (dagArg.isEither())
787           ++numEither;
788       }
789     }
790 
791     if (numOpArgs != numTreeArgs - numDirectives + numEither) {
792       auto err =
793           formatv("op '{0}' argument number mismatch: "
794                   "{1} in pattern vs. {2} in definition",
795                   op.getOperationName(), numTreeArgs + numEither, numOpArgs);
796       PrintFatalError(&def, err);
797     }
798 
799     // The name attached to the DAG node's operator is for representing the
800     // results generated from this op. It should be remembered as bound results.
801     if (!treeName.empty()) {
802       LLVM_DEBUG(llvm::dbgs()
803                  << "found symbol bound to op result: " << treeName << '\n');
804       verifyBind(infoMap.bindOpResult(treeName, op), treeName);
805     }
806 
807     // The operand in `either` DAG should be bound to the operation in the
808     // parent DagNode.
809     auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
810                                      int &opArgIdx) {
811       for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
812         if (DagNode subTree = tree.getArgAsNestedDag(i)) {
813           collectBoundSymbols(subTree, infoMap, isSrcPattern);
814         } else {
815           auto argName = tree.getArgName(i);
816           if (!argName.empty() && argName != "_")
817             verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
818                        argName);
819         }
820       }
821     };
822 
823     for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
824       if (auto treeArg = tree.getArgAsNestedDag(i)) {
825         if (treeArg.isEither()) {
826           collectSymbolInEither(tree, treeArg, opArgIdx);
827         } else {
828           // This DAG node argument is a DAG node itself. Go inside recursively.
829           collectBoundSymbols(treeArg, infoMap, isSrcPattern);
830         }
831         continue;
832       }
833 
834       if (isSrcPattern) {
835         // We can only bind symbols to op arguments in source pattern. Those
836         // symbols are referenced in result patterns.
837         auto treeArgName = tree.getArgName(i);
838         // `$_` is a special symbol meaning ignore the current argument.
839         if (!treeArgName.empty() && treeArgName != "_") {
840           LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
841                                   << treeArgName << '\n');
842           verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
843                      treeArgName);
844         }
845       }
846     }
847     return;
848   }
849 
850   if (!treeName.empty()) {
851     PrintFatalError(
852         &def, formatv("binding symbol '{0}' to non-operation/native code call "
853                       "unsupported right now",
854                       treeName));
855   }
856 }
857