xref: /llvm-project/mlir/lib/TableGen/Pattern.cpp (revision 56222a0694e4caf35e892d70591417c39fef1185)
1cde4d5a6SJacques Pienaar //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2eb753f4aSLei Zhang //
3*56222a06SMehdi Amini // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4*56222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
5*56222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6eb753f4aSLei Zhang //
7*56222a06SMehdi Amini //===----------------------------------------------------------------------===//
8eb753f4aSLei Zhang //
9eb753f4aSLei Zhang // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10eb753f4aSLei Zhang // Pattern.
11eb753f4aSLei Zhang //
12eb753f4aSLei Zhang //===----------------------------------------------------------------------===//
13eb753f4aSLei Zhang 
14eb753f4aSLei Zhang #include "mlir/TableGen/Pattern.h"
15eb753f4aSLei Zhang #include "llvm/ADT/Twine.h"
161358df19SLei Zhang #include "llvm/Support/Debug.h"
1704b6d2f3SLei Zhang #include "llvm/Support/FormatVariadic.h"
188f5fa566SLei Zhang #include "llvm/TableGen/Error.h"
19eb753f4aSLei Zhang #include "llvm/TableGen/Record.h"
20eb753f4aSLei Zhang 
211358df19SLei Zhang #define DEBUG_TYPE "mlir-tblgen-pattern"
221358df19SLei Zhang 
23eb753f4aSLei Zhang using namespace mlir;
24eb753f4aSLei Zhang 
250ea6154bSJacques Pienaar using llvm::formatv;
26eb753f4aSLei Zhang using mlir::tblgen::Operator;
27eb753f4aSLei Zhang 
28ac68637bSLei Zhang //===----------------------------------------------------------------------===//
29ac68637bSLei Zhang // DagLeaf
30ac68637bSLei Zhang //===----------------------------------------------------------------------===//
31ac68637bSLei Zhang 
32e0774c00SLei Zhang bool tblgen::DagLeaf::isUnspecified() const {
33b9e38a79SLei Zhang   return dyn_cast_or_null<llvm::UnsetInit>(def);
34e0774c00SLei Zhang }
35e0774c00SLei Zhang 
36e0774c00SLei Zhang bool tblgen::DagLeaf::isOperandMatcher() const {
37e0774c00SLei Zhang   // Operand matchers specify a type constraint.
38b9e38a79SLei Zhang   return isSubClassOf("TypeConstraint");
39e0774c00SLei Zhang }
40e0774c00SLei Zhang 
41e0774c00SLei Zhang bool tblgen::DagLeaf::isAttrMatcher() const {
42c52a8127SFeng Liu   // Attribute matchers specify an attribute constraint.
43b9e38a79SLei Zhang   return isSubClassOf("AttrConstraint");
44e0774c00SLei Zhang }
45e0774c00SLei Zhang 
46d0e2019dSLei Zhang bool tblgen::DagLeaf::isNativeCodeCall() const {
47d0e2019dSLei Zhang   return isSubClassOf("NativeCodeCall");
48e0774c00SLei Zhang }
49e0774c00SLei Zhang 
50e0774c00SLei Zhang bool tblgen::DagLeaf::isConstantAttr() const {
51b9e38a79SLei Zhang   return isSubClassOf("ConstantAttr");
52b9e38a79SLei Zhang }
53b9e38a79SLei Zhang 
54b9e38a79SLei Zhang bool tblgen::DagLeaf::isEnumAttrCase() const {
559dd182e0SLei Zhang   return isSubClassOf("EnumAttrCaseInfo");
56e0774c00SLei Zhang }
57e0774c00SLei Zhang 
588f5fa566SLei Zhang tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {
598f5fa566SLei Zhang   assert((isOperandMatcher() || isAttrMatcher()) &&
608f5fa566SLei Zhang          "the DAG leaf must be operand or attribute");
618f5fa566SLei Zhang   return Constraint(cast<llvm::DefInit>(def)->getDef());
62e0774c00SLei Zhang }
63e0774c00SLei Zhang 
64e0774c00SLei Zhang tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
65e0774c00SLei Zhang   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
66e0774c00SLei Zhang   return ConstantAttr(cast<llvm::DefInit>(def));
67e0774c00SLei Zhang }
68e0774c00SLei Zhang 
69b9e38a79SLei Zhang tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const {
70b9e38a79SLei Zhang   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
71b9e38a79SLei Zhang   return EnumAttrCase(cast<llvm::DefInit>(def));
72b9e38a79SLei Zhang }
73b9e38a79SLei Zhang 
74e0774c00SLei Zhang std::string tblgen::DagLeaf::getConditionTemplate() const {
758f5fa566SLei Zhang   return getAsConstraint().getConditionTemplate();
76e0774c00SLei Zhang }
77e0774c00SLei Zhang 
78d0e2019dSLei Zhang llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const {
79d0e2019dSLei Zhang   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
80d0e2019dSLei Zhang   return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
81eb753f4aSLei Zhang }
82eb753f4aSLei Zhang 
83b9e38a79SLei Zhang bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
84b9e38a79SLei Zhang   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
85b9e38a79SLei Zhang     return defInit->getDef()->isSubClassOf(superclass);
86b9e38a79SLei Zhang   return false;
87b9e38a79SLei Zhang }
88b9e38a79SLei Zhang 
891358df19SLei Zhang void tblgen::DagLeaf::print(raw_ostream &os) const {
901358df19SLei Zhang   if (def)
911358df19SLei Zhang     def->print(os);
921358df19SLei Zhang }
931358df19SLei Zhang 
94ac68637bSLei Zhang //===----------------------------------------------------------------------===//
95ac68637bSLei Zhang // DagNode
96ac68637bSLei Zhang //===----------------------------------------------------------------------===//
97ac68637bSLei Zhang 
98d0e2019dSLei Zhang bool tblgen::DagNode::isNativeCodeCall() const {
99d0e2019dSLei Zhang   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
100d0e2019dSLei Zhang     return defInit->getDef()->isSubClassOf("NativeCodeCall");
101c52a8127SFeng Liu   return false;
102c52a8127SFeng Liu }
103c52a8127SFeng Liu 
104647f8cabSRiver Riddle bool tblgen::DagNode::isOperation() const {
105c72d849eSLei Zhang   return !(isNativeCodeCall() || isReplaceWithValue());
106647f8cabSRiver Riddle }
107647f8cabSRiver Riddle 
108d0e2019dSLei Zhang llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
109d0e2019dSLei Zhang   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
110c52a8127SFeng Liu   return cast<llvm::DefInit>(node->getOperator())
111c52a8127SFeng Liu       ->getDef()
112d0e2019dSLei Zhang       ->getValueAsString("expression");
113c52a8127SFeng Liu }
114c52a8127SFeng Liu 
115e032d0dcSLei Zhang llvm::StringRef tblgen::DagNode::getSymbol() const {
116388fb375SJacques Pienaar   return node->getNameStr();
117388fb375SJacques Pienaar }
118388fb375SJacques Pienaar 
119eb753f4aSLei Zhang Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const {
120eb753f4aSLei Zhang   llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
1212dc6d205SLei Zhang   auto it = mapper->find(opDef);
1222dc6d205SLei Zhang   if (it != mapper->end())
1232dc6d205SLei Zhang     return *it->second;
12479f53b0cSJacques Pienaar   return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
1252dc6d205SLei Zhang               .first->second;
126eb753f4aSLei Zhang }
127eb753f4aSLei Zhang 
1282fe8ae4fSJacques Pienaar int tblgen::DagNode::getNumOps() const {
1292fe8ae4fSJacques Pienaar   int count = isReplaceWithValue() ? 0 : 1;
1302fe8ae4fSJacques Pienaar   for (int i = 0, e = getNumArgs(); i != e; ++i) {
131eb753f4aSLei Zhang     if (auto child = getArgAsNestedDag(i))
132eb753f4aSLei Zhang       count += child.getNumOps();
133eb753f4aSLei Zhang   }
134eb753f4aSLei Zhang   return count;
135eb753f4aSLei Zhang }
136eb753f4aSLei Zhang 
1372fe8ae4fSJacques Pienaar int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); }
138eb753f4aSLei Zhang 
139eb753f4aSLei Zhang bool tblgen::DagNode::isNestedDagArg(unsigned index) const {
140eb753f4aSLei Zhang   return isa<llvm::DagInit>(node->getArg(index));
141eb753f4aSLei Zhang }
142eb753f4aSLei Zhang 
143eb753f4aSLei Zhang tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const {
144eb753f4aSLei Zhang   return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
145eb753f4aSLei Zhang }
146eb753f4aSLei Zhang 
147e0774c00SLei Zhang tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const {
148e0774c00SLei Zhang   assert(!isNestedDagArg(index));
149e0774c00SLei Zhang   return DagLeaf(node->getArg(index));
150eb753f4aSLei Zhang }
151eb753f4aSLei Zhang 
152eb753f4aSLei Zhang StringRef tblgen::DagNode::getArgName(unsigned index) const {
153eb753f4aSLei Zhang   return node->getArgNameStr(index);
154eb753f4aSLei Zhang }
155eb753f4aSLei Zhang 
156eb753f4aSLei Zhang bool tblgen::DagNode::isReplaceWithValue() const {
157eb753f4aSLei Zhang   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
158eb753f4aSLei Zhang   return dagOpDef->getName() == "replaceWithValue";
159eb753f4aSLei Zhang }
160eb753f4aSLei Zhang 
1611358df19SLei Zhang void tblgen::DagNode::print(raw_ostream &os) const {
1621358df19SLei Zhang   if (node)
1631358df19SLei Zhang     node->print(os);
1641358df19SLei Zhang }
1651358df19SLei Zhang 
166ac68637bSLei Zhang //===----------------------------------------------------------------------===//
167ac68637bSLei Zhang // SymbolInfoMap
168ac68637bSLei Zhang //===----------------------------------------------------------------------===//
169ac68637bSLei Zhang 
170ac68637bSLei Zhang StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol,
171ac68637bSLei Zhang                                                   int *index) {
172ac68637bSLei Zhang   StringRef name, indexStr;
173ac68637bSLei Zhang   int idx = -1;
174ac68637bSLei Zhang   std::tie(name, indexStr) = symbol.rsplit("__");
175ac68637bSLei Zhang 
176ac68637bSLei Zhang   if (indexStr.consumeInteger(10, idx)) {
177ac68637bSLei Zhang     // The second part is not an index; we return the whole symbol as-is.
178ac68637bSLei Zhang     return symbol;
179eb753f4aSLei Zhang   }
180ac68637bSLei Zhang   if (index) {
181ac68637bSLei Zhang     *index = idx;
182ac68637bSLei Zhang   }
183ac68637bSLei Zhang   return name;
184ac68637bSLei Zhang }
185ac68637bSLei Zhang 
186ac68637bSLei Zhang tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op,
187ac68637bSLei Zhang                                               SymbolInfo::Kind kind,
188ac68637bSLei Zhang                                               Optional<int> index)
189ac68637bSLei Zhang     : op(op), kind(kind), argIndex(index) {}
190ac68637bSLei Zhang 
191ac68637bSLei Zhang int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
192ac68637bSLei Zhang   switch (kind) {
193ac68637bSLei Zhang   case Kind::Attr:
194ac68637bSLei Zhang   case Kind::Operand:
195ac68637bSLei Zhang   case Kind::Value:
196ac68637bSLei Zhang     return 1;
197ac68637bSLei Zhang   case Kind::Result:
198ac68637bSLei Zhang     return op->getNumResults();
199ac68637bSLei Zhang   }
20012ff145eSjpienaar   llvm_unreachable("unknown kind");
201ac68637bSLei Zhang }
202ac68637bSLei Zhang 
203ac68637bSLei Zhang std::string
204ac68637bSLei Zhang tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
205796ca609SLei Zhang   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
206ac68637bSLei Zhang   switch (kind) {
207ac68637bSLei Zhang   case Kind::Attr: {
208ac68637bSLei Zhang     auto type =
209ac68637bSLei Zhang         op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
210ac68637bSLei Zhang     return formatv("{0} {1};\n", type, name);
211ac68637bSLei Zhang   }
21231cfee60SLei Zhang   case Kind::Operand: {
21331cfee60SLei Zhang     // Use operand range for captured operands (to support potential variadic
21431cfee60SLei Zhang     // operands).
21531cfee60SLei Zhang     return formatv("Operation::operand_range {0}(op0->getOperands());\n", name);
21631cfee60SLei Zhang   }
217ac68637bSLei Zhang   case Kind::Value: {
21835807bc4SRiver Riddle     return formatv("ArrayRef<ValuePtr> {0};\n", name);
219ac68637bSLei Zhang   }
220ac68637bSLei Zhang   case Kind::Result: {
22131cfee60SLei Zhang     // Use the op itself for captured results.
222ac68637bSLei Zhang     return formatv("{0} {1};\n", op->getQualCppClassName(), name);
223ac68637bSLei Zhang   }
224ac68637bSLei Zhang   }
22512ff145eSjpienaar   llvm_unreachable("unknown kind");
226ac68637bSLei Zhang }
227ac68637bSLei Zhang 
22831cfee60SLei Zhang std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
22931cfee60SLei Zhang     StringRef name, int index, const char *fmt, const char *separator) const {
2301358df19SLei Zhang   LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
23131cfee60SLei Zhang   switch (kind) {
23231cfee60SLei Zhang   case Kind::Attr: {
23331cfee60SLei Zhang     assert(index < 0);
2341358df19SLei Zhang     auto repl = formatv(fmt, name);
2351358df19SLei Zhang     LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
2361358df19SLei Zhang     return repl;
23731cfee60SLei Zhang   }
23831cfee60SLei Zhang   case Kind::Operand: {
23931cfee60SLei Zhang     assert(index < 0);
24031cfee60SLei Zhang     auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
24131cfee60SLei Zhang     // If this operand is variadic, then return a range. Otherwise, return the
24231cfee60SLei Zhang     // value itself.
24331cfee60SLei Zhang     if (operand->isVariadic()) {
2441358df19SLei Zhang       auto repl = formatv(fmt, name);
2451358df19SLei Zhang       LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
2461358df19SLei Zhang       return repl;
24731cfee60SLei Zhang     }
2481358df19SLei Zhang     auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
2491358df19SLei Zhang     LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
2501358df19SLei Zhang     return repl;
25131cfee60SLei Zhang   }
25231cfee60SLei Zhang   case Kind::Result: {
25331cfee60SLei Zhang     // If `index` is greater than zero, then we are referencing a specific
25431cfee60SLei Zhang     // result of a multi-result op. The result can still be variadic.
25531cfee60SLei Zhang     if (index >= 0) {
25631cfee60SLei Zhang       std::string v = formatv("{0}.getODSResults({1})", name, index);
25731cfee60SLei Zhang       if (!op->getResult(index).isVariadic())
25831cfee60SLei Zhang         v = formatv("(*{0}.begin())", v);
2591358df19SLei Zhang       auto repl = formatv(fmt, v);
2601358df19SLei Zhang       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
2611358df19SLei Zhang       return repl;
26231cfee60SLei Zhang     }
26331cfee60SLei Zhang 
26423d21af6SLei Zhang     // If this op has no result at all but still we bind a symbol to it, it
26523d21af6SLei Zhang     // means we want to capture the op itself.
26623d21af6SLei Zhang     if (op->getNumResults() == 0) {
26723d21af6SLei Zhang       LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
26823d21af6SLei Zhang       return name;
26923d21af6SLei Zhang     }
27023d21af6SLei Zhang 
27131cfee60SLei Zhang     // We are referencing all results of the multi-result op. A specific result
27231cfee60SLei Zhang     // can either be a value or a range. Then join them with `separator`.
27331cfee60SLei Zhang     SmallVector<std::string, 4> values;
27431cfee60SLei Zhang     values.reserve(op->getNumResults());
27531cfee60SLei Zhang 
27631cfee60SLei Zhang     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
27731cfee60SLei Zhang       std::string v = formatv("{0}.getODSResults({1})", name, i);
27831cfee60SLei Zhang       if (!op->getResult(i).isVariadic()) {
27931cfee60SLei Zhang         v = formatv("(*{0}.begin())", v);
28031cfee60SLei Zhang       }
28131cfee60SLei Zhang       values.push_back(formatv(fmt, v));
28231cfee60SLei Zhang     }
2831358df19SLei Zhang     auto repl = llvm::join(values, separator);
2841358df19SLei Zhang     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
2851358df19SLei Zhang     return repl;
28631cfee60SLei Zhang   }
28731cfee60SLei Zhang   case Kind::Value: {
28831cfee60SLei Zhang     assert(index < 0);
28931cfee60SLei Zhang     assert(op == nullptr);
2901358df19SLei Zhang     auto repl = formatv(fmt, name);
2911358df19SLei Zhang     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
2921358df19SLei Zhang     return repl;
29331cfee60SLei Zhang   }
29431cfee60SLei Zhang   }
29594298ceaSLei Zhang   llvm_unreachable("unknown kind");
29631cfee60SLei Zhang }
29731cfee60SLei Zhang 
29831cfee60SLei Zhang std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse(
29931cfee60SLei Zhang     StringRef name, int index, const char *fmt, const char *separator) const {
3001358df19SLei Zhang   LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
301ac68637bSLei Zhang   switch (kind) {
302ac68637bSLei Zhang   case Kind::Attr:
303ac68637bSLei Zhang   case Kind::Operand: {
304ac68637bSLei Zhang     assert(index < 0 && "only allowed for symbol bound to result");
3051358df19SLei Zhang     auto repl = formatv(fmt, name);
3061358df19SLei Zhang     LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
3071358df19SLei Zhang     return repl;
308ac68637bSLei Zhang   }
309ac68637bSLei Zhang   case Kind::Result: {
310ac68637bSLei Zhang     if (index >= 0) {
3111358df19SLei Zhang       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
3121358df19SLei Zhang       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
3131358df19SLei Zhang       return repl;
314ac68637bSLei Zhang     }
315ac68637bSLei Zhang 
31631cfee60SLei Zhang     // We are referencing all results of the multi-result op. Each result should
31731cfee60SLei Zhang     // have a value range, and then join them with `separator`.
318ac68637bSLei Zhang     SmallVector<std::string, 4> values;
31931cfee60SLei Zhang     values.reserve(op->getNumResults());
32031cfee60SLei Zhang 
321ac68637bSLei Zhang     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
32231cfee60SLei Zhang       values.push_back(
32331cfee60SLei Zhang           formatv(fmt, formatv("{0}.getODSResults({1})", name, i)));
324ac68637bSLei Zhang     }
3251358df19SLei Zhang     auto repl = llvm::join(values, separator);
3261358df19SLei Zhang     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
3271358df19SLei Zhang     return repl;
328ac68637bSLei Zhang   }
329ac68637bSLei Zhang   case Kind::Value: {
330ac68637bSLei Zhang     assert(index < 0 && "only allowed for symbol bound to result");
331ac68637bSLei Zhang     assert(op == nullptr);
3321358df19SLei Zhang     auto repl = formatv(fmt, formatv("{{{0}}", name));
3331358df19SLei Zhang     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
3341358df19SLei Zhang     return repl;
335ac68637bSLei Zhang   }
336ac68637bSLei Zhang   }
33712ff145eSjpienaar   llvm_unreachable("unknown kind");
338ac68637bSLei Zhang }
339ac68637bSLei Zhang 
340ac68637bSLei Zhang bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
341ac68637bSLei Zhang                                            int argIndex) {
342ac68637bSLei Zhang   StringRef name = getValuePackName(symbol);
343ac68637bSLei Zhang   if (name != symbol) {
344ac68637bSLei Zhang     auto error = formatv(
345ac68637bSLei Zhang         "symbol '{0}' with trailing index cannot bind to op argument", symbol);
346ac68637bSLei Zhang     PrintFatalError(loc, error);
347ac68637bSLei Zhang   }
348ac68637bSLei Zhang 
349ac68637bSLei Zhang   auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
350ac68637bSLei Zhang                      ? SymbolInfo::getAttr(&op, argIndex)
351ac68637bSLei Zhang                      : SymbolInfo::getOperand(&op, argIndex);
352ac68637bSLei Zhang 
353ac68637bSLei Zhang   return symbolInfoMap.insert({symbol, symInfo}).second;
354ac68637bSLei Zhang }
355ac68637bSLei Zhang 
356ac68637bSLei Zhang bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
357ac68637bSLei Zhang   StringRef name = getValuePackName(symbol);
358ac68637bSLei Zhang   return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
359ac68637bSLei Zhang }
360ac68637bSLei Zhang 
361ac68637bSLei Zhang bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) {
362ac68637bSLei Zhang   return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
363ac68637bSLei Zhang }
364ac68637bSLei Zhang 
365ac68637bSLei Zhang bool tblgen::SymbolInfoMap::contains(StringRef symbol) const {
366ac68637bSLei Zhang   return find(symbol) != symbolInfoMap.end();
367ac68637bSLei Zhang }
368ac68637bSLei Zhang 
369ac68637bSLei Zhang tblgen::SymbolInfoMap::const_iterator
370ac68637bSLei Zhang tblgen::SymbolInfoMap::find(StringRef key) const {
371ac68637bSLei Zhang   StringRef name = getValuePackName(key);
372ac68637bSLei Zhang   return symbolInfoMap.find(name);
373ac68637bSLei Zhang }
374ac68637bSLei Zhang 
375ac68637bSLei Zhang int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
376ac68637bSLei Zhang   StringRef name = getValuePackName(symbol);
377ac68637bSLei Zhang   if (name != symbol) {
378ac68637bSLei Zhang     // If there is a trailing index inside symbol, it references just one
379ac68637bSLei Zhang     // static value.
380ac68637bSLei Zhang     return 1;
381ac68637bSLei Zhang   }
382ac68637bSLei Zhang   // Otherwise, find how many it represents by querying the symbol's info.
383ac68637bSLei Zhang   return find(name)->getValue().getStaticValueCount();
384ac68637bSLei Zhang }
385ac68637bSLei Zhang 
38631cfee60SLei Zhang std::string
38731cfee60SLei Zhang tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt,
38831cfee60SLei Zhang                                            const char *separator) const {
389ac68637bSLei Zhang   int index = -1;
390ac68637bSLei Zhang   StringRef name = getValuePackName(symbol, &index);
391ac68637bSLei Zhang 
392ac68637bSLei Zhang   auto it = symbolInfoMap.find(name);
393ac68637bSLei Zhang   if (it == symbolInfoMap.end()) {
394ac68637bSLei Zhang     auto error = formatv("referencing unbound symbol '{0}'", symbol);
395ac68637bSLei Zhang     PrintFatalError(loc, error);
396ac68637bSLei Zhang   }
397ac68637bSLei Zhang 
39831cfee60SLei Zhang   return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
39931cfee60SLei Zhang }
40031cfee60SLei Zhang 
40131cfee60SLei Zhang std::string tblgen::SymbolInfoMap::getAllRangeUse(StringRef symbol,
40231cfee60SLei Zhang                                                   const char *fmt,
40331cfee60SLei Zhang                                                   const char *separator) const {
40431cfee60SLei Zhang   int index = -1;
40531cfee60SLei Zhang   StringRef name = getValuePackName(symbol, &index);
40631cfee60SLei Zhang 
40731cfee60SLei Zhang   auto it = symbolInfoMap.find(name);
40831cfee60SLei Zhang   if (it == symbolInfoMap.end()) {
40931cfee60SLei Zhang     auto error = formatv("referencing unbound symbol '{0}'", symbol);
41031cfee60SLei Zhang     PrintFatalError(loc, error);
41131cfee60SLei Zhang   }
41231cfee60SLei Zhang 
41331cfee60SLei Zhang   return it->getValue().getAllRangeUse(name, index, fmt, separator);
414ac68637bSLei Zhang }
415ac68637bSLei Zhang 
416ac68637bSLei Zhang //===----------------------------------------------------------------------===//
417ac68637bSLei Zhang // Pattern
418ac68637bSLei Zhang //==----------------------------------------------------------------------===//
419ac68637bSLei Zhang 
420ac68637bSLei Zhang tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
421ac68637bSLei Zhang     : def(*def), recordOpMap(mapper) {}
422eb753f4aSLei Zhang 
423eb753f4aSLei Zhang tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
4248f5fa566SLei Zhang   return tblgen::DagNode(def.getValueAsDag("sourcePattern"));
425eb753f4aSLei Zhang }
426eb753f4aSLei Zhang 
427e032d0dcSLei Zhang int tblgen::Pattern::getNumResultPatterns() const {
4288f5fa566SLei Zhang   auto *results = def.getValueAsListInit("resultPatterns");
429eb753f4aSLei Zhang   return results->size();
430eb753f4aSLei Zhang }
431eb753f4aSLei Zhang 
432eb753f4aSLei Zhang tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
4338f5fa566SLei Zhang   auto *results = def.getValueAsListInit("resultPatterns");
434eb753f4aSLei Zhang   return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
435eb753f4aSLei Zhang }
436eb753f4aSLei Zhang 
437ac68637bSLei Zhang void tblgen::Pattern::collectSourcePatternBoundSymbols(
438ac68637bSLei Zhang     tblgen::SymbolInfoMap &infoMap) {
4391358df19SLei Zhang   LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
440ac68637bSLei Zhang   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
4411358df19SLei Zhang   LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
442eb753f4aSLei Zhang }
443eb753f4aSLei Zhang 
444ac68637bSLei Zhang void tblgen::Pattern::collectResultPatternBoundSymbols(
445ac68637bSLei Zhang     tblgen::SymbolInfoMap &infoMap) {
4461358df19SLei Zhang   LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
447ac68637bSLei Zhang   for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
448ac68637bSLei Zhang     auto pattern = getResultPattern(i);
449ac68637bSLei Zhang     collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
450eb753f4aSLei Zhang   }
4511358df19SLei Zhang   LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
452388fb375SJacques Pienaar }
453388fb375SJacques Pienaar 
454eb753f4aSLei Zhang const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
455eb753f4aSLei Zhang   return getSourcePattern().getDialectOp(recordOpMap);
456eb753f4aSLei Zhang }
457eb753f4aSLei Zhang 
458eb753f4aSLei Zhang tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) {
459eb753f4aSLei Zhang   return node.getDialectOp(recordOpMap);
460eb753f4aSLei Zhang }
461388fb375SJacques Pienaar 
4628f5fa566SLei Zhang std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const {
463388fb375SJacques Pienaar   auto *listInit = def.getValueAsListInit("constraints");
4648f5fa566SLei Zhang   std::vector<tblgen::AppliedConstraint> ret;
465388fb375SJacques Pienaar   ret.reserve(listInit->size());
4668f5fa566SLei Zhang 
467388fb375SJacques Pienaar   for (auto it : *listInit) {
4688f5fa566SLei Zhang     auto *dagInit = dyn_cast<llvm::DagInit>(it);
4698f5fa566SLei Zhang     if (!dagInit)
4708bfedb3cSKazuaki Ishizaki       PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity "
4718f5fa566SLei Zhang                                     "constraints should be DAG nodes");
4728f5fa566SLei Zhang 
4738f5fa566SLei Zhang     std::vector<std::string> entities;
4748f5fa566SLei Zhang     entities.reserve(dagInit->arg_size());
475cb40e36dSLei Zhang     for (auto *argName : dagInit->getArgNames()) {
476cb40e36dSLei Zhang       if (!argName) {
477cb40e36dSLei Zhang         PrintFatalError(
478cb40e36dSLei Zhang             def.getLoc(),
479cb40e36dSLei Zhang             "operands to additional constraints can only be symbol references");
480cb40e36dSLei Zhang       }
4818f5fa566SLei Zhang       entities.push_back(argName->getValue());
482cb40e36dSLei Zhang     }
4838f5fa566SLei Zhang 
4848f5fa566SLei Zhang     ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
485c72d849eSLei Zhang                      dagInit->getNameStr(), std::move(entities));
486388fb375SJacques Pienaar   }
487388fb375SJacques Pienaar   return ret;
488388fb375SJacques Pienaar }
48953035874SFeng Liu 
49053035874SFeng Liu int tblgen::Pattern::getBenefit() const {
491a0606ca7SFeng Liu   // The initial benefit value is a heuristic with number of ops in the source
49253035874SFeng Liu   // pattern.
493a0606ca7SFeng Liu   int initBenefit = getSourcePattern().getNumOps();
49453035874SFeng Liu   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
495a0606ca7SFeng Liu   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
496a0606ca7SFeng Liu     PrintFatalError(def.getLoc(),
497a0606ca7SFeng Liu                     "The 'addBenefit' takes and only takes one integer value");
498a0606ca7SFeng Liu   }
499a0606ca7SFeng Liu   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
50053035874SFeng Liu }
50104b6d2f3SLei Zhang 
5024165885aSJacques Pienaar std::vector<tblgen::Pattern::IdentifierLine>
5034165885aSJacques Pienaar tblgen::Pattern::getLocation() const {
5044165885aSJacques Pienaar   std::vector<std::pair<StringRef, unsigned>> result;
5054165885aSJacques Pienaar   result.reserve(def.getLoc().size());
5064165885aSJacques Pienaar   for (auto loc : def.getLoc()) {
5074165885aSJacques Pienaar     unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
5084165885aSJacques Pienaar     assert(buf && "invalid source location");
5094165885aSJacques Pienaar     result.emplace_back(
5104165885aSJacques Pienaar         llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
5114165885aSJacques Pienaar         llvm::SrcMgr.getLineAndColumn(loc, buf).first);
5124165885aSJacques Pienaar   }
5134165885aSJacques Pienaar   return result;
5144165885aSJacques Pienaar }
5154165885aSJacques Pienaar 
516ac68637bSLei Zhang void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
517e032d0dcSLei Zhang                                           bool isSrcPattern) {
518e032d0dcSLei Zhang   auto treeName = tree.getSymbol();
519e032d0dcSLei Zhang   if (!tree.isOperation()) {
520e032d0dcSLei Zhang     if (!treeName.empty()) {
521e032d0dcSLei Zhang       PrintFatalError(
522e032d0dcSLei Zhang           def.getLoc(),
523e032d0dcSLei Zhang           formatv("binding symbol '{0}' to non-operation unsupported right now",
524e032d0dcSLei Zhang                   treeName));
525e032d0dcSLei Zhang     }
526e032d0dcSLei Zhang     return;
527e032d0dcSLei Zhang   }
528e032d0dcSLei Zhang 
52904b6d2f3SLei Zhang   auto &op = getDialectOp(tree);
53004b6d2f3SLei Zhang   auto numOpArgs = op.getNumArgs();
53104b6d2f3SLei Zhang   auto numTreeArgs = tree.getNumArgs();
53204b6d2f3SLei Zhang 
53304b6d2f3SLei Zhang   if (numOpArgs != numTreeArgs) {
534ac68637bSLei Zhang     auto err = formatv("op '{0}' argument number mismatch: "
53504b6d2f3SLei Zhang                        "{1} in pattern vs. {2} in definition",
536ac68637bSLei Zhang                        op.getOperationName(), numTreeArgs, numOpArgs);
537ac68637bSLei Zhang     PrintFatalError(def.getLoc(), err);
53804b6d2f3SLei Zhang   }
53904b6d2f3SLei Zhang 
54004b6d2f3SLei Zhang   // The name attached to the DAG node's operator is for representing the
54104b6d2f3SLei Zhang   // results generated from this op. It should be remembered as bound results.
542ac68637bSLei Zhang   if (!treeName.empty()) {
5431358df19SLei Zhang     LLVM_DEBUG(llvm::dbgs()
5441358df19SLei Zhang                << "found symbol bound to op result: " << treeName << '\n');
545ac68637bSLei Zhang     if (!infoMap.bindOpResult(treeName, op))
546ac68637bSLei Zhang       PrintFatalError(def.getLoc(),
547ac68637bSLei Zhang                       formatv("symbol '{0}' bound more than once", treeName));
548ac68637bSLei Zhang   }
54904b6d2f3SLei Zhang 
5502fe8ae4fSJacques Pienaar   for (int i = 0; i != numTreeArgs; ++i) {
55104b6d2f3SLei Zhang     if (auto treeArg = tree.getArgAsNestedDag(i)) {
55204b6d2f3SLei Zhang       // This DAG node argument is a DAG node itself. Go inside recursively.
553ac68637bSLei Zhang       collectBoundSymbols(treeArg, infoMap, isSrcPattern);
554e032d0dcSLei Zhang     } else if (isSrcPattern) {
555e032d0dcSLei Zhang       // We can only bind symbols to op arguments in source pattern. Those
556e032d0dcSLei Zhang       // symbols are referenced in result patterns.
55704b6d2f3SLei Zhang       auto treeArgName = tree.getArgName(i);
5584982eaf8SLei Zhang       // `$_` is a special symbol meaning ignore the current argument.
5594982eaf8SLei Zhang       if (!treeArgName.empty() && treeArgName != "_") {
5601358df19SLei Zhang         LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
5611358df19SLei Zhang                                 << treeArgName << '\n');
562ac68637bSLei Zhang         if (!infoMap.bindOpArgument(treeArgName, op, i)) {
563ac68637bSLei Zhang           auto err = formatv("symbol '{0}' bound more than once", treeArgName);
564ac68637bSLei Zhang           PrintFatalError(def.getLoc(), err);
565ac68637bSLei Zhang         }
566ac68637bSLei Zhang       }
56704b6d2f3SLei Zhang     }
56804b6d2f3SLei Zhang   }
56904b6d2f3SLei Zhang }
570