xref: /llvm-project/mlir/include/mlir/TableGen/Pattern.h (revision 08d7377b67358496a409080fac22f3f7c077fb63)
1 //===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TABLEGEN_PATTERN_H_
15 #define MLIR_TABLEGEN_PATTERN_H_
16 
17 #include "mlir/Support/LLVM.h"
18 #include "mlir/TableGen/Argument.h"
19 #include "mlir/TableGen/Operator.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/Hashing.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringSet.h"
24 
25 #include <optional>
26 #include <unordered_map>
27 
28 namespace llvm {
29 class DagInit;
30 class Init;
31 class Record;
32 } // namespace llvm
33 
34 namespace mlir {
35 namespace tblgen {
36 
37 // Mapping from TableGen Record to Operator wrapper object.
38 //
39 // We allocate each wrapper object in heap to make sure the pointer to it is
40 // valid throughout the lifetime of this map. This is important because this map
41 // is shared among multiple patterns to avoid creating the wrapper object for
42 // the same op again and again. But this map will continuously grow.
43 using RecordOperatorMap =
44     DenseMap<const llvm::Record *, std::unique_ptr<Operator>>;
45 
46 class Pattern;
47 
48 // Wrapper class providing helper methods for accessing TableGen DAG leaves
49 // used inside Patterns. This class is lightweight and designed to be used like
50 // values.
51 //
52 // A TableGen DAG construct is of the syntax
53 //   `(operator, arg0, arg1, ...)`.
54 //
55 // This class provides getters to retrieve `arg*` as tblgen:: wrapper objects
56 // for handy helper methods. It only works on `arg*`s that are not nested DAG
57 // constructs.
58 class DagLeaf {
59 public:
DagLeaf(const llvm::Init * def)60   explicit DagLeaf(const llvm::Init *def) : def(def) {}
61 
62   // Returns true if this DAG leaf is not specified in the pattern. That is, it
63   // places no further constraints/transforms and just carries over the original
64   // value.
65   bool isUnspecified() const;
66 
67   // Returns true if this DAG leaf is matching an operand. That is, it specifies
68   // a type constraint.
69   bool isOperandMatcher() const;
70 
71   // Returns true if this DAG leaf is matching an attribute. That is, it
72   // specifies an attribute constraint.
73   bool isAttrMatcher() const;
74 
75   // Returns true if this DAG leaf is wrapping native code call.
76   bool isNativeCodeCall() const;
77 
78   // Returns true if this DAG leaf is specifying a constant attribute.
79   bool isConstantAttr() const;
80 
81   // Returns true if this DAG leaf is specifying an enum attribute case.
82   bool isEnumAttrCase() const;
83 
84   // Returns true if this DAG leaf is specifying a string attribute.
85   bool isStringAttr() const;
86 
87   // Returns this DAG leaf as a constraint. Asserts if fails.
88   Constraint getAsConstraint() const;
89 
90   // Returns this DAG leaf as an constant attribute. Asserts if fails.
91   ConstantAttr getAsConstantAttr() const;
92 
93   // Returns this DAG leaf as an enum attribute case.
94   // Precondition: isEnumAttrCase()
95   EnumAttrCase getAsEnumAttrCase() const;
96 
97   // Returns the matching condition template inside this DAG leaf. Assumes the
98   // leaf is an operand/attribute matcher and asserts otherwise.
99   std::string getConditionTemplate() const;
100 
101   // Returns the native code call template inside this DAG leaf.
102   // Precondition: isNativeCodeCall()
103   StringRef getNativeCodeTemplate() const;
104 
105   // Returns the number of values will be returned by the native helper
106   // function.
107   // Precondition: isNativeCodeCall()
108   int getNumReturnsOfNativeCode() const;
109 
110   // Returns the string associated with the leaf.
111   // Precondition: isStringAttr()
112   std::string getStringAttr() const;
113 
114   void print(raw_ostream &os) const;
115 
116 private:
117   friend llvm::DenseMapInfo<DagLeaf>;
getAsOpaquePointer()118   const void *getAsOpaquePointer() const { return def; }
119 
120   // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
121   // also a subclass of the given `superclass`.
122   bool isSubClassOf(StringRef superclass) const;
123 
124   const llvm::Init *def;
125 };
126 
127 // Wrapper class providing helper methods for accessing TableGen DAG constructs
128 // used inside Patterns. This class is lightweight and designed to be used like
129 // values.
130 //
131 // A TableGen DAG construct is of the syntax
132 //   `(operator, arg0, arg1, ...)`.
133 //
134 // When used inside Patterns, `operator` corresponds to some dialect op, or
135 // a known list of verbs that defines special transformation actions. This
136 // `arg*` can be a nested DAG construct. This class provides getters to
137 // retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper
138 // methods.
139 //
140 // A null DagNode contains a nullptr and converts to false implicitly.
141 class DagNode {
142 public:
DagNode(const llvm::DagInit * node)143   explicit DagNode(const llvm::DagInit *node) : node(node) {}
144 
145   // Implicit bool converter that returns true if this DagNode is not a null
146   // DagNode.
147   operator bool() const { return node != nullptr; }
148 
149   // Returns the symbol bound to this DAG node.
150   StringRef getSymbol() const;
151 
152   // Returns the operator wrapper object corresponding to the dialect op matched
153   // by this DAG. The operator wrapper will be queried from the given `mapper`
154   // and created in it if not existing.
155   Operator &getDialectOp(RecordOperatorMap *mapper) const;
156 
157   // Returns the number of operations recursively involved in the DAG tree
158   // rooted from this node.
159   int getNumOps() const;
160 
161   // Returns the number of immediate arguments to this DAG node.
162   int getNumArgs() const;
163 
164   // Returns true if the `index`-th argument is a nested DAG construct.
165   bool isNestedDagArg(unsigned index) const;
166 
167   // Gets the `index`-th argument as a nested DAG construct if possible. Returns
168   // null DagNode otherwise.
169   DagNode getArgAsNestedDag(unsigned index) const;
170 
171   // Gets the `index`-th argument as a DAG leaf.
172   DagLeaf getArgAsLeaf(unsigned index) const;
173 
174   // Returns the specified name of the `index`-th argument.
175   StringRef getArgName(unsigned index) const;
176 
177   // Returns true if this DAG construct means to replace with an existing SSA
178   // value.
179   bool isReplaceWithValue() const;
180 
181   // Returns whether this DAG represents the location of an op creation.
182   bool isLocationDirective() const;
183 
184   // Returns whether this DAG is a return type specifier.
185   bool isReturnTypeDirective() const;
186 
187   // Returns true if this DAG node is wrapping native code call.
188   bool isNativeCodeCall() const;
189 
190   // Returns whether this DAG is an `either` specifier.
191   bool isEither() const;
192 
193   // Returns whether this DAG is an `variadic` specifier.
194   bool isVariadic() const;
195 
196   // Returns true if this DAG node is an operation.
197   bool isOperation() const;
198 
199   // Returns the native code call template inside this DAG node.
200   // Precondition: isNativeCodeCall()
201   StringRef getNativeCodeTemplate() const;
202 
203   // Returns the number of values will be returned by the native helper
204   // function.
205   // Precondition: isNativeCodeCall()
206   int getNumReturnsOfNativeCode() const;
207 
208   void print(raw_ostream &os) const;
209 
210 private:
211   friend class SymbolInfoMap;
212   friend llvm::DenseMapInfo<DagNode>;
getAsOpaquePointer()213   const void *getAsOpaquePointer() const { return node; }
214 
215   const llvm::DagInit *node; // nullptr means null DagNode
216 };
217 
218 // A class for maintaining information for symbols bound in patterns and
219 // provides methods for resolving them according to specific use cases.
220 //
221 // Symbols can be bound to
222 //
223 // * Op arguments and op results in the source pattern and
224 // * Op results in result patterns.
225 //
226 // Symbols can be referenced in result patterns and additional constraints to
227 // the pattern.
228 //
229 // For example, in
230 //
231 // ```
232 // def : Pattern<
233 //     (SrcOp:$results1 $arg0, %arg1),
234 //     [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>;
235 // ```
236 //
237 // `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to
238 // `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build
239 // `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`.
240 //
241 // If a symbol binds to a multi-result op and it does not have the `__N`
242 // suffix, the symbol is expanded to represent all results generated by the
243 // multi-result op. If the symbol has a `__N` suffix, then it will expand to
244 // only the N-th *static* result as declared in ODS, and that can still
245 // corresponds to multiple *dynamic* values if the N-th *static* result is
246 // variadic.
247 //
248 // This class keeps track of such symbols and resolves them into their bound
249 // values in a suitable way.
250 class SymbolInfoMap {
251 public:
SymbolInfoMap(ArrayRef<SMLoc> loc)252   explicit SymbolInfoMap(ArrayRef<SMLoc> loc) : loc(loc) {}
253 
254   // Class for information regarding a symbol.
255   class SymbolInfo {
256   public:
257     // Returns a type string of a variable.
258     std::string getVarTypeStr(StringRef name) const;
259 
260     // Returns a string for defining a variable named as `name` to store the
261     // value bound by this symbol.
262     std::string getVarDecl(StringRef name) const;
263 
264     // Returns a string for defining an argument which passes the reference of
265     // the variable.
266     std::string getArgDecl(StringRef name) const;
267 
268     // Returns a variable name for the symbol named as `name`.
269     std::string getVarName(StringRef name) const;
270 
271   private:
272     // Allow SymbolInfoMap to access private methods.
273     friend class SymbolInfoMap;
274 
275     // Structure to uniquely distinguish different locations of the symbols.
276     //
277     // * If a symbol is defined as an operand of an operation, `dag` specifies
278     //   the DAG of the operation, `operandIndexOrNumValues` specifies the
279     //   operand index, and `variadicSubIndex` must be set to `std::nullopt`.
280     //
281     // * If a symbol is defined in a `variadic` DAG, `dag` specifies the DAG
282     //   of the parent operation, `operandIndexOrNumValues` specifies the
283     //   declared operand index of the variadic operand in the parent
284     //   operation.
285     //
286     //   - If the symbol is defined as a result of `variadic` DAG, the
287     //     `variadicSubIndex` must be set to `std::nullopt`, which means that
288     //     the symbol binds to the full operand range.
289     //
290     //   - If the symbol is defined as a operand, the `variadicSubIndex` must
291     //     be set to the index within the variadic sub-operand list.
292     //
293     // * If a symbol is defined in a `either` DAG, `dag` specifies the DAG
294     //   of the parent operation, `operandIndexOrNumValues` specifies the
295     //   operand index in the parent operation (not necessary the index in the
296     //   DAG).
297     //
298     // * If a symbol is defined as a result, specifies the number of returning
299     //   value.
300     //
301     // Example 1:
302     //
303     //   def : Pat<(OpA $input0, $input1), ...>;
304     //
305     //   $input0: (OpA, 0, nullopt)
306     //   $input1: (OpA, 1, nullopt)
307     //
308     // Example 2:
309     //
310     //   def : Pat<(OpB (variadic:$input0 $input0a, $input0b),
311     //                  (variadic:$input1 $input1a, $input1b, $input1c)),
312     //             ...>;
313     //
314     //   $input0:  (OpB, 0, nullopt)
315     //   $input0a: (OpB, 0, 0)
316     //   $input0b: (OpB, 0, 1)
317     //   $input1:  (OpB, 1, nullopt)
318     //   $input1a: (OpB, 1, 0)
319     //   $input1b: (OpB, 1, 1)
320     //   $input1c: (OpB, 1, 2)
321     //
322     // Example 3:
323     //
324     //   def : Pat<(OpC $input0, (either $input1, $input2)), ...>;
325     //
326     //   $input0: (OpC, 0, nullopt)
327     //   $input1: (OpC, 1, nullopt)
328     //   $input2: (OpC, 2, nullopt)
329     //
330     // Example 4:
331     //
332     //   def ThreeResultOp : TEST_Op<...> {
333     //     let results = (outs
334     //       AnyType:$result1,
335     //       AnyType:$result2,
336     //       AnyType:$result3
337     //     );
338     //   }
339     //
340     //   def : Pat<...,
341     //             (ThreeResultOp:$result ...)>;
342     //
343     //   $result: (nullptr, 3, nullopt)
344     //
345     struct DagAndConstant {
346       // DagNode and DagLeaf are accessed by value which means it can't be used
347       // as identifier here. Use an opaque pointer type instead.
348       const void *dag;
349       int operandIndexOrNumValues;
350       std::optional<int> variadicSubIndex;
351 
DagAndConstantDagAndConstant352       DagAndConstant(const void *dag, int operandIndexOrNumValues,
353                      std::optional<int> variadicSubIndex)
354           : dag(dag), operandIndexOrNumValues(operandIndexOrNumValues),
355             variadicSubIndex(variadicSubIndex) {}
356 
357       bool operator==(const DagAndConstant &rhs) const {
358         return dag == rhs.dag &&
359                operandIndexOrNumValues == rhs.operandIndexOrNumValues &&
360                variadicSubIndex == rhs.variadicSubIndex;
361       }
362     };
363 
364     // What kind of entity this symbol represents:
365     // * Attr: op attribute
366     // * Operand: op operand
367     // * Result: op result
368     // * Value: a value not attached to an op (e.g., from NativeCodeCall)
369     // * MultipleValues: a pack of values not attached to an op (e.g., from
370     //   NativeCodeCall). This kind supports indexing.
371     enum class Kind : uint8_t { Attr, Operand, Result, Value, MultipleValues };
372 
373     // Creates a SymbolInfo instance. `dagAndConstant` is only used for `Attr`
374     // and `Operand` so should be std::nullopt for `Result` and `Value` kind.
375     SymbolInfo(const Operator *op, Kind kind,
376                std::optional<DagAndConstant> dagAndConstant);
377 
378     // Static methods for creating SymbolInfo.
getAttr(const Operator * op,int index)379     static SymbolInfo getAttr(const Operator *op, int index) {
380       return SymbolInfo(op, Kind::Attr,
381                         DagAndConstant(nullptr, index, std::nullopt));
382     }
getAttr()383     static SymbolInfo getAttr() {
384       return SymbolInfo(nullptr, Kind::Attr, std::nullopt);
385     }
386     static SymbolInfo
387     getOperand(DagNode node, const Operator *op, int operandIndex,
388                std::optional<int> variadicSubIndex = std::nullopt) {
389       return SymbolInfo(op, Kind::Operand,
390                         DagAndConstant(node.getAsOpaquePointer(), operandIndex,
391                                        variadicSubIndex));
392     }
getResult(const Operator * op)393     static SymbolInfo getResult(const Operator *op) {
394       return SymbolInfo(op, Kind::Result, std::nullopt);
395     }
getValue()396     static SymbolInfo getValue() {
397       return SymbolInfo(nullptr, Kind::Value, std::nullopt);
398     }
getMultipleValues(int numValues)399     static SymbolInfo getMultipleValues(int numValues) {
400       return SymbolInfo(nullptr, Kind::MultipleValues,
401                         DagAndConstant(nullptr, numValues, std::nullopt));
402     }
403 
404     // Returns the number of static values this symbol corresponds to.
405     // A static value is an operand/result declared in ODS. Normally a symbol
406     // only represents one static value, but symbols bound to op results can
407     // represent more than one if the op is a multi-result op.
408     int getStaticValueCount() const;
409 
410     // Returns a string containing the C++ expression for referencing this
411     // symbol as a value (if this symbol represents one static value) or a value
412     // range (if this symbol represents multiple static values). `name` is the
413     // name of the C++ variable that this symbol bounds to. `index` should only
414     // be used for indexing results.  `fmt` is used to format each value.
415     // `separator` is used to separate values if this is a value range.
416     std::string getValueAndRangeUse(StringRef name, int index, const char *fmt,
417                                     const char *separator) const;
418 
419     // Returns a string containing the C++ expression for referencing this
420     // symbol as a value range regardless of how many static values this symbol
421     // represents. `name` is the name of the C++ variable that this symbol
422     // bounds to. `index` should only be used for indexing results. `fmt` is
423     // used to format each value. `separator` is used to separate values in the
424     // range.
425     std::string getAllRangeUse(StringRef name, int index, const char *fmt,
426                                const char *separator) const;
427 
428     // The argument index (for `Attr` and `Operand` only)
getArgIndex()429     int getArgIndex() const { return dagAndConstant->operandIndexOrNumValues; }
430 
431     // The number of values in the MultipleValue
getSize()432     int getSize() const { return dagAndConstant->operandIndexOrNumValues; }
433 
434     // The variadic sub-operands index (for variadic `Operand` only)
getVariadicSubIndex()435     std::optional<int> getVariadicSubIndex() const {
436       return dagAndConstant->variadicSubIndex;
437     }
438 
439     const Operator *op; // The op where the bound entity belongs
440     Kind kind;          // The kind of the bound entity
441 
442     // The tuple of DagNode pointer and two constant values (for `Attr`,
443     // `Operand` and the size of MultipleValue symbol). Note that operands may
444     // be bound to the same symbol, use the DagNode and index to distinguish
445     // them. For `Attr` and MultipleValue, the Dag part will be nullptr.
446     std::optional<DagAndConstant> dagAndConstant;
447 
448     // Alternative name for the symbol. It is used in case the name
449     // is not unique. Applicable for `Operand` only.
450     std::optional<std::string> alternativeName;
451   };
452 
453   using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
454 
455   // Iterators for accessing all symbols.
456   using iterator = BaseT::iterator;
begin()457   iterator begin() { return symbolInfoMap.begin(); }
end()458   iterator end() { return symbolInfoMap.end(); }
459 
460   // Const iterators for accessing all symbols.
461   using const_iterator = BaseT::const_iterator;
begin()462   const_iterator begin() const { return symbolInfoMap.begin(); }
end()463   const_iterator end() const { return symbolInfoMap.end(); }
464 
465   // Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
466   // Returns false if `symbol` is already bound and symbols are not operands.
467   bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op,
468                       int argIndex,
469                       std::optional<int> variadicSubIndex = std::nullopt);
470 
471   // Binds the given `symbol` to the results the given `op`. Returns false if
472   // `symbol` is already bound.
473   bool bindOpResult(StringRef symbol, const Operator &op);
474 
475   // A helper function for dispatching target value binding functions.
476   bool bindValues(StringRef symbol, int numValues = 1);
477 
478   // Registers the given `symbol` as bound to the Value(s). Returns false if
479   // `symbol` is already bound.
480   bool bindValue(StringRef symbol);
481 
482   // Registers the given `symbol` as bound to a MultipleValue. Return false if
483   // `symbol` is already bound.
484   bool bindMultipleValues(StringRef symbol, int numValues);
485 
486   // Registers the given `symbol` as bound to an attr. Returns false if `symbol`
487   // is already bound.
488   bool bindAttr(StringRef symbol);
489 
490   // Returns true if the given `symbol` is bound.
491   bool contains(StringRef symbol) const;
492 
493   // Returns an iterator to the information of the given symbol named as `key`.
494   const_iterator find(StringRef key) const;
495 
496   // Returns an iterator to the information of the given symbol named as `key`,
497   // with index `argIndex` for operator `op`.
498   const_iterator findBoundSymbol(StringRef key, DagNode node,
499                                  const Operator &op, int argIndex,
500                                  std::optional<int> variadicSubIndex) const;
501   const_iterator findBoundSymbol(StringRef key,
502                                  const SymbolInfo &symbolInfo) const;
503 
504   // Returns the bounds of a range that includes all the elements which
505   // bind to the `key`.
506   std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);
507 
508   // Returns number of times symbol named as `key` was used.
509   int count(StringRef key) const;
510 
511   // Returns the number of static values of the given `symbol` corresponds to.
512   // A static value is an operand/result declared in ODS. Normally a symbol only
513   // represents one static value, but symbols bound to op results can represent
514   // more than one if the op is a multi-result op.
515   int getStaticValueCount(StringRef symbol) const;
516 
517   // Returns a string containing the C++ expression for referencing this
518   // symbol as a value (if this symbol represents one static value) or a value
519   // range (if this symbol represents multiple static values). `fmt` is used to
520   // format each value. `separator` is used to separate values if `symbol`
521   // represents a value range.
522   std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}",
523                                   const char *separator = ", ") const;
524 
525   // Returns a string containing the C++ expression for referencing this
526   // symbol as a value range regardless of how many static values this symbol
527   // represents. `fmt` is used to format each value. `separator` is used to
528   // separate values in the range.
529   std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
530                              const char *separator = ", ") const;
531 
532   // Assign alternative unique names to Operands that have equal names.
533   void assignUniqueAlternativeNames();
534 
535   // Splits the given `symbol` into a value pack name and an index. Returns the
536   // value pack name and writes the index to `index` on success. Returns
537   // `symbol` itself if it does not contain an index.
538   //
539   // We can use `name__N` to access the `N`-th value in the value pack bound to
540   // `name`. `name` is typically the results of an multi-result op.
541   static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
542 
543 private:
544   BaseT symbolInfoMap;
545 
546   // Pattern instantiation location. This is intended to be used as parameter
547   // to PrintFatalError() to report errors.
548   ArrayRef<SMLoc> loc;
549 };
550 
551 // Wrapper class providing helper methods for accessing MLIR Pattern defined
552 // in TableGen. This class should closely reflect what is defined as class
553 // `Pattern` in TableGen. This class contains maps so it is not intended to be
554 // used as values.
555 class Pattern {
556 public:
557   explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper);
558 
559   // Returns the source pattern to match.
560   DagNode getSourcePattern() const;
561 
562   // Returns the number of result patterns generated by applying this rewrite
563   // rule.
564   int getNumResultPatterns() const;
565 
566   // Returns the DAG tree root node of the `index`-th result pattern.
567   DagNode getResultPattern(unsigned index) const;
568 
569   // Collects all symbols bound in the source pattern into `infoMap`.
570   void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap);
571 
572   // Collects all symbols bound in result patterns into `infoMap`.
573   void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap);
574 
575   // Returns the op that the root node of the source pattern matches.
576   const Operator &getSourceRootOp();
577 
578   // Returns the operator wrapper object corresponding to the given `node`'s DAG
579   // operator.
580   Operator &getDialectOp(DagNode node);
581 
582   // Returns the constraints.
583   std::vector<AppliedConstraint> getConstraints() const;
584 
585   // Returns the number of supplemental auxiliary patterns generated by applying
586   // this rewrite rule.
587   int getNumSupplementalPatterns() const;
588 
589   // Returns the DAG tree root node of the `index`-th supplemental result
590   // pattern.
591   DagNode getSupplementalPattern(unsigned index) const;
592 
593   // Returns the benefit score of the pattern.
594   int getBenefit() const;
595 
596   using IdentifierLine = std::pair<StringRef, unsigned>;
597 
598   // Returns the file location of the pattern (buffer identifier + line number
599   // pair).
600   std::vector<IdentifierLine> getLocation() const;
601 
602   // Recursively collects all bound symbols inside the DAG tree rooted
603   // at `tree` and updates the given `infoMap`.
604   void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
605                            bool isSrcPattern);
606 
607 private:
608   // Helper function to verify variable binding.
609   void verifyBind(bool result, StringRef symbolName);
610 
611   // The TableGen definition of this pattern.
612   const llvm::Record &def;
613 
614   // All operators.
615   // TODO: we need a proper context manager, like MLIRContext, for managing the
616   // lifetime of shared entities.
617   RecordOperatorMap *recordOpMap;
618 };
619 
620 } // namespace tblgen
621 } // namespace mlir
622 
623 namespace llvm {
624 template <>
625 struct DenseMapInfo<mlir::tblgen::DagNode> {
626   static mlir::tblgen::DagNode getEmptyKey() {
627     return mlir::tblgen::DagNode(
628         llvm::DenseMapInfo<llvm::DagInit *>::getEmptyKey());
629   }
630   static mlir::tblgen::DagNode getTombstoneKey() {
631     return mlir::tblgen::DagNode(
632         llvm::DenseMapInfo<llvm::DagInit *>::getTombstoneKey());
633   }
634   static unsigned getHashValue(mlir::tblgen::DagNode node) {
635     return llvm::hash_value(node.getAsOpaquePointer());
636   }
637   static bool isEqual(mlir::tblgen::DagNode lhs, mlir::tblgen::DagNode rhs) {
638     return lhs.node == rhs.node;
639   }
640 };
641 
642 template <>
643 struct DenseMapInfo<mlir::tblgen::DagLeaf> {
644   static mlir::tblgen::DagLeaf getEmptyKey() {
645     return mlir::tblgen::DagLeaf(
646         llvm::DenseMapInfo<llvm::Init *>::getEmptyKey());
647   }
648   static mlir::tblgen::DagLeaf getTombstoneKey() {
649     return mlir::tblgen::DagLeaf(
650         llvm::DenseMapInfo<llvm::Init *>::getTombstoneKey());
651   }
652   static unsigned getHashValue(mlir::tblgen::DagLeaf leaf) {
653     return llvm::hash_value(leaf.getAsOpaquePointer());
654   }
655   static bool isEqual(mlir::tblgen::DagLeaf lhs, mlir::tblgen::DagLeaf rhs) {
656     return lhs.def == rhs.def;
657   }
658 };
659 } // namespace llvm
660 
661 #endif // MLIR_TABLEGEN_PATTERN_H_
662