xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision be185b6a7355fdfeb1c31df2e1272366fe58b01f)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/TableGen/Attribute.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/GenInfo.h"
16 #include "mlir/TableGen/Operator.h"
17 #include "mlir/TableGen/Pattern.h"
18 #include "mlir/TableGen/Predicate.h"
19 #include "mlir/TableGen/Type.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/StringSet.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/FormatAdapters.h"
25 #include "llvm/Support/PrettyStackTrace.h"
26 #include "llvm/Support/Signals.h"
27 #include "llvm/TableGen/Error.h"
28 #include "llvm/TableGen/Main.h"
29 #include "llvm/TableGen/Record.h"
30 #include "llvm/TableGen/TableGenBackend.h"
31 
32 using namespace mlir;
33 using namespace mlir::tblgen;
34 
35 using llvm::formatv;
36 using llvm::Record;
37 using llvm::RecordKeeper;
38 
39 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
40 
41 namespace llvm {
42 template <>
43 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
44   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
45                      raw_ostream &os, StringRef style) {
46     os << v.first << ":" << v.second;
47   }
48 };
49 } // end namespace llvm
50 
51 //===----------------------------------------------------------------------===//
52 // PatternEmitter
53 //===----------------------------------------------------------------------===//
54 
55 namespace {
56 class PatternEmitter {
57 public:
58   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
59 
60   // Emits the mlir::RewritePattern struct named `rewriteName`.
61   void emit(StringRef rewriteName);
62 
63 private:
64   // Emits the code for matching ops.
65   void emitMatchLogic(DagNode tree);
66 
67   // Emits the code for rewriting ops.
68   void emitRewriteLogic();
69 
70   //===--------------------------------------------------------------------===//
71   // Match utilities
72   //===--------------------------------------------------------------------===//
73 
74   // Emits C++ statements for matching the op constrained by the given DAG
75   // `tree`.
76   void emitOpMatch(DagNode tree, int depth);
77 
78   // Emits C++ statements for matching the `argIndex`-th argument of the given
79   // DAG `tree` as an operand.
80   void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent);
81 
82   // Emits C++ statements for matching the `argIndex`-th argument of the given
83   // DAG `tree` as an attribute.
84   void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent);
85 
86   // Emits C++ for checking a match with a corresponding match failure
87   // diagnostic.
88   void emitMatchCheck(int depth, const FmtObjectBase &matchFmt,
89                       const llvm::formatv_object_base &failureFmt);
90 
91   //===--------------------------------------------------------------------===//
92   // Rewrite utilities
93   //===--------------------------------------------------------------------===//
94 
95   // The entry point for handling a result pattern rooted at `resultTree`. This
96   // method dispatches to concrete handlers according to `resultTree`'s kind and
97   // returns a symbol representing the whole value pack. Callers are expected to
98   // further resolve the symbol according to the specific use case.
99   //
100   // `depth` is the nesting level of `resultTree`; 0 means top-level result
101   // pattern. For top-level result pattern, `resultIndex` indicates which result
102   // of the matched root op this pattern is intended to replace, which can be
103   // used to deduce the result type of the op generated from this result
104   // pattern.
105   std::string handleResultPattern(DagNode resultTree, int resultIndex,
106                                   int depth);
107 
108   // Emits the C++ statement to replace the matched DAG with a value built via
109   // calling native C++ code.
110   std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
111 
112   // Returns the symbol of the old value serving as the replacement.
113   StringRef handleReplaceWithValue(DagNode tree);
114 
115   // Returns the location value to use.
116   std::pair<bool, std::string> getLocation(DagNode tree);
117 
118   // Returns the location value to use.
119   std::string handleLocationDirective(DagNode tree);
120 
121   // Emits the C++ statement to build a new op out of the given DAG `tree` and
122   // returns the variable name that this op is assigned to. If the root op in
123   // DAG `tree` has a specified name, the created op will be assigned to a
124   // variable of the given name. Otherwise, a unique name will be used as the
125   // result value name.
126   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
127 
128   using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
129 
130   // Emits a local variable for each value and attribute to be used for creating
131   // an op.
132   void createSeparateLocalVarsForOpArgs(DagNode node,
133                                         ChildNodeIndexNameMap &childNodeNames);
134 
135   // Emits the concrete arguments used to call an op's builder.
136   void supplyValuesForOpArgs(DagNode node,
137                              const ChildNodeIndexNameMap &childNodeNames);
138 
139   // Emits the local variables for holding all values as a whole and all named
140   // attributes as a whole to be used for creating an op.
141   void createAggregateLocalVarsForOpArgs(
142       DagNode node, const ChildNodeIndexNameMap &childNodeNames);
143 
144   // Returns the C++ expression to construct a constant attribute of the given
145   // `value` for the given attribute kind `attr`.
146   std::string handleConstantAttr(Attribute attr, StringRef value);
147 
148   // Returns the C++ expression to build an argument from the given DAG `leaf`.
149   // `patArgName` is used to bound the argument to the source pattern.
150   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
151 
152   //===--------------------------------------------------------------------===//
153   // General utilities
154   //===--------------------------------------------------------------------===//
155 
156   // Collects all of the operations within the given dag tree.
157   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
158 
159   // Returns a unique symbol for a local variable of the given `op`.
160   std::string getUniqueSymbol(const Operator *op);
161 
162   //===--------------------------------------------------------------------===//
163   // Symbol utilities
164   //===--------------------------------------------------------------------===//
165 
166   // Returns how many static values the given DAG `node` correspond to.
167   int getNodeValueCount(DagNode node);
168 
169 private:
170   // Pattern instantiation location followed by the location of multiclass
171   // prototypes used. This is intended to be used as a whole to
172   // PrintFatalError() on errors.
173   ArrayRef<llvm::SMLoc> loc;
174 
175   // Op's TableGen Record to wrapper object.
176   RecordOperatorMap *opMap;
177 
178   // Handy wrapper for pattern being emitted.
179   Pattern pattern;
180 
181   // Map for all bound symbols' info.
182   SymbolInfoMap symbolInfoMap;
183 
184   // The next unused ID for newly created values.
185   unsigned nextValueId;
186 
187   raw_ostream &os;
188 
189   // Format contexts containing placeholder substitutions.
190   FmtContext fmtCtx;
191 
192   // Number of op processed.
193   int opCounter = 0;
194 };
195 } // end anonymous namespace
196 
197 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
198                                raw_ostream &os)
199     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
200       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
201   fmtCtx.withBuilder("rewriter");
202 }
203 
204 std::string PatternEmitter::handleConstantAttr(Attribute attr,
205                                                StringRef value) {
206   if (!attr.isConstBuildable())
207     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
208                              " does not have the 'constBuilderCall' field");
209 
210   // TODO: Verify the constants here
211   return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
212 }
213 
214 // Helper function to match patterns.
215 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
216   Operator &op = tree.getDialectOp(opMap);
217   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
218                           << op.getOperationName() << "' at depth " << depth
219                           << '\n');
220 
221   int indent = 4 + 2 * depth;
222   os.indent(indent) << formatv(
223       "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
224       depth, op.getQualCppClassName());
225   // Skip the operand matching at depth 0 as the pattern rewriter already does.
226   if (depth != 0) {
227     // Skip if there is no defining operation (e.g., arguments to function).
228     os.indent(indent) << formatv("if (!castedOp{0}) return failure();\n",
229                                  depth);
230   }
231   if (tree.getNumArgs() != op.getNumArgs()) {
232     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
233                                  "pattern vs. {2} in definition",
234                                  op.getOperationName(), tree.getNumArgs(),
235                                  op.getNumArgs()));
236   }
237 
238   // If the operand's name is set, set to that variable.
239   auto name = tree.getSymbol();
240   if (!name.empty())
241     os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
242 
243   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
244     auto opArg = op.getArg(i);
245 
246     // Handle nested DAG construct first
247     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
248       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
249         if (operand->isVariableLength()) {
250           auto error = formatv("use nested DAG construct to match op {0}'s "
251                                "variadic operand #{1} unsupported now",
252                                op.getOperationName(), i);
253           PrintFatalError(loc, error);
254         }
255       }
256       os.indent(indent) << "{\n";
257 
258       os.indent(indent + 2) << formatv(
259           "auto *op{0} = "
260           "(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
261           depth + 1, depth, i);
262       emitOpMatch(argTree, depth + 1);
263       os.indent(indent + 2)
264           << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
265       os.indent(indent) << "}\n";
266       continue;
267     }
268 
269     // Next handle DAG leaf: operand or attribute
270     if (opArg.is<NamedTypeConstraint *>()) {
271       emitOperandMatch(tree, i, depth, indent);
272     } else if (opArg.is<NamedAttribute *>()) {
273       emitAttributeMatch(tree, i, depth, indent);
274     } else {
275       PrintFatalError(loc, "unhandled case when matching op");
276     }
277   }
278   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
279                           << op.getOperationName() << "' at depth " << depth
280                           << '\n');
281 }
282 
283 void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
284                                       int indent) {
285   Operator &op = tree.getDialectOp(opMap);
286   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
287   auto matcher = tree.getArgAsLeaf(argIndex);
288 
289   // If a constraint is specified, we need to generate C++ statements to
290   // check the constraint.
291   if (!matcher.isUnspecified()) {
292     if (!matcher.isOperandMatcher()) {
293       PrintFatalError(
294           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
295                        op.getOperationName(), argIndex + 1));
296     }
297 
298     // Only need to verify if the matcher's type is different from the one
299     // of op definition.
300     Constraint constraint = matcher.getAsConstraint();
301     if (operand->constraint != constraint) {
302       if (operand->isVariableLength()) {
303         auto error = formatv(
304             "further constrain op {0}'s variadic operand #{1} unsupported now",
305             op.getOperationName(), argIndex);
306         PrintFatalError(loc, error);
307       }
308       auto self =
309           formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth,
310                   argIndex);
311       emitMatchCheck(
312           depth,
313           tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
314           formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
315                   "'{2}'\"",
316                   operand - op.operand_begin(), op.getOperationName(),
317                   constraint.getDescription()));
318     }
319   }
320 
321   // Capture the value
322   auto name = tree.getArgName(argIndex);
323   // `$_` is a special symbol to ignore op argument matching.
324   if (!name.empty() && name != "_") {
325     // We need to subtract the number of attributes before this operand to get
326     // the index in the operand list.
327     auto numPrevAttrs = std::count_if(
328         op.arg_begin(), op.arg_begin() + argIndex,
329         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
330 
331     os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
332                                  name, depth, argIndex - numPrevAttrs);
333   }
334 }
335 
336 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
337                                         int indent) {
338   Operator &op = tree.getDialectOp(opMap);
339   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
340   const auto &attr = namedAttr->attr;
341 
342   os.indent(indent) << "{\n";
343   indent += 2;
344   os.indent(indent) << formatv(
345       "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");"
346       "(void)tblgen_attr;\n",
347       depth, attr.getStorageType(), namedAttr->name);
348 
349   // TODO: This should use getter method to avoid duplication.
350   if (attr.hasDefaultValue()) {
351     os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
352                       << std::string(tgfmt(attr.getConstBuilderTemplate(),
353                                            &fmtCtx, attr.getDefaultValue()))
354                       << ";\n";
355   } else if (attr.isOptional()) {
356     // For a missing attribute that is optional according to definition, we
357     // should just capture a mlir::Attribute() to signal the missing state.
358     // That is precisely what getAttr() returns on missing attributes.
359   } else {
360     emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx),
361                    formatv("\"expected op '{0}' to have attribute '{1}' "
362                            "of type '{2}'\"",
363                            op.getOperationName(), namedAttr->name,
364                            attr.getStorageType()));
365   }
366 
367   auto matcher = tree.getArgAsLeaf(argIndex);
368   if (!matcher.isUnspecified()) {
369     if (!matcher.isAttrMatcher()) {
370       PrintFatalError(
371           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
372                        op.getOperationName(), argIndex + 1));
373     }
374 
375     // If a constraint is specified, we need to generate C++ statements to
376     // check the constraint.
377     emitMatchCheck(
378         depth,
379         tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
380         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
381                 "{2}\"",
382                 op.getOperationName(), namedAttr->name,
383                 matcher.getAsConstraint().getDescription()));
384   }
385 
386   // Capture the value
387   auto name = tree.getArgName(argIndex);
388   // `$_` is a special symbol to ignore op argument matching.
389   if (!name.empty() && name != "_") {
390     os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
391   }
392 
393   indent -= 2;
394   os.indent(indent) << "}\n";
395 }
396 
397 void PatternEmitter::emitMatchCheck(
398     int depth, const FmtObjectBase &matchFmt,
399     const llvm::formatv_object_base &failureFmt) {
400   // {0} The match depth (used to get the operation that failed to match).
401   // {1} The format for the match string.
402   // {2} The format for the failure string.
403   const char *matchStr = R"(
404     if (!({1})) {
405       return rewriter.notifyMatchFailure(op{0}, [&](::mlir::Diagnostic &diag) {
406         diag << {2};
407       });
408     })";
409   os << llvm::formatv(matchStr, depth, matchFmt.str(), failureFmt.str())
410      << "\n";
411 }
412 
413 void PatternEmitter::emitMatchLogic(DagNode tree) {
414   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
415   int depth = 0;
416   emitOpMatch(tree, depth);
417 
418   for (auto &appliedConstraint : pattern.getConstraints()) {
419     auto &constraint = appliedConstraint.constraint;
420     auto &entities = appliedConstraint.entities;
421 
422     auto condition = constraint.getConditionTemplate();
423     if (isa<TypeConstraint>(constraint)) {
424       auto self = formatv("({0}.getType())",
425                           symbolInfoMap.getValueAndRangeUse(entities.front()));
426       emitMatchCheck(
427           depth, tgfmt(condition, &fmtCtx.withSelf(self.str())),
428           formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
429                   entities.front(), constraint.getDescription()));
430 
431     } else if (isa<AttrConstraint>(constraint)) {
432       PrintFatalError(
433           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
434     } else {
435       // TODO: replace formatv arguments with the exact specified
436       // args.
437       if (entities.size() > 4) {
438         PrintFatalError(loc, "only support up to 4-entity constraints now");
439       }
440       SmallVector<std::string, 4> names;
441       int i = 0;
442       for (int e = entities.size(); i < e; ++i)
443         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
444       std::string self = appliedConstraint.self;
445       if (!self.empty())
446         self = symbolInfoMap.getValueAndRangeUse(self);
447       for (; i < 4; ++i)
448         names.push_back("<unused>");
449       emitMatchCheck(depth,
450                      tgfmt(condition, &fmtCtx.withSelf(self), names[0],
451                            names[1], names[2], names[3]),
452                      formatv("\"entities '{0}' failed to satisfy constraint: "
453                              "{1}\"",
454                              llvm::join(entities, ", "),
455                              constraint.getDescription()));
456     }
457   }
458   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
459 }
460 
461 void PatternEmitter::collectOps(DagNode tree,
462                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
463   // Check if this tree is an operation.
464   if (tree.isOperation()) {
465     const Operator &op = tree.getDialectOp(opMap);
466     LLVM_DEBUG(llvm::dbgs()
467                << "found operation " << op.getOperationName() << '\n');
468     ops.insert(&op);
469   }
470 
471   // Recurse the arguments of the tree.
472   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
473     if (auto child = tree.getArgAsNestedDag(i))
474       collectOps(child, ops);
475 }
476 
477 void PatternEmitter::emit(StringRef rewriteName) {
478   // Get the DAG tree for the source pattern.
479   DagNode sourceTree = pattern.getSourcePattern();
480 
481   const Operator &rootOp = pattern.getSourceRootOp();
482   auto rootName = rootOp.getOperationName();
483 
484   // Collect the set of result operations.
485   llvm::SmallPtrSet<const Operator *, 4> resultOps;
486   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
487   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
488     collectOps(pattern.getResultPattern(i), resultOps);
489   }
490   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
491 
492   // Emit RewritePattern for Pattern.
493   auto locs = pattern.getLocation();
494   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
495                 make_range(locs.rbegin(), locs.rend()));
496   os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
497   {0}(::mlir::MLIRContext *context)
498       : ::mlir::RewritePattern("{1}", {{)",
499                 rewriteName, rootName);
500   // Sort result operators by name.
501   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
502                                                          resultOps.end());
503   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
504     return lhs->getOperationName() < rhs->getOperationName();
505   });
506   llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
507     os << '"' << op->getOperationName() << '"';
508   });
509   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
510 
511   // Emit matchAndRewrite() function.
512   os << R"(
513   ::mlir::LogicalResult
514   matchAndRewrite(::mlir::Operation *op0,
515                   ::mlir::PatternRewriter &rewriter) const override {
516 )";
517 
518   // Register all symbols bound in the source pattern.
519   pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
520 
521   LLVM_DEBUG(
522       llvm::dbgs() << "start creating local variables for capturing matches\n");
523   os.indent(4) << "// Variables for capturing values and attributes used for "
524                   "creating ops\n";
525   // Create local variables for storing the arguments and results bound
526   // to symbols.
527   for (const auto &symbolInfoPair : symbolInfoMap) {
528     StringRef symbol = symbolInfoPair.getKey();
529     auto &info = symbolInfoPair.getValue();
530     os.indent(4) << info.getVarDecl(symbol);
531   }
532   // TODO: capture ops with consistent numbering so that it can be
533   // reused for fused loc.
534   os.indent(4) << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
535                           pattern.getSourcePattern().getNumOps());
536   LLVM_DEBUG(
537       llvm::dbgs() << "done creating local variables for capturing matches\n");
538 
539   os.indent(4) << "// Match\n";
540   os.indent(4) << "tblgen_ops[0] = op0;\n";
541   emitMatchLogic(sourceTree);
542   os << "\n";
543 
544   os.indent(4) << "// Rewrite\n";
545   emitRewriteLogic();
546 
547   os.indent(4) << "return success();\n";
548   os << "  };\n";
549   os << "};\n";
550 }
551 
552 void PatternEmitter::emitRewriteLogic() {
553   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
554   const Operator &rootOp = pattern.getSourceRootOp();
555   int numExpectedResults = rootOp.getNumResults();
556   int numResultPatterns = pattern.getNumResultPatterns();
557 
558   // First register all symbols bound to ops generated in result patterns.
559   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
560 
561   // Only the last N static values generated are used to replace the matched
562   // root N-result op. We need to calculate the starting index (of the results
563   // of the matched op) each result pattern is to replace.
564   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
565   // If we don't need to replace any value at all, set the replacement starting
566   // index as the number of result patterns so we skip all of them when trying
567   // to replace the matched op's results.
568   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
569   for (int i = numResultPatterns - 1; i >= 0; --i) {
570     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
571     offsets[i] = offsets[i + 1] - numValues;
572     if (offsets[i] == 0) {
573       if (replStartIndex == -1)
574         replStartIndex = i;
575     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
576       auto error = formatv(
577           "cannot use the same multi-result op '{0}' to generate both "
578           "auxiliary values and values to be used for replacing the matched op",
579           pattern.getResultPattern(i).getSymbol());
580       PrintFatalError(loc, error);
581     }
582   }
583 
584   if (offsets.front() > 0) {
585     const char error[] = "no enough values generated to replace the matched op";
586     PrintFatalError(loc, error);
587   }
588 
589   os.indent(4) << "auto odsLoc = rewriter.getFusedLoc({";
590   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
591     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
592   }
593   os << "}); (void)odsLoc;\n";
594 
595   // Process auxiliary result patterns.
596   for (int i = 0; i < replStartIndex; ++i) {
597     DagNode resultTree = pattern.getResultPattern(i);
598     auto val = handleResultPattern(resultTree, offsets[i], 0);
599     // Normal op creation will be streamed to `os` by the above call; but
600     // NativeCodeCall will only be materialized to `os` if it is used. Here
601     // we are handling auxiliary patterns so we want the side effect even if
602     // NativeCodeCall is not replacing matched root op's results.
603     if (resultTree.isNativeCodeCall())
604       os.indent(4) << val << ";\n";
605   }
606 
607   if (numExpectedResults == 0) {
608     assert(replStartIndex >= numResultPatterns &&
609            "invalid auxiliary vs. replacement pattern division!");
610     // No result to replace. Just erase the op.
611     os.indent(4) << "rewriter.eraseOp(op0);\n";
612   } else {
613     // Process replacement result patterns.
614     os.indent(4)
615         << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
616     for (int i = replStartIndex; i < numResultPatterns; ++i) {
617       DagNode resultTree = pattern.getResultPattern(i);
618       auto val = handleResultPattern(resultTree, offsets[i], 0);
619       os.indent(4) << "\n";
620       // Resolve each symbol for all range use so that we can loop over them.
621       // We need an explicit cast to `SmallVector` to capture the cases where
622       // `{0}` resolves to an `Operation::result_range` as well as cases that
623       // are not iterable (e.g. vector that gets wrapped in additional braces by
624       // RewriterGen).
625       // TODO: Revisit the need for materializing a vector.
626       os << symbolInfoMap.getAllRangeUse(
627           val,
628           "    for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{ "
629           "tblgen_repl_values.push_back(v); }",
630           "\n");
631     }
632     os.indent(4) << "\n";
633     os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n";
634   }
635 
636   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
637 }
638 
639 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
640   return std::string(
641       formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
642 }
643 
644 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
645                                                 int resultIndex, int depth) {
646   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
647   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
648   LLVM_DEBUG(llvm::dbgs() << '\n');
649 
650   if (resultTree.isLocationDirective()) {
651     PrintFatalError(loc,
652                     "location directive can only be used with op creation");
653   }
654 
655   if (resultTree.isNativeCodeCall()) {
656     auto symbol = handleReplaceWithNativeCodeCall(resultTree);
657     symbolInfoMap.bindValue(symbol);
658     return symbol;
659   }
660 
661   if (resultTree.isReplaceWithValue())
662     return handleReplaceWithValue(resultTree).str();
663 
664   // Normal op creation.
665   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
666   if (resultTree.getSymbol().empty()) {
667     // This is an op not explicitly bound to a symbol in the rewrite rule.
668     // Register the auto-generated symbol for it.
669     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
670   }
671   return symbol;
672 }
673 
674 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
675   assert(tree.isReplaceWithValue());
676 
677   if (tree.getNumArgs() != 1) {
678     PrintFatalError(
679         loc, "replaceWithValue directive must take exactly one argument");
680   }
681 
682   if (!tree.getSymbol().empty()) {
683     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
684   }
685 
686   return tree.getArgName(0);
687 }
688 
689 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
690   assert(tree.isLocationDirective());
691   auto lookUpArgLoc = [this, &tree](int idx) {
692     const auto *const lookupFmt = "(*{0}.begin()).getLoc()";
693     return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt);
694   };
695 
696   if (tree.getNumArgs() == 0)
697     llvm::PrintFatalError(
698         "At least one argument to location directive required");
699 
700   if (!tree.getSymbol().empty())
701     PrintFatalError(loc, "cannot bind symbol to location");
702 
703   if (tree.getNumArgs() == 1) {
704     DagLeaf leaf = tree.getArgAsLeaf(0);
705     if (leaf.isStringAttr())
706       return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), "
707                      "rewriter.getContext())",
708                      leaf.getStringAttr())
709           .str();
710     return lookUpArgLoc(0);
711   }
712 
713   std::string ret;
714   llvm::raw_string_ostream os(ret);
715   std::string strAttr;
716   os << "rewriter.getFusedLoc({";
717   bool first = true;
718   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
719     DagLeaf leaf = tree.getArgAsLeaf(i);
720     // Handle the optional string value.
721     if (leaf.isStringAttr()) {
722       if (!strAttr.empty())
723         llvm::PrintFatalError("Only one string attribute may be specified");
724       strAttr = leaf.getStringAttr();
725       continue;
726     }
727     os << (first ? "" : ", ") << lookUpArgLoc(i);
728     first = false;
729   }
730   os << "}";
731   if (!strAttr.empty()) {
732     os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
733   }
734   os << ")";
735   return os.str();
736 }
737 
738 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
739                                              StringRef patArgName) {
740   if (leaf.isStringAttr())
741     PrintFatalError(loc, "raw string not supported as argument");
742   if (leaf.isConstantAttr()) {
743     auto constAttr = leaf.getAsConstantAttr();
744     return handleConstantAttr(constAttr.getAttribute(),
745                               constAttr.getConstantValue());
746   }
747   if (leaf.isEnumAttrCase()) {
748     auto enumCase = leaf.getAsEnumAttrCase();
749     if (enumCase.isStrCase())
750       return handleConstantAttr(enumCase, enumCase.getSymbol());
751     // This is an enum case backed by an IntegerAttr. We need to get its value
752     // to build the constant.
753     std::string val = std::to_string(enumCase.getValue());
754     return handleConstantAttr(enumCase, val);
755   }
756 
757   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
758   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
759   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
760     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
761                             << "' (via symbol ref)\n");
762     return argName;
763   }
764   if (leaf.isNativeCodeCall()) {
765     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
766     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
767                             << "' (via NativeCodeCall)\n");
768     return std::string(repl);
769   }
770   PrintFatalError(loc, "unhandled case when rewriting op");
771 }
772 
773 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
774   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
775   LLVM_DEBUG(tree.print(llvm::dbgs()));
776   LLVM_DEBUG(llvm::dbgs() << '\n');
777 
778   auto fmt = tree.getNativeCodeTemplate();
779   // TODO: replace formatv arguments with the exact specified args.
780   SmallVector<std::string, 8> attrs(8);
781   if (tree.getNumArgs() > 8) {
782     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
783                              Twine(tree.getNumArgs()));
784   }
785   bool hasLocationDirective;
786   std::string locToUse;
787   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
788 
789   for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
790     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
791     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
792                             << " replacement: " << attrs[i] << "\n");
793   }
794   return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0],
795                            attrs[1], attrs[2], attrs[3], attrs[4], attrs[5],
796                            attrs[6], attrs[7]));
797 }
798 
799 int PatternEmitter::getNodeValueCount(DagNode node) {
800   if (node.isOperation()) {
801     // If the op is bound to a symbol in the rewrite rule, query its result
802     // count from the symbol info map.
803     auto symbol = node.getSymbol();
804     if (!symbol.empty()) {
805       return symbolInfoMap.getStaticValueCount(symbol);
806     }
807     // Otherwise this is an unbound op; we will use all its results.
808     return pattern.getDialectOp(node).getNumResults();
809   }
810   // TODO: This considers all NativeCodeCall as returning one
811   // value. Enhance if multi-value ones are needed.
812   return 1;
813 }
814 
815 std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) {
816   auto numPatArgs = tree.getNumArgs();
817 
818   if (numPatArgs != 0) {
819     if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
820       if (lastArg.isLocationDirective()) {
821         return std::make_pair(true, handleLocationDirective(lastArg));
822       }
823   }
824 
825   // If no explicit location is given, use the default, all fused, location.
826   return std::make_pair(false, "odsLoc");
827 }
828 
829 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
830                                              int depth) {
831   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
832   LLVM_DEBUG(tree.print(llvm::dbgs()));
833   LLVM_DEBUG(llvm::dbgs() << '\n');
834 
835   Operator &resultOp = tree.getDialectOp(opMap);
836   auto numOpArgs = resultOp.getNumArgs();
837   auto numPatArgs = tree.getNumArgs();
838 
839   bool hasLocationDirective;
840   std::string locToUse;
841   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
842 
843   auto inPattern = numPatArgs - hasLocationDirective;
844   if (numOpArgs != inPattern) {
845     PrintFatalError(loc,
846                     formatv("resultant op '{0}' argument number mismatch: "
847                             "{1} in pattern vs. {2} in definition",
848                             resultOp.getOperationName(), inPattern, numOpArgs));
849   }
850 
851   // A map to collect all nested DAG child nodes' names, with operand index as
852   // the key. This includes both bound and unbound child nodes.
853   ChildNodeIndexNameMap childNodeNames;
854 
855   // First go through all the child nodes who are nested DAG constructs to
856   // create ops for them and remember the symbol names for them, so that we can
857   // use the results in the current node. This happens in a recursive manner.
858   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
859     if (auto child = tree.getArgAsNestedDag(i))
860       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
861   }
862 
863   // The name of the local variable holding this op.
864   std::string valuePackName;
865   // The symbol for holding the result of this pattern. Note that the result of
866   // this pattern is not necessarily the same as the variable created by this
867   // pattern because we can use `__N` suffix to refer only a specific result if
868   // the generated op is a multi-result op.
869   std::string resultValue;
870   if (tree.getSymbol().empty()) {
871     // No symbol is explicitly bound to this op in the pattern. Generate a
872     // unique name.
873     valuePackName = resultValue = getUniqueSymbol(&resultOp);
874   } else {
875     resultValue = std::string(tree.getSymbol());
876     // Strip the index to get the name for the value pack and use it to name the
877     // local variable for the op.
878     valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
879   }
880 
881   // Create the local variable for this op.
882   os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
883                           valuePackName);
884   os.indent(4) << "{\n";
885 
886   // Right now ODS don't have general type inference support. Except a few
887   // special cases listed below, DRR needs to supply types for all results
888   // when building an op.
889   bool isSameOperandsAndResultType =
890       resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
891   bool useFirstAttr =
892       resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
893 
894   if (isSameOperandsAndResultType || useFirstAttr) {
895     // We know how to deduce the result type for ops with these traits and we've
896     // generated builders taking aggregate parameters. Use those builders to
897     // create the ops.
898 
899     // First prepare local variables for op arguments used in builder call.
900     createAggregateLocalVarsForOpArgs(tree, childNodeNames);
901 
902     // Then create the op.
903     os.indent(6) << formatv(
904         "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);\n",
905         valuePackName, resultOp.getQualCppClassName(), locToUse);
906     os.indent(4) << "}\n";
907     return resultValue;
908   }
909 
910   bool usePartialResults = valuePackName != resultValue;
911 
912   if (usePartialResults || depth > 0 || resultIndex < 0) {
913     // For these cases (broadcastable ops, op results used both as auxiliary
914     // values and replacement values, ops in nested patterns, auxiliary ops), we
915     // still need to supply the result types when building the op. But because
916     // we don't generate a builder automatically with ODS for them, it's the
917     // developer's responsibility to make sure such a builder (with result type
918     // deduction ability) exists. We go through the separate-parameter builder
919     // here given that it's easier for developers to write compared to
920     // aggregate-parameter builders.
921     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
922 
923     os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
924                             resultOp.getQualCppClassName(), locToUse);
925     supplyValuesForOpArgs(tree, childNodeNames);
926     os << "\n      );\n";
927     os.indent(4) << "}\n";
928     return resultValue;
929   }
930 
931   // If depth == 0 and resultIndex >= 0, it means we are replacing the values
932   // generated from the source pattern root op. Then we can use the source
933   // pattern's value types to determine the value type of the generated op
934   // here.
935 
936   // First prepare local variables for op arguments used in builder call.
937   createAggregateLocalVarsForOpArgs(tree, childNodeNames);
938 
939   // Then prepare the result types. We need to specify the types for all
940   // results.
941   os.indent(6) << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
942                           "(void)tblgen_types;\n");
943   int numResults = resultOp.getNumResults();
944   if (numResults != 0) {
945     for (int i = 0; i < numResults; ++i)
946       os.indent(6) << formatv("for (auto v : castedOp0.getODSResults({0})) {{"
947                               "tblgen_types.push_back(v.getType()); }\n",
948                               resultIndex + i);
949   }
950   os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
951                           "tblgen_values, tblgen_attrs);\n",
952                           valuePackName, resultOp.getQualCppClassName(),
953                           locToUse);
954   os.indent(4) << "}\n";
955   return resultValue;
956 }
957 
958 void PatternEmitter::createSeparateLocalVarsForOpArgs(
959     DagNode node, ChildNodeIndexNameMap &childNodeNames) {
960   Operator &resultOp = node.getDialectOp(opMap);
961 
962   // Now prepare operands used for building this op:
963   // * If the operand is non-variadic, we create a `Value` local variable.
964   // * If the operand is variadic, we create a `SmallVector<Value>` local
965   //   variable.
966 
967   int valueIndex = 0; // An index for uniquing local variable names.
968   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
969     const auto *operand =
970         resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
971     if (!operand) {
972       // We do not need special handling for attributes.
973       continue;
974     }
975 
976     std::string varName;
977     if (operand->isVariadic()) {
978       varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
979       os.indent(6) << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n",
980                               varName);
981       std::string range;
982       if (node.isNestedDagArg(argIndex)) {
983         range = childNodeNames[argIndex];
984       } else {
985         range = std::string(node.getArgName(argIndex));
986       }
987       // Resolve the symbol for all range use so that we have a uniform way of
988       // capturing the values.
989       range = symbolInfoMap.getValueAndRangeUse(range);
990       os.indent(6) << formatv("for (auto v : {0}) {1}.push_back(v);\n", range,
991                               varName);
992     } else {
993       varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
994       os.indent(6) << formatv("::mlir::Value {0} = ", varName);
995       if (node.isNestedDagArg(argIndex)) {
996         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
997       } else {
998         DagLeaf leaf = node.getArgAsLeaf(argIndex);
999         auto symbol =
1000             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1001         if (leaf.isNativeCodeCall()) {
1002           os << std::string(
1003               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1004         } else {
1005           os << symbol;
1006         }
1007       }
1008       os << ";\n";
1009     }
1010 
1011     // Update to use the newly created local variable for building the op later.
1012     childNodeNames[argIndex] = varName;
1013   }
1014 }
1015 
1016 void PatternEmitter::supplyValuesForOpArgs(
1017     DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
1018   Operator &resultOp = node.getDialectOp(opMap);
1019   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1020        argIndex != numOpArgs; ++argIndex) {
1021     // Start each argument on its own line.
1022     (os << ",\n").indent(8);
1023 
1024     Argument opArg = resultOp.getArg(argIndex);
1025     // Handle the case of operand first.
1026     if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
1027       if (!operand->name.empty())
1028         os << "/*" << operand->name << "=*/";
1029       os << childNodeNames.lookup(argIndex);
1030       continue;
1031     }
1032 
1033     // The argument in the op definition.
1034     auto opArgName = resultOp.getArgName(argIndex);
1035     if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1036       if (!subTree.isNativeCodeCall())
1037         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1038                              "for creating attribute");
1039       os << formatv("/*{0}=*/{1}", opArgName,
1040                     handleReplaceWithNativeCodeCall(subTree));
1041     } else {
1042       auto leaf = node.getArgAsLeaf(argIndex);
1043       // The argument in the result DAG pattern.
1044       auto patArgName = node.getArgName(argIndex);
1045       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1046         // TODO: Refactor out into map to avoid recomputing these.
1047         if (!opArg.is<NamedAttribute *>())
1048           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1049         if (!patArgName.empty())
1050           os << "/*" << patArgName << "=*/";
1051       } else {
1052         os << "/*" << opArgName << "=*/";
1053       }
1054       os << handleOpArgument(leaf, patArgName);
1055     }
1056   }
1057 }
1058 
1059 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1060     DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
1061   Operator &resultOp = node.getDialectOp(opMap);
1062 
1063   os.indent(6) << formatv("::mlir::SmallVector<::mlir::Value, 4> "
1064                           "tblgen_values; (void)tblgen_values;\n");
1065   os.indent(6) << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
1066                           "tblgen_attrs; (void)tblgen_attrs;\n");
1067 
1068   const char *addAttrCmd =
1069       "if (auto tmpAttr = {1}) "
1070       "tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), tmpAttr);\n";
1071   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1072     if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1073       // The argument in the op definition.
1074       auto opArgName = resultOp.getArgName(argIndex);
1075       if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1076         if (!subTree.isNativeCodeCall())
1077           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1078                                "for creating attribute");
1079         os.indent(6) << formatv(addAttrCmd, opArgName,
1080                                 handleReplaceWithNativeCodeCall(subTree));
1081       } else {
1082         auto leaf = node.getArgAsLeaf(argIndex);
1083         // The argument in the result DAG pattern.
1084         auto patArgName = node.getArgName(argIndex);
1085         os.indent(6) << formatv(addAttrCmd, opArgName,
1086                                 handleOpArgument(leaf, patArgName));
1087       }
1088       continue;
1089     }
1090 
1091     const auto *operand =
1092         resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1093     std::string varName;
1094     if (operand->isVariadic()) {
1095       std::string range;
1096       if (node.isNestedDagArg(argIndex)) {
1097         range = childNodeNames.lookup(argIndex);
1098       } else {
1099         range = std::string(node.getArgName(argIndex));
1100       }
1101       // Resolve the symbol for all range use so that we have a uniform way of
1102       // capturing the values.
1103       range = symbolInfoMap.getValueAndRangeUse(range);
1104       os.indent(6) << formatv(
1105           "for (auto v : {0}) tblgen_values.push_back(v);\n", range);
1106     } else {
1107       os.indent(6) << formatv("tblgen_values.push_back(", varName);
1108       if (node.isNestedDagArg(argIndex)) {
1109         os << symbolInfoMap.getValueAndRangeUse(
1110             childNodeNames.lookup(argIndex));
1111       } else {
1112         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1113         auto symbol =
1114             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1115         if (leaf.isNativeCodeCall()) {
1116           os << std::string(
1117               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1118         } else {
1119           os << symbol;
1120         }
1121       }
1122       os << ");\n";
1123     }
1124   }
1125 }
1126 
1127 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1128   emitSourceFileHeader("Rewriters", os);
1129 
1130   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1131   auto numPatterns = patterns.size();
1132 
1133   // We put the map here because it can be shared among multiple patterns.
1134   RecordOperatorMap recordOpMap;
1135 
1136   std::vector<std::string> rewriterNames;
1137   rewriterNames.reserve(numPatterns);
1138 
1139   std::string baseRewriterName = "GeneratedConvert";
1140   int rewriterIndex = 0;
1141 
1142   for (Record *p : patterns) {
1143     std::string name;
1144     if (p->isAnonymous()) {
1145       // If no name is provided, ensure unique rewriter names simply by
1146       // appending unique suffix.
1147       name = baseRewriterName + llvm::utostr(rewriterIndex++);
1148     } else {
1149       name = std::string(p->getName());
1150     }
1151     LLVM_DEBUG(llvm::dbgs()
1152                << "=== start generating pattern '" << name << "' ===\n");
1153     PatternEmitter(p, &recordOpMap, os).emit(name);
1154     LLVM_DEBUG(llvm::dbgs()
1155                << "=== done generating pattern '" << name << "' ===\n");
1156     rewriterNames.push_back(std::move(name));
1157   }
1158 
1159   // Emit function to add the generated matchers to the pattern list.
1160   os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(MLIRContext "
1161         "*context, OwningRewritePatternList *patterns) {\n";
1162   for (const auto &name : rewriterNames) {
1163     os << "  patterns->insert<" << name << ">(context);\n";
1164   }
1165   os << "}\n";
1166 }
1167 
1168 static mlir::GenRegistration
1169     genRewriters("gen-rewriters", "Generate pattern rewriters",
1170                  [](const RecordKeeper &records, raw_ostream &os) {
1171                    emitRewriters(records, os);
1172                    return false;
1173                  });
1174