xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 0a34492f36d77f043d371cc91f359b2d65e86475)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/TableGen/Pattern.h"
19 #include "mlir/TableGen/Predicate.h"
20 #include "mlir/TableGen/Type.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringSet.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatAdapters.h"
26 #include "llvm/Support/PrettyStackTrace.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Main.h"
30 #include "llvm/TableGen/Record.h"
31 #include "llvm/TableGen/TableGenBackend.h"
32 
33 using namespace mlir;
34 using namespace mlir::tblgen;
35 
36 using llvm::formatv;
37 using llvm::Record;
38 using llvm::RecordKeeper;
39 
40 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
41 
42 namespace llvm {
43 template <>
44 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
45   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
46                      raw_ostream &os, StringRef style) {
47     os << v.first << ":" << v.second;
48   }
49 };
50 } // end namespace llvm
51 
52 //===----------------------------------------------------------------------===//
53 // PatternEmitter
54 //===----------------------------------------------------------------------===//
55 
56 namespace {
57 class PatternEmitter {
58 public:
59   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
60 
61   // Emits the mlir::RewritePattern struct named `rewriteName`.
62   void emit(StringRef rewriteName);
63 
64 private:
65   // Emits the code for matching ops.
66   void emitMatchLogic(DagNode tree);
67 
68   // Emits the code for rewriting ops.
69   void emitRewriteLogic();
70 
71   //===--------------------------------------------------------------------===//
72   // Match utilities
73   //===--------------------------------------------------------------------===//
74 
75   // Emits C++ statements for matching the op constrained by the given DAG
76   // `tree`.
77   void emitOpMatch(DagNode tree, int depth);
78 
79   // Emits C++ statements for matching the `argIndex`-th argument of the given
80   // DAG `tree` as an operand.
81   void emitOperandMatch(DagNode tree, int argIndex, int depth);
82 
83   // Emits C++ statements for matching the `argIndex`-th argument of the given
84   // DAG `tree` as an attribute.
85   void emitAttributeMatch(DagNode tree, int argIndex, int depth);
86 
87   // Emits C++ for checking a match with a corresponding match failure
88   // diagnostic.
89   void emitMatchCheck(int depth, const FmtObjectBase &matchFmt,
90                       const llvm::formatv_object_base &failureFmt);
91 
92   //===--------------------------------------------------------------------===//
93   // Rewrite utilities
94   //===--------------------------------------------------------------------===//
95 
96   // The entry point for handling a result pattern rooted at `resultTree`. This
97   // method dispatches to concrete handlers according to `resultTree`'s kind and
98   // returns a symbol representing the whole value pack. Callers are expected to
99   // further resolve the symbol according to the specific use case.
100   //
101   // `depth` is the nesting level of `resultTree`; 0 means top-level result
102   // pattern. For top-level result pattern, `resultIndex` indicates which result
103   // of the matched root op this pattern is intended to replace, which can be
104   // used to deduce the result type of the op generated from this result
105   // pattern.
106   std::string handleResultPattern(DagNode resultTree, int resultIndex,
107                                   int depth);
108 
109   // Emits the C++ statement to replace the matched DAG with a value built via
110   // calling native C++ code.
111   std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
112 
113   // Returns the symbol of the old value serving as the replacement.
114   StringRef handleReplaceWithValue(DagNode tree);
115 
116   // Returns the location value to use.
117   std::pair<bool, std::string> getLocation(DagNode tree);
118 
119   // Returns the location value to use.
120   std::string handleLocationDirective(DagNode tree);
121 
122   // Emits the C++ statement to build a new op out of the given DAG `tree` and
123   // returns the variable name that this op is assigned to. If the root op in
124   // DAG `tree` has a specified name, the created op will be assigned to a
125   // variable of the given name. Otherwise, a unique name will be used as the
126   // result value name.
127   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
128 
129   using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
130 
131   // Emits a local variable for each value and attribute to be used for creating
132   // an op.
133   void createSeparateLocalVarsForOpArgs(DagNode node,
134                                         ChildNodeIndexNameMap &childNodeNames);
135 
136   // Emits the concrete arguments used to call an op's builder.
137   void supplyValuesForOpArgs(DagNode node,
138                              const ChildNodeIndexNameMap &childNodeNames);
139 
140   // Emits the local variables for holding all values as a whole and all named
141   // attributes as a whole to be used for creating an op.
142   void createAggregateLocalVarsForOpArgs(
143       DagNode node, const ChildNodeIndexNameMap &childNodeNames);
144 
145   // Returns the C++ expression to construct a constant attribute of the given
146   // `value` for the given attribute kind `attr`.
147   std::string handleConstantAttr(Attribute attr, StringRef value);
148 
149   // Returns the C++ expression to build an argument from the given DAG `leaf`.
150   // `patArgName` is used to bound the argument to the source pattern.
151   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
152 
153   //===--------------------------------------------------------------------===//
154   // General utilities
155   //===--------------------------------------------------------------------===//
156 
157   // Collects all of the operations within the given dag tree.
158   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
159 
160   // Returns a unique symbol for a local variable of the given `op`.
161   std::string getUniqueSymbol(const Operator *op);
162 
163   //===--------------------------------------------------------------------===//
164   // Symbol utilities
165   //===--------------------------------------------------------------------===//
166 
167   // Returns how many static values the given DAG `node` correspond to.
168   int getNodeValueCount(DagNode node);
169 
170 private:
171   // Pattern instantiation location followed by the location of multiclass
172   // prototypes used. This is intended to be used as a whole to
173   // PrintFatalError() on errors.
174   ArrayRef<llvm::SMLoc> loc;
175 
176   // Op's TableGen Record to wrapper object.
177   RecordOperatorMap *opMap;
178 
179   // Handy wrapper for pattern being emitted.
180   Pattern pattern;
181 
182   // Map for all bound symbols' info.
183   SymbolInfoMap symbolInfoMap;
184 
185   // The next unused ID for newly created values.
186   unsigned nextValueId;
187 
188   raw_indented_ostream os;
189 
190   // Format contexts containing placeholder substitutions.
191   FmtContext fmtCtx;
192 
193   // Number of op processed.
194   int opCounter = 0;
195 };
196 } // end anonymous namespace
197 
198 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
199                                raw_ostream &os)
200     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
201       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
202   fmtCtx.withBuilder("rewriter");
203 }
204 
205 std::string PatternEmitter::handleConstantAttr(Attribute attr,
206                                                StringRef value) {
207   if (!attr.isConstBuildable())
208     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
209                              " does not have the 'constBuilderCall' field");
210 
211   // TODO: Verify the constants here
212   return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
213 }
214 
215 // Helper function to match patterns.
216 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
217   Operator &op = tree.getDialectOp(opMap);
218   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
219                           << op.getOperationName() << "' at depth " << depth
220                           << '\n');
221 
222   int indent = 4 + 2 * depth;
223   os.indent(indent) << formatv(
224       "auto castedOp{0} = ::llvm::dyn_cast_or_null<{1}>(op{0}); "
225       "(void)castedOp{0};\n",
226       depth, op.getQualCppClassName());
227   // Skip the operand matching at depth 0 as the pattern rewriter already does.
228   if (depth != 0) {
229     // Skip if there is no defining operation (e.g., arguments to function).
230     os << formatv("if (!castedOp{0})\n  return failure();\n", depth);
231   }
232   if (tree.getNumArgs() != op.getNumArgs()) {
233     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
234                                  "pattern vs. {2} in definition",
235                                  op.getOperationName(), tree.getNumArgs(),
236                                  op.getNumArgs()));
237   }
238 
239   // If the operand's name is set, set to that variable.
240   auto name = tree.getSymbol();
241   if (!name.empty())
242     os << formatv("{0} = castedOp{1};\n", name, depth);
243 
244   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
245     auto opArg = op.getArg(i);
246 
247     // Handle nested DAG construct first
248     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
249       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
250         if (operand->isVariableLength()) {
251           auto error = formatv("use nested DAG construct to match op {0}'s "
252                                "variadic operand #{1} unsupported now",
253                                op.getOperationName(), i);
254           PrintFatalError(loc, error);
255         }
256       }
257       os << "{\n";
258 
259       os.indent() << formatv(
260           "auto *op{0} = "
261           "(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
262           depth + 1, depth, i);
263       emitOpMatch(argTree, depth + 1);
264       os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
265       os.unindent() << "}\n";
266       continue;
267     }
268 
269     // Next handle DAG leaf: operand or attribute
270     if (opArg.is<NamedTypeConstraint *>()) {
271       emitOperandMatch(tree, i, depth);
272     } else if (opArg.is<NamedAttribute *>()) {
273       emitAttributeMatch(tree, i, depth);
274     } else {
275       PrintFatalError(loc, "unhandled case when matching op");
276     }
277   }
278   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
279                           << op.getOperationName() << "' at depth " << depth
280                           << '\n');
281 }
282 
283 void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
284   Operator &op = tree.getDialectOp(opMap);
285   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
286   auto matcher = tree.getArgAsLeaf(argIndex);
287 
288   // If a constraint is specified, we need to generate C++ statements to
289   // check the constraint.
290   if (!matcher.isUnspecified()) {
291     if (!matcher.isOperandMatcher()) {
292       PrintFatalError(
293           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
294                        op.getOperationName(), argIndex + 1));
295     }
296 
297     // Only need to verify if the matcher's type is different from the one
298     // of op definition.
299     Constraint constraint = matcher.getAsConstraint();
300     if (operand->constraint != constraint) {
301       if (operand->isVariableLength()) {
302         auto error = formatv(
303             "further constrain op {0}'s variadic operand #{1} unsupported now",
304             op.getOperationName(), argIndex);
305         PrintFatalError(loc, error);
306       }
307       auto self =
308           formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth,
309                   argIndex);
310       emitMatchCheck(
311           depth,
312           tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
313           formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
314                   "'{2}'\"",
315                   operand - op.operand_begin(), op.getOperationName(),
316                   constraint.getDescription()));
317     }
318   }
319 
320   // Capture the value
321   auto name = tree.getArgName(argIndex);
322   // `$_` is a special symbol to ignore op argument matching.
323   if (!name.empty() && name != "_") {
324     // We need to subtract the number of attributes before this operand to get
325     // the index in the operand list.
326     auto numPrevAttrs = std::count_if(
327         op.arg_begin(), op.arg_begin() + argIndex,
328         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
329 
330     os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth,
331                   argIndex - numPrevAttrs);
332   }
333 }
334 
335 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
336   Operator &op = tree.getDialectOp(opMap);
337   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
338   const auto &attr = namedAttr->attr;
339 
340   os << "{\n";
341   os.indent() << formatv(
342       "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); "
343       "(void)tblgen_attr;\n",
344       depth, attr.getStorageType(), namedAttr->name);
345 
346   // TODO: This should use getter method to avoid duplication.
347   if (attr.hasDefaultValue()) {
348     os << "if (!tblgen_attr) tblgen_attr = "
349        << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
350                             attr.getDefaultValue()))
351        << ";\n";
352   } else if (attr.isOptional()) {
353     // For a missing attribute that is optional according to definition, we
354     // should just capture a mlir::Attribute() to signal the missing state.
355     // That is precisely what getAttr() returns on missing attributes.
356   } else {
357     emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx),
358                    formatv("\"expected op '{0}' to have attribute '{1}' "
359                            "of type '{2}'\"",
360                            op.getOperationName(), namedAttr->name,
361                            attr.getStorageType()));
362   }
363 
364   auto matcher = tree.getArgAsLeaf(argIndex);
365   if (!matcher.isUnspecified()) {
366     if (!matcher.isAttrMatcher()) {
367       PrintFatalError(
368           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
369                        op.getOperationName(), argIndex + 1));
370     }
371 
372     // If a constraint is specified, we need to generate C++ statements to
373     // check the constraint.
374     emitMatchCheck(
375         depth,
376         tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
377         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
378                 "{2}\"",
379                 op.getOperationName(), namedAttr->name,
380                 matcher.getAsConstraint().getDescription()));
381   }
382 
383   // Capture the value
384   auto name = tree.getArgName(argIndex);
385   // `$_` is a special symbol to ignore op argument matching.
386   if (!name.empty() && name != "_") {
387     os << formatv("{0} = tblgen_attr;\n", name);
388   }
389 
390   os.unindent() << "}\n";
391 }
392 
393 void PatternEmitter::emitMatchCheck(
394     int depth, const FmtObjectBase &matchFmt,
395     const llvm::formatv_object_base &failureFmt) {
396   os << "if (!(" << matchFmt.str() << "))";
397   os.scope("{\n", "\n}\n").os
398       << "return rewriter.notifyMatchFailure(op" << depth
399       << ", [&](::mlir::Diagnostic &diag) {\n  diag << " << failureFmt.str()
400       << ";\n});";
401 }
402 
403 void PatternEmitter::emitMatchLogic(DagNode tree) {
404   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
405   int depth = 0;
406   emitOpMatch(tree, depth);
407 
408   for (auto &appliedConstraint : pattern.getConstraints()) {
409     auto &constraint = appliedConstraint.constraint;
410     auto &entities = appliedConstraint.entities;
411 
412     auto condition = constraint.getConditionTemplate();
413     if (isa<TypeConstraint>(constraint)) {
414       auto self = formatv("({0}.getType())",
415                           symbolInfoMap.getValueAndRangeUse(entities.front()));
416       emitMatchCheck(
417           depth, tgfmt(condition, &fmtCtx.withSelf(self.str())),
418           formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
419                   entities.front(), constraint.getDescription()));
420 
421     } else if (isa<AttrConstraint>(constraint)) {
422       PrintFatalError(
423           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
424     } else {
425       // TODO: replace formatv arguments with the exact specified
426       // args.
427       if (entities.size() > 4) {
428         PrintFatalError(loc, "only support up to 4-entity constraints now");
429       }
430       SmallVector<std::string, 4> names;
431       int i = 0;
432       for (int e = entities.size(); i < e; ++i)
433         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
434       std::string self = appliedConstraint.self;
435       if (!self.empty())
436         self = symbolInfoMap.getValueAndRangeUse(self);
437       for (; i < 4; ++i)
438         names.push_back("<unused>");
439       emitMatchCheck(depth,
440                      tgfmt(condition, &fmtCtx.withSelf(self), names[0],
441                            names[1], names[2], names[3]),
442                      formatv("\"entities '{0}' failed to satisfy constraint: "
443                              "{1}\"",
444                              llvm::join(entities, ", "),
445                              constraint.getDescription()));
446     }
447   }
448   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
449 }
450 
451 void PatternEmitter::collectOps(DagNode tree,
452                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
453   // Check if this tree is an operation.
454   if (tree.isOperation()) {
455     const Operator &op = tree.getDialectOp(opMap);
456     LLVM_DEBUG(llvm::dbgs()
457                << "found operation " << op.getOperationName() << '\n');
458     ops.insert(&op);
459   }
460 
461   // Recurse the arguments of the tree.
462   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
463     if (auto child = tree.getArgAsNestedDag(i))
464       collectOps(child, ops);
465 }
466 
467 void PatternEmitter::emit(StringRef rewriteName) {
468   // Get the DAG tree for the source pattern.
469   DagNode sourceTree = pattern.getSourcePattern();
470 
471   const Operator &rootOp = pattern.getSourceRootOp();
472   auto rootName = rootOp.getOperationName();
473 
474   // Collect the set of result operations.
475   llvm::SmallPtrSet<const Operator *, 4> resultOps;
476   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
477   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
478     collectOps(pattern.getResultPattern(i), resultOps);
479   }
480   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
481 
482   // Emit RewritePattern for Pattern.
483   auto locs = pattern.getLocation();
484   os << formatv("/* Generated from:\n    {0:$[ instantiating\n    ]}\n*/\n",
485                 make_range(locs.rbegin(), locs.rend()));
486   os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
487   {0}(::mlir::MLIRContext *context)
488       : ::mlir::RewritePattern("{1}", {{)",
489                 rewriteName, rootName);
490   // Sort result operators by name.
491   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
492                                                          resultOps.end());
493   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
494     return lhs->getOperationName() < rhs->getOperationName();
495   });
496   llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
497     os << '"' << op->getOperationName() << '"';
498   });
499   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
500 
501   // Emit matchAndRewrite() function.
502   {
503     auto classScope = os.scope();
504     os.reindent(R"(
505     ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
506         ::mlir::PatternRewriter &rewriter) const override {)")
507         << '\n';
508     {
509       auto functionScope = os.scope();
510 
511       // Register all symbols bound in the source pattern.
512       pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
513 
514       LLVM_DEBUG(llvm::dbgs()
515                  << "start creating local variables for capturing matches\n");
516       os << "// Variables for capturing values and attributes used while "
517             "creating ops\n";
518       // Create local variables for storing the arguments and results bound
519       // to symbols.
520       for (const auto &symbolInfoPair : symbolInfoMap) {
521         StringRef symbol = symbolInfoPair.getKey();
522         auto &info = symbolInfoPair.getValue();
523         os << info.getVarDecl(symbol);
524       }
525       // TODO: capture ops with consistent numbering so that it can be
526       // reused for fused loc.
527       os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
528                     pattern.getSourcePattern().getNumOps());
529       LLVM_DEBUG(llvm::dbgs()
530                  << "done creating local variables for capturing matches\n");
531 
532       os << "// Match\n";
533       os << "tblgen_ops[0] = op0;\n";
534       emitMatchLogic(sourceTree);
535 
536       os << "\n// Rewrite\n";
537       emitRewriteLogic();
538 
539       os << "return ::mlir::success();\n";
540     }
541     os << "};\n";
542   }
543   os << "};\n\n";
544 }
545 
546 void PatternEmitter::emitRewriteLogic() {
547   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
548   const Operator &rootOp = pattern.getSourceRootOp();
549   int numExpectedResults = rootOp.getNumResults();
550   int numResultPatterns = pattern.getNumResultPatterns();
551 
552   // First register all symbols bound to ops generated in result patterns.
553   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
554 
555   // Only the last N static values generated are used to replace the matched
556   // root N-result op. We need to calculate the starting index (of the results
557   // of the matched op) each result pattern is to replace.
558   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
559   // If we don't need to replace any value at all, set the replacement starting
560   // index as the number of result patterns so we skip all of them when trying
561   // to replace the matched op's results.
562   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
563   for (int i = numResultPatterns - 1; i >= 0; --i) {
564     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
565     offsets[i] = offsets[i + 1] - numValues;
566     if (offsets[i] == 0) {
567       if (replStartIndex == -1)
568         replStartIndex = i;
569     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
570       auto error = formatv(
571           "cannot use the same multi-result op '{0}' to generate both "
572           "auxiliary values and values to be used for replacing the matched op",
573           pattern.getResultPattern(i).getSymbol());
574       PrintFatalError(loc, error);
575     }
576   }
577 
578   if (offsets.front() > 0) {
579     const char error[] = "no enough values generated to replace the matched op";
580     PrintFatalError(loc, error);
581   }
582 
583   os << "auto odsLoc = rewriter.getFusedLoc({";
584   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
585     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
586   }
587   os << "}); (void)odsLoc;\n";
588 
589   // Process auxiliary result patterns.
590   for (int i = 0; i < replStartIndex; ++i) {
591     DagNode resultTree = pattern.getResultPattern(i);
592     auto val = handleResultPattern(resultTree, offsets[i], 0);
593     // Normal op creation will be streamed to `os` by the above call; but
594     // NativeCodeCall will only be materialized to `os` if it is used. Here
595     // we are handling auxiliary patterns so we want the side effect even if
596     // NativeCodeCall is not replacing matched root op's results.
597     if (resultTree.isNativeCodeCall())
598       os << val << ";\n";
599   }
600 
601   if (numExpectedResults == 0) {
602     assert(replStartIndex >= numResultPatterns &&
603            "invalid auxiliary vs. replacement pattern division!");
604     // No result to replace. Just erase the op.
605     os << "rewriter.eraseOp(op0);\n";
606   } else {
607     // Process replacement result patterns.
608     os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
609     for (int i = replStartIndex; i < numResultPatterns; ++i) {
610       DagNode resultTree = pattern.getResultPattern(i);
611       auto val = handleResultPattern(resultTree, offsets[i], 0);
612       os << "\n";
613       // Resolve each symbol for all range use so that we can loop over them.
614       // We need an explicit cast to `SmallVector` to capture the cases where
615       // `{0}` resolves to an `Operation::result_range` as well as cases that
616       // are not iterable (e.g. vector that gets wrapped in additional braces by
617       // RewriterGen).
618       // TODO: Revisit the need for materializing a vector.
619       os << symbolInfoMap.getAllRangeUse(
620           val,
621           "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
622           "  tblgen_repl_values.push_back(v);\n}\n",
623           "\n");
624     }
625     os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
626   }
627 
628   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
629 }
630 
631 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
632   return std::string(
633       formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
634 }
635 
636 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
637                                                 int resultIndex, int depth) {
638   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
639   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
640   LLVM_DEBUG(llvm::dbgs() << '\n');
641 
642   if (resultTree.isLocationDirective()) {
643     PrintFatalError(loc,
644                     "location directive can only be used with op creation");
645   }
646 
647   if (resultTree.isNativeCodeCall()) {
648     auto symbol = handleReplaceWithNativeCodeCall(resultTree);
649     symbolInfoMap.bindValue(symbol);
650     return symbol;
651   }
652 
653   if (resultTree.isReplaceWithValue())
654     return handleReplaceWithValue(resultTree).str();
655 
656   // Normal op creation.
657   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
658   if (resultTree.getSymbol().empty()) {
659     // This is an op not explicitly bound to a symbol in the rewrite rule.
660     // Register the auto-generated symbol for it.
661     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
662   }
663   return symbol;
664 }
665 
666 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
667   assert(tree.isReplaceWithValue());
668 
669   if (tree.getNumArgs() != 1) {
670     PrintFatalError(
671         loc, "replaceWithValue directive must take exactly one argument");
672   }
673 
674   if (!tree.getSymbol().empty()) {
675     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
676   }
677 
678   return tree.getArgName(0);
679 }
680 
681 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
682   assert(tree.isLocationDirective());
683   auto lookUpArgLoc = [this, &tree](int idx) {
684     const auto *const lookupFmt = "(*{0}.begin()).getLoc()";
685     return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt);
686   };
687 
688   if (tree.getNumArgs() == 0)
689     llvm::PrintFatalError(
690         "At least one argument to location directive required");
691 
692   if (!tree.getSymbol().empty())
693     PrintFatalError(loc, "cannot bind symbol to location");
694 
695   if (tree.getNumArgs() == 1) {
696     DagLeaf leaf = tree.getArgAsLeaf(0);
697     if (leaf.isStringAttr())
698       return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), "
699                      "rewriter.getContext())",
700                      leaf.getStringAttr())
701           .str();
702     return lookUpArgLoc(0);
703   }
704 
705   std::string ret;
706   llvm::raw_string_ostream os(ret);
707   std::string strAttr;
708   os << "rewriter.getFusedLoc({";
709   bool first = true;
710   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
711     DagLeaf leaf = tree.getArgAsLeaf(i);
712     // Handle the optional string value.
713     if (leaf.isStringAttr()) {
714       if (!strAttr.empty())
715         llvm::PrintFatalError("Only one string attribute may be specified");
716       strAttr = leaf.getStringAttr();
717       continue;
718     }
719     os << (first ? "" : ", ") << lookUpArgLoc(i);
720     first = false;
721   }
722   os << "}";
723   if (!strAttr.empty()) {
724     os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
725   }
726   os << ")";
727   return os.str();
728 }
729 
730 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
731                                              StringRef patArgName) {
732   if (leaf.isStringAttr())
733     PrintFatalError(loc, "raw string not supported as argument");
734   if (leaf.isConstantAttr()) {
735     auto constAttr = leaf.getAsConstantAttr();
736     return handleConstantAttr(constAttr.getAttribute(),
737                               constAttr.getConstantValue());
738   }
739   if (leaf.isEnumAttrCase()) {
740     auto enumCase = leaf.getAsEnumAttrCase();
741     if (enumCase.isStrCase())
742       return handleConstantAttr(enumCase, enumCase.getSymbol());
743     // This is an enum case backed by an IntegerAttr. We need to get its value
744     // to build the constant.
745     std::string val = std::to_string(enumCase.getValue());
746     return handleConstantAttr(enumCase, val);
747   }
748 
749   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
750   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
751   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
752     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
753                             << "' (via symbol ref)\n");
754     return argName;
755   }
756   if (leaf.isNativeCodeCall()) {
757     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
758     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
759                             << "' (via NativeCodeCall)\n");
760     return std::string(repl);
761   }
762   PrintFatalError(loc, "unhandled case when rewriting op");
763 }
764 
765 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
766   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
767   LLVM_DEBUG(tree.print(llvm::dbgs()));
768   LLVM_DEBUG(llvm::dbgs() << '\n');
769 
770   auto fmt = tree.getNativeCodeTemplate();
771   // TODO: replace formatv arguments with the exact specified args.
772   SmallVector<std::string, 8> attrs(8);
773   if (tree.getNumArgs() > 8) {
774     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
775                              Twine(tree.getNumArgs()));
776   }
777   bool hasLocationDirective;
778   std::string locToUse;
779   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
780 
781   for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
782     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
783     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
784                             << " replacement: " << attrs[i] << "\n");
785   }
786   return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0],
787                            attrs[1], attrs[2], attrs[3], attrs[4], attrs[5],
788                            attrs[6], attrs[7]));
789 }
790 
791 int PatternEmitter::getNodeValueCount(DagNode node) {
792   if (node.isOperation()) {
793     // If the op is bound to a symbol in the rewrite rule, query its result
794     // count from the symbol info map.
795     auto symbol = node.getSymbol();
796     if (!symbol.empty()) {
797       return symbolInfoMap.getStaticValueCount(symbol);
798     }
799     // Otherwise this is an unbound op; we will use all its results.
800     return pattern.getDialectOp(node).getNumResults();
801   }
802   // TODO: This considers all NativeCodeCall as returning one
803   // value. Enhance if multi-value ones are needed.
804   return 1;
805 }
806 
807 std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) {
808   auto numPatArgs = tree.getNumArgs();
809 
810   if (numPatArgs != 0) {
811     if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
812       if (lastArg.isLocationDirective()) {
813         return std::make_pair(true, handleLocationDirective(lastArg));
814       }
815   }
816 
817   // If no explicit location is given, use the default, all fused, location.
818   return std::make_pair(false, "odsLoc");
819 }
820 
821 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
822                                              int depth) {
823   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
824   LLVM_DEBUG(tree.print(llvm::dbgs()));
825   LLVM_DEBUG(llvm::dbgs() << '\n');
826 
827   Operator &resultOp = tree.getDialectOp(opMap);
828   auto numOpArgs = resultOp.getNumArgs();
829   auto numPatArgs = tree.getNumArgs();
830 
831   bool hasLocationDirective;
832   std::string locToUse;
833   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
834 
835   auto inPattern = numPatArgs - hasLocationDirective;
836   if (numOpArgs != inPattern) {
837     PrintFatalError(loc,
838                     formatv("resultant op '{0}' argument number mismatch: "
839                             "{1} in pattern vs. {2} in definition",
840                             resultOp.getOperationName(), inPattern, numOpArgs));
841   }
842 
843   // A map to collect all nested DAG child nodes' names, with operand index as
844   // the key. This includes both bound and unbound child nodes.
845   ChildNodeIndexNameMap childNodeNames;
846 
847   // First go through all the child nodes who are nested DAG constructs to
848   // create ops for them and remember the symbol names for them, so that we can
849   // use the results in the current node. This happens in a recursive manner.
850   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
851     if (auto child = tree.getArgAsNestedDag(i))
852       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
853   }
854 
855   // The name of the local variable holding this op.
856   std::string valuePackName;
857   // The symbol for holding the result of this pattern. Note that the result of
858   // this pattern is not necessarily the same as the variable created by this
859   // pattern because we can use `__N` suffix to refer only a specific result if
860   // the generated op is a multi-result op.
861   std::string resultValue;
862   if (tree.getSymbol().empty()) {
863     // No symbol is explicitly bound to this op in the pattern. Generate a
864     // unique name.
865     valuePackName = resultValue = getUniqueSymbol(&resultOp);
866   } else {
867     resultValue = std::string(tree.getSymbol());
868     // Strip the index to get the name for the value pack and use it to name the
869     // local variable for the op.
870     valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
871   }
872 
873   // Create the local variable for this op.
874   os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
875                 valuePackName);
876 
877   // Right now ODS don't have general type inference support. Except a few
878   // special cases listed below, DRR needs to supply types for all results
879   // when building an op.
880   bool isSameOperandsAndResultType =
881       resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
882   bool useFirstAttr =
883       resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
884 
885   if (isSameOperandsAndResultType || useFirstAttr) {
886     // We know how to deduce the result type for ops with these traits and we've
887     // generated builders taking aggregate parameters. Use those builders to
888     // create the ops.
889 
890     // First prepare local variables for op arguments used in builder call.
891     createAggregateLocalVarsForOpArgs(tree, childNodeNames);
892 
893     // Then create the op.
894     os.scope("", "\n}\n").os << formatv(
895         "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
896         valuePackName, resultOp.getQualCppClassName(), locToUse);
897     return resultValue;
898   }
899 
900   bool usePartialResults = valuePackName != resultValue;
901 
902   if (usePartialResults || depth > 0 || resultIndex < 0) {
903     // For these cases (broadcastable ops, op results used both as auxiliary
904     // values and replacement values, ops in nested patterns, auxiliary ops), we
905     // still need to supply the result types when building the op. But because
906     // we don't generate a builder automatically with ODS for them, it's the
907     // developer's responsibility to make sure such a builder (with result type
908     // deduction ability) exists. We go through the separate-parameter builder
909     // here given that it's easier for developers to write compared to
910     // aggregate-parameter builders.
911     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
912 
913     os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
914                              resultOp.getQualCppClassName(), locToUse);
915     supplyValuesForOpArgs(tree, childNodeNames);
916     os << "\n  );\n}\n";
917     return resultValue;
918   }
919 
920   // If depth == 0 and resultIndex >= 0, it means we are replacing the values
921   // generated from the source pattern root op. Then we can use the source
922   // pattern's value types to determine the value type of the generated op
923   // here.
924 
925   // First prepare local variables for op arguments used in builder call.
926   createAggregateLocalVarsForOpArgs(tree, childNodeNames);
927 
928   // Then prepare the result types. We need to specify the types for all
929   // results.
930   os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
931                          "(void)tblgen_types;\n");
932   int numResults = resultOp.getNumResults();
933   if (numResults != 0) {
934     for (int i = 0; i < numResults; ++i)
935       os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
936                     "  tblgen_types.push_back(v.getType());\n}\n",
937                     resultIndex + i);
938   }
939   os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
940                 "tblgen_values, tblgen_attrs);\n",
941                 valuePackName, resultOp.getQualCppClassName(), locToUse);
942   os.unindent() << "}\n";
943   return resultValue;
944 }
945 
946 void PatternEmitter::createSeparateLocalVarsForOpArgs(
947     DagNode node, ChildNodeIndexNameMap &childNodeNames) {
948   Operator &resultOp = node.getDialectOp(opMap);
949 
950   // Now prepare operands used for building this op:
951   // * If the operand is non-variadic, we create a `Value` local variable.
952   // * If the operand is variadic, we create a `SmallVector<Value>` local
953   //   variable.
954 
955   int valueIndex = 0; // An index for uniquing local variable names.
956   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
957     const auto *operand =
958         resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
959     // We do not need special handling for attributes.
960     if (!operand)
961       continue;
962 
963     raw_indented_ostream::DelimitedScope scope(os);
964     std::string varName;
965     if (operand->isVariadic()) {
966       varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
967       os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName);
968       std::string range;
969       if (node.isNestedDagArg(argIndex)) {
970         range = childNodeNames[argIndex];
971       } else {
972         range = std::string(node.getArgName(argIndex));
973       }
974       // Resolve the symbol for all range use so that we have a uniform way of
975       // capturing the values.
976       range = symbolInfoMap.getValueAndRangeUse(range);
977       os << formatv("for (auto v: {0}) {{\n  {1}.push_back(v);\n}\n", range,
978                     varName);
979     } else {
980       varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
981       os << formatv("::mlir::Value {0} = ", varName);
982       if (node.isNestedDagArg(argIndex)) {
983         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
984       } else {
985         DagLeaf leaf = node.getArgAsLeaf(argIndex);
986         auto symbol =
987             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
988         if (leaf.isNativeCodeCall()) {
989           os << std::string(
990               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
991         } else {
992           os << symbol;
993         }
994       }
995       os << ";\n";
996     }
997 
998     // Update to use the newly created local variable for building the op later.
999     childNodeNames[argIndex] = varName;
1000   }
1001 }
1002 
1003 void PatternEmitter::supplyValuesForOpArgs(
1004     DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
1005   Operator &resultOp = node.getDialectOp(opMap);
1006   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1007        argIndex != numOpArgs; ++argIndex) {
1008     // Start each argument on its own line.
1009     os << ",\n    ";
1010 
1011     Argument opArg = resultOp.getArg(argIndex);
1012     // Handle the case of operand first.
1013     if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
1014       if (!operand->name.empty())
1015         os << "/*" << operand->name << "=*/";
1016       os << childNodeNames.lookup(argIndex);
1017       continue;
1018     }
1019 
1020     // The argument in the op definition.
1021     auto opArgName = resultOp.getArgName(argIndex);
1022     if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1023       if (!subTree.isNativeCodeCall())
1024         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1025                              "for creating attribute");
1026       os << formatv("/*{0}=*/{1}", opArgName,
1027                     handleReplaceWithNativeCodeCall(subTree));
1028     } else {
1029       auto leaf = node.getArgAsLeaf(argIndex);
1030       // The argument in the result DAG pattern.
1031       auto patArgName = node.getArgName(argIndex);
1032       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1033         // TODO: Refactor out into map to avoid recomputing these.
1034         if (!opArg.is<NamedAttribute *>())
1035           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1036         if (!patArgName.empty())
1037           os << "/*" << patArgName << "=*/";
1038       } else {
1039         os << "/*" << opArgName << "=*/";
1040       }
1041       os << handleOpArgument(leaf, patArgName);
1042     }
1043   }
1044 }
1045 
1046 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1047     DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
1048   Operator &resultOp = node.getDialectOp(opMap);
1049 
1050   auto scope = os.scope();
1051   os << formatv("::mlir::SmallVector<::mlir::Value, 4> "
1052                 "tblgen_values; (void)tblgen_values;\n");
1053   os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
1054                 "tblgen_attrs; (void)tblgen_attrs;\n");
1055 
1056   const char *addAttrCmd =
1057       "if (auto tmpAttr = {1}) {\n"
1058       "  tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), "
1059       "tmpAttr);\n}\n";
1060   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1061     if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1062       // The argument in the op definition.
1063       auto opArgName = resultOp.getArgName(argIndex);
1064       if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1065         if (!subTree.isNativeCodeCall())
1066           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1067                                "for creating attribute");
1068         os << formatv(addAttrCmd, opArgName,
1069                       handleReplaceWithNativeCodeCall(subTree));
1070       } else {
1071         auto leaf = node.getArgAsLeaf(argIndex);
1072         // The argument in the result DAG pattern.
1073         auto patArgName = node.getArgName(argIndex);
1074         os << formatv(addAttrCmd, opArgName,
1075                       handleOpArgument(leaf, patArgName));
1076       }
1077       continue;
1078     }
1079 
1080     const auto *operand =
1081         resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1082     std::string varName;
1083     if (operand->isVariadic()) {
1084       std::string range;
1085       if (node.isNestedDagArg(argIndex)) {
1086         range = childNodeNames.lookup(argIndex);
1087       } else {
1088         range = std::string(node.getArgName(argIndex));
1089       }
1090       // Resolve the symbol for all range use so that we have a uniform way of
1091       // capturing the values.
1092       range = symbolInfoMap.getValueAndRangeUse(range);
1093       os << formatv("for (auto v: {0}) {{\n  tblgen_values.push_back(v);\n}\n",
1094                     range);
1095     } else {
1096       os << formatv("tblgen_values.push_back(", varName);
1097       if (node.isNestedDagArg(argIndex)) {
1098         os << symbolInfoMap.getValueAndRangeUse(
1099             childNodeNames.lookup(argIndex));
1100       } else {
1101         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1102         auto symbol =
1103             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1104         if (leaf.isNativeCodeCall()) {
1105           os << std::string(
1106               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1107         } else {
1108           os << symbol;
1109         }
1110       }
1111       os << ");\n";
1112     }
1113   }
1114 }
1115 
1116 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1117   emitSourceFileHeader("Rewriters", os);
1118 
1119   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1120   auto numPatterns = patterns.size();
1121 
1122   // We put the map here because it can be shared among multiple patterns.
1123   RecordOperatorMap recordOpMap;
1124 
1125   std::vector<std::string> rewriterNames;
1126   rewriterNames.reserve(numPatterns);
1127 
1128   std::string baseRewriterName = "GeneratedConvert";
1129   int rewriterIndex = 0;
1130 
1131   for (Record *p : patterns) {
1132     std::string name;
1133     if (p->isAnonymous()) {
1134       // If no name is provided, ensure unique rewriter names simply by
1135       // appending unique suffix.
1136       name = baseRewriterName + llvm::utostr(rewriterIndex++);
1137     } else {
1138       name = std::string(p->getName());
1139     }
1140     LLVM_DEBUG(llvm::dbgs()
1141                << "=== start generating pattern '" << name << "' ===\n");
1142     PatternEmitter(p, &recordOpMap, os).emit(name);
1143     LLVM_DEBUG(llvm::dbgs()
1144                << "=== done generating pattern '" << name << "' ===\n");
1145     rewriterNames.push_back(std::move(name));
1146   }
1147 
1148   // Emit function to add the generated matchers to the pattern list.
1149   os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::MLIRContext "
1150         "*context, ::mlir::OwningRewritePatternList *patterns) {\n";
1151   for (const auto &name : rewriterNames) {
1152     os << "  patterns->insert<" << name << ">(context);\n";
1153   }
1154   os << "}\n";
1155 }
1156 
1157 static mlir::GenRegistration
1158     genRewriters("gen-rewriters", "Generate pattern rewriters",
1159                  [](const RecordKeeper &records, raw_ostream &os) {
1160                    emitRewriters(records, os);
1161                    return false;
1162                  });
1163