xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 9b851527d53345c4a5d56a909dfa1ca7f59a0c11)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/TableGen/Pattern.h"
19 #include "mlir/TableGen/Predicate.h"
20 #include "mlir/TableGen/Type.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringSet.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatAdapters.h"
26 #include "llvm/Support/PrettyStackTrace.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Main.h"
30 #include "llvm/TableGen/Record.h"
31 #include "llvm/TableGen/TableGenBackend.h"
32 
33 using namespace mlir;
34 using namespace mlir::tblgen;
35 
36 using llvm::formatv;
37 using llvm::Record;
38 using llvm::RecordKeeper;
39 
40 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
41 
42 namespace llvm {
43 template <>
44 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
45   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
46                      raw_ostream &os, StringRef style) {
47     os << v.first << ":" << v.second;
48   }
49 };
50 } // end namespace llvm
51 
52 //===----------------------------------------------------------------------===//
53 // PatternEmitter
54 //===----------------------------------------------------------------------===//
55 
56 namespace {
57 class PatternEmitter {
58 public:
59   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
60 
61   // Emits the mlir::RewritePattern struct named `rewriteName`.
62   void emit(StringRef rewriteName);
63 
64 private:
65   // Emits the code for matching ops.
66   void emitMatchLogic(DagNode tree);
67 
68   // Emits the code for rewriting ops.
69   void emitRewriteLogic();
70 
71   //===--------------------------------------------------------------------===//
72   // Match utilities
73   //===--------------------------------------------------------------------===//
74 
75   // Emits C++ statements for matching the op constrained by the given DAG
76   // `tree`.
77   void emitOpMatch(DagNode tree, int depth);
78 
79   // Emits C++ statements for matching the `argIndex`-th argument of the given
80   // DAG `tree` as an operand.
81   void emitOperandMatch(DagNode tree, int argIndex, int depth);
82 
83   // Emits C++ statements for matching the `argIndex`-th argument of the given
84   // DAG `tree` as an attribute.
85   void emitAttributeMatch(DagNode tree, int argIndex, int depth);
86 
87   // Emits C++ for checking a match with a corresponding match failure
88   // diagnostic.
89   void emitMatchCheck(int depth, const FmtObjectBase &matchFmt,
90                       const llvm::formatv_object_base &failureFmt);
91 
92   //===--------------------------------------------------------------------===//
93   // Rewrite utilities
94   //===--------------------------------------------------------------------===//
95 
96   // The entry point for handling a result pattern rooted at `resultTree`. This
97   // method dispatches to concrete handlers according to `resultTree`'s kind and
98   // returns a symbol representing the whole value pack. Callers are expected to
99   // further resolve the symbol according to the specific use case.
100   //
101   // `depth` is the nesting level of `resultTree`; 0 means top-level result
102   // pattern. For top-level result pattern, `resultIndex` indicates which result
103   // of the matched root op this pattern is intended to replace, which can be
104   // used to deduce the result type of the op generated from this result
105   // pattern.
106   std::string handleResultPattern(DagNode resultTree, int resultIndex,
107                                   int depth);
108 
109   // Emits the C++ statement to replace the matched DAG with a value built via
110   // calling native C++ code.
111   std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
112 
113   // Returns the symbol of the old value serving as the replacement.
114   StringRef handleReplaceWithValue(DagNode tree);
115 
116   // Returns the location value to use.
117   std::pair<bool, std::string> getLocation(DagNode tree);
118 
119   // Returns the location value to use.
120   std::string handleLocationDirective(DagNode tree);
121 
122   // Emits the C++ statement to build a new op out of the given DAG `tree` and
123   // returns the variable name that this op is assigned to. If the root op in
124   // DAG `tree` has a specified name, the created op will be assigned to a
125   // variable of the given name. Otherwise, a unique name will be used as the
126   // result value name.
127   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
128 
129   using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
130 
131   // Emits a local variable for each value and attribute to be used for creating
132   // an op.
133   void createSeparateLocalVarsForOpArgs(DagNode node,
134                                         ChildNodeIndexNameMap &childNodeNames);
135 
136   // Emits the concrete arguments used to call an op's builder.
137   void supplyValuesForOpArgs(DagNode node,
138                              const ChildNodeIndexNameMap &childNodeNames);
139 
140   // Emits the local variables for holding all values as a whole and all named
141   // attributes as a whole to be used for creating an op.
142   void createAggregateLocalVarsForOpArgs(
143       DagNode node, const ChildNodeIndexNameMap &childNodeNames);
144 
145   // Returns the C++ expression to construct a constant attribute of the given
146   // `value` for the given attribute kind `attr`.
147   std::string handleConstantAttr(Attribute attr, StringRef value);
148 
149   // Returns the C++ expression to build an argument from the given DAG `leaf`.
150   // `patArgName` is used to bound the argument to the source pattern.
151   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
152 
153   //===--------------------------------------------------------------------===//
154   // General utilities
155   //===--------------------------------------------------------------------===//
156 
157   // Collects all of the operations within the given dag tree.
158   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
159 
160   // Returns a unique symbol for a local variable of the given `op`.
161   std::string getUniqueSymbol(const Operator *op);
162 
163   //===--------------------------------------------------------------------===//
164   // Symbol utilities
165   //===--------------------------------------------------------------------===//
166 
167   // Returns how many static values the given DAG `node` correspond to.
168   int getNodeValueCount(DagNode node);
169 
170 private:
171   // Pattern instantiation location followed by the location of multiclass
172   // prototypes used. This is intended to be used as a whole to
173   // PrintFatalError() on errors.
174   ArrayRef<llvm::SMLoc> loc;
175 
176   // Op's TableGen Record to wrapper object.
177   RecordOperatorMap *opMap;
178 
179   // Handy wrapper for pattern being emitted.
180   Pattern pattern;
181 
182   // Map for all bound symbols' info.
183   SymbolInfoMap symbolInfoMap;
184 
185   // The next unused ID for newly created values.
186   unsigned nextValueId;
187 
188   raw_indented_ostream os;
189 
190   // Format contexts containing placeholder substitutions.
191   FmtContext fmtCtx;
192 
193   // Number of op processed.
194   int opCounter = 0;
195 };
196 } // end anonymous namespace
197 
198 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
199                                raw_ostream &os)
200     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
201       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
202   fmtCtx.withBuilder("rewriter");
203 }
204 
205 std::string PatternEmitter::handleConstantAttr(Attribute attr,
206                                                StringRef value) {
207   if (!attr.isConstBuildable())
208     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
209                              " does not have the 'constBuilderCall' field");
210 
211   // TODO: Verify the constants here
212   return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
213 }
214 
215 // Helper function to match patterns.
216 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
217   Operator &op = tree.getDialectOp(opMap);
218   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
219                           << op.getOperationName() << "' at depth " << depth
220                           << '\n');
221 
222   int indent = 4 + 2 * depth;
223   os.indent(indent) << formatv(
224       "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
225       depth, op.getQualCppClassName());
226   // Skip the operand matching at depth 0 as the pattern rewriter already does.
227   if (depth != 0) {
228     // Skip if there is no defining operation (e.g., arguments to function).
229     os << formatv("if (!castedOp{0})\n  return failure();\n", depth);
230   }
231   if (tree.getNumArgs() != op.getNumArgs()) {
232     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
233                                  "pattern vs. {2} in definition",
234                                  op.getOperationName(), tree.getNumArgs(),
235                                  op.getNumArgs()));
236   }
237 
238   // If the operand's name is set, set to that variable.
239   auto name = tree.getSymbol();
240   if (!name.empty())
241     os << formatv("{0} = castedOp{1};\n", name, depth);
242 
243   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
244     auto opArg = op.getArg(i);
245 
246     // Handle nested DAG construct first
247     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
248       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
249         if (operand->isVariableLength()) {
250           auto error = formatv("use nested DAG construct to match op {0}'s "
251                                "variadic operand #{1} unsupported now",
252                                op.getOperationName(), i);
253           PrintFatalError(loc, error);
254         }
255       }
256       os << "{\n";
257 
258       os.indent() << formatv(
259           "auto *op{0} = "
260           "(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
261           depth + 1, depth, i);
262       emitOpMatch(argTree, depth + 1);
263       os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
264       os.unindent() << "}\n";
265       continue;
266     }
267 
268     // Next handle DAG leaf: operand or attribute
269     if (opArg.is<NamedTypeConstraint *>()) {
270       emitOperandMatch(tree, i, depth);
271     } else if (opArg.is<NamedAttribute *>()) {
272       emitAttributeMatch(tree, i, depth);
273     } else {
274       PrintFatalError(loc, "unhandled case when matching op");
275     }
276   }
277   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
278                           << op.getOperationName() << "' at depth " << depth
279                           << '\n');
280 }
281 
282 void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
283   Operator &op = tree.getDialectOp(opMap);
284   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
285   auto matcher = tree.getArgAsLeaf(argIndex);
286 
287   // If a constraint is specified, we need to generate C++ statements to
288   // check the constraint.
289   if (!matcher.isUnspecified()) {
290     if (!matcher.isOperandMatcher()) {
291       PrintFatalError(
292           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
293                        op.getOperationName(), argIndex + 1));
294     }
295 
296     // Only need to verify if the matcher's type is different from the one
297     // of op definition.
298     Constraint constraint = matcher.getAsConstraint();
299     if (operand->constraint != constraint) {
300       if (operand->isVariableLength()) {
301         auto error = formatv(
302             "further constrain op {0}'s variadic operand #{1} unsupported now",
303             op.getOperationName(), argIndex);
304         PrintFatalError(loc, error);
305       }
306       auto self =
307           formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth,
308                   argIndex);
309       emitMatchCheck(
310           depth,
311           tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
312           formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
313                   "'{2}'\"",
314                   operand - op.operand_begin(), op.getOperationName(),
315                   constraint.getDescription()));
316     }
317   }
318 
319   // Capture the value
320   auto name = tree.getArgName(argIndex);
321   // `$_` is a special symbol to ignore op argument matching.
322   if (!name.empty() && name != "_") {
323     // We need to subtract the number of attributes before this operand to get
324     // the index in the operand list.
325     auto numPrevAttrs = std::count_if(
326         op.arg_begin(), op.arg_begin() + argIndex,
327         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
328 
329     os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth,
330                   argIndex - numPrevAttrs);
331   }
332 }
333 
334 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
335   Operator &op = tree.getDialectOp(opMap);
336   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
337   const auto &attr = namedAttr->attr;
338 
339   os << "{\n";
340   os.indent() << formatv(
341       "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); "
342       "(void)tblgen_attr;\n",
343       depth, attr.getStorageType(), namedAttr->name);
344 
345   // TODO: This should use getter method to avoid duplication.
346   if (attr.hasDefaultValue()) {
347     os << "if (!tblgen_attr) tblgen_attr = "
348        << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
349                             attr.getDefaultValue()))
350        << ";\n";
351   } else if (attr.isOptional()) {
352     // For a missing attribute that is optional according to definition, we
353     // should just capture a mlir::Attribute() to signal the missing state.
354     // That is precisely what getAttr() returns on missing attributes.
355   } else {
356     emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx),
357                    formatv("\"expected op '{0}' to have attribute '{1}' "
358                            "of type '{2}'\"",
359                            op.getOperationName(), namedAttr->name,
360                            attr.getStorageType()));
361   }
362 
363   auto matcher = tree.getArgAsLeaf(argIndex);
364   if (!matcher.isUnspecified()) {
365     if (!matcher.isAttrMatcher()) {
366       PrintFatalError(
367           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
368                        op.getOperationName(), argIndex + 1));
369     }
370 
371     // If a constraint is specified, we need to generate C++ statements to
372     // check the constraint.
373     emitMatchCheck(
374         depth,
375         tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
376         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
377                 "{2}\"",
378                 op.getOperationName(), namedAttr->name,
379                 matcher.getAsConstraint().getDescription()));
380   }
381 
382   // Capture the value
383   auto name = tree.getArgName(argIndex);
384   // `$_` is a special symbol to ignore op argument matching.
385   if (!name.empty() && name != "_") {
386     os << formatv("{0} = tblgen_attr;\n", name);
387   }
388 
389   os.unindent() << "}\n";
390 }
391 
392 void PatternEmitter::emitMatchCheck(
393     int depth, const FmtObjectBase &matchFmt,
394     const llvm::formatv_object_base &failureFmt) {
395   os << "if (!(" << matchFmt.str() << "))";
396   os.scope("{\n", "\n}\n").os
397       << "return rewriter.notifyMatchFailure(op" << depth
398       << ", [&](::mlir::Diagnostic &diag) {\n  diag << " << failureFmt.str()
399       << ";\n});";
400 }
401 
402 void PatternEmitter::emitMatchLogic(DagNode tree) {
403   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
404   int depth = 0;
405   emitOpMatch(tree, depth);
406 
407   for (auto &appliedConstraint : pattern.getConstraints()) {
408     auto &constraint = appliedConstraint.constraint;
409     auto &entities = appliedConstraint.entities;
410 
411     auto condition = constraint.getConditionTemplate();
412     if (isa<TypeConstraint>(constraint)) {
413       auto self = formatv("({0}.getType())",
414                           symbolInfoMap.getValueAndRangeUse(entities.front()));
415       emitMatchCheck(
416           depth, tgfmt(condition, &fmtCtx.withSelf(self.str())),
417           formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
418                   entities.front(), constraint.getDescription()));
419 
420     } else if (isa<AttrConstraint>(constraint)) {
421       PrintFatalError(
422           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
423     } else {
424       // TODO: replace formatv arguments with the exact specified
425       // args.
426       if (entities.size() > 4) {
427         PrintFatalError(loc, "only support up to 4-entity constraints now");
428       }
429       SmallVector<std::string, 4> names;
430       int i = 0;
431       for (int e = entities.size(); i < e; ++i)
432         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
433       std::string self = appliedConstraint.self;
434       if (!self.empty())
435         self = symbolInfoMap.getValueAndRangeUse(self);
436       for (; i < 4; ++i)
437         names.push_back("<unused>");
438       emitMatchCheck(depth,
439                      tgfmt(condition, &fmtCtx.withSelf(self), names[0],
440                            names[1], names[2], names[3]),
441                      formatv("\"entities '{0}' failed to satisfy constraint: "
442                              "{1}\"",
443                              llvm::join(entities, ", "),
444                              constraint.getDescription()));
445     }
446   }
447   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
448 }
449 
450 void PatternEmitter::collectOps(DagNode tree,
451                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
452   // Check if this tree is an operation.
453   if (tree.isOperation()) {
454     const Operator &op = tree.getDialectOp(opMap);
455     LLVM_DEBUG(llvm::dbgs()
456                << "found operation " << op.getOperationName() << '\n');
457     ops.insert(&op);
458   }
459 
460   // Recurse the arguments of the tree.
461   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
462     if (auto child = tree.getArgAsNestedDag(i))
463       collectOps(child, ops);
464 }
465 
466 void PatternEmitter::emit(StringRef rewriteName) {
467   // Get the DAG tree for the source pattern.
468   DagNode sourceTree = pattern.getSourcePattern();
469 
470   const Operator &rootOp = pattern.getSourceRootOp();
471   auto rootName = rootOp.getOperationName();
472 
473   // Collect the set of result operations.
474   llvm::SmallPtrSet<const Operator *, 4> resultOps;
475   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
476   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
477     collectOps(pattern.getResultPattern(i), resultOps);
478   }
479   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
480 
481   // Emit RewritePattern for Pattern.
482   auto locs = pattern.getLocation();
483   os << formatv("/* Generated from:\n    {0:$[ instantiating\n    ]}\n*/\n",
484                 make_range(locs.rbegin(), locs.rend()));
485   os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
486   {0}(::mlir::MLIRContext *context)
487       : ::mlir::RewritePattern("{1}", {{)",
488                 rewriteName, rootName);
489   // Sort result operators by name.
490   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
491                                                          resultOps.end());
492   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
493     return lhs->getOperationName() < rhs->getOperationName();
494   });
495   llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
496     os << '"' << op->getOperationName() << '"';
497   });
498   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
499 
500   // Emit matchAndRewrite() function.
501   {
502     auto classScope = os.scope();
503     os.reindent(R"(
504     ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
505         ::mlir::PatternRewriter &rewriter) const override {)")
506         << '\n';
507     {
508       auto functionScope = os.scope();
509 
510       // Register all symbols bound in the source pattern.
511       pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
512 
513       LLVM_DEBUG(llvm::dbgs()
514                  << "start creating local variables for capturing matches\n");
515       os << "// Variables for capturing values and attributes used while "
516             "creating ops\n";
517       // Create local variables for storing the arguments and results bound
518       // to symbols.
519       for (const auto &symbolInfoPair : symbolInfoMap) {
520         StringRef symbol = symbolInfoPair.getKey();
521         auto &info = symbolInfoPair.getValue();
522         os << info.getVarDecl(symbol);
523       }
524       // TODO: capture ops with consistent numbering so that it can be
525       // reused for fused loc.
526       os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
527                     pattern.getSourcePattern().getNumOps());
528       LLVM_DEBUG(llvm::dbgs()
529                  << "done creating local variables for capturing matches\n");
530 
531       os << "// Match\n";
532       os << "tblgen_ops[0] = op0;\n";
533       emitMatchLogic(sourceTree);
534 
535       os << "\n// Rewrite\n";
536       emitRewriteLogic();
537 
538       os << "return success();\n";
539     }
540     os << "};\n";
541   }
542   os << "};\n\n";
543 }
544 
545 void PatternEmitter::emitRewriteLogic() {
546   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
547   const Operator &rootOp = pattern.getSourceRootOp();
548   int numExpectedResults = rootOp.getNumResults();
549   int numResultPatterns = pattern.getNumResultPatterns();
550 
551   // First register all symbols bound to ops generated in result patterns.
552   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
553 
554   // Only the last N static values generated are used to replace the matched
555   // root N-result op. We need to calculate the starting index (of the results
556   // of the matched op) each result pattern is to replace.
557   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
558   // If we don't need to replace any value at all, set the replacement starting
559   // index as the number of result patterns so we skip all of them when trying
560   // to replace the matched op's results.
561   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
562   for (int i = numResultPatterns - 1; i >= 0; --i) {
563     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
564     offsets[i] = offsets[i + 1] - numValues;
565     if (offsets[i] == 0) {
566       if (replStartIndex == -1)
567         replStartIndex = i;
568     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
569       auto error = formatv(
570           "cannot use the same multi-result op '{0}' to generate both "
571           "auxiliary values and values to be used for replacing the matched op",
572           pattern.getResultPattern(i).getSymbol());
573       PrintFatalError(loc, error);
574     }
575   }
576 
577   if (offsets.front() > 0) {
578     const char error[] = "no enough values generated to replace the matched op";
579     PrintFatalError(loc, error);
580   }
581 
582   os << "auto odsLoc = rewriter.getFusedLoc({";
583   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
584     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
585   }
586   os << "}); (void)odsLoc;\n";
587 
588   // Process auxiliary result patterns.
589   for (int i = 0; i < replStartIndex; ++i) {
590     DagNode resultTree = pattern.getResultPattern(i);
591     auto val = handleResultPattern(resultTree, offsets[i], 0);
592     // Normal op creation will be streamed to `os` by the above call; but
593     // NativeCodeCall will only be materialized to `os` if it is used. Here
594     // we are handling auxiliary patterns so we want the side effect even if
595     // NativeCodeCall is not replacing matched root op's results.
596     if (resultTree.isNativeCodeCall())
597       os << val << ";\n";
598   }
599 
600   if (numExpectedResults == 0) {
601     assert(replStartIndex >= numResultPatterns &&
602            "invalid auxiliary vs. replacement pattern division!");
603     // No result to replace. Just erase the op.
604     os << "rewriter.eraseOp(op0);\n";
605   } else {
606     // Process replacement result patterns.
607     os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
608     for (int i = replStartIndex; i < numResultPatterns; ++i) {
609       DagNode resultTree = pattern.getResultPattern(i);
610       auto val = handleResultPattern(resultTree, offsets[i], 0);
611       os << "\n";
612       // Resolve each symbol for all range use so that we can loop over them.
613       // We need an explicit cast to `SmallVector` to capture the cases where
614       // `{0}` resolves to an `Operation::result_range` as well as cases that
615       // are not iterable (e.g. vector that gets wrapped in additional braces by
616       // RewriterGen).
617       // TODO: Revisit the need for materializing a vector.
618       os << symbolInfoMap.getAllRangeUse(
619           val,
620           "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
621           "  tblgen_repl_values.push_back(v);\n}\n",
622           "\n");
623     }
624     os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
625   }
626 
627   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
628 }
629 
630 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
631   return std::string(
632       formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
633 }
634 
635 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
636                                                 int resultIndex, int depth) {
637   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
638   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
639   LLVM_DEBUG(llvm::dbgs() << '\n');
640 
641   if (resultTree.isLocationDirective()) {
642     PrintFatalError(loc,
643                     "location directive can only be used with op creation");
644   }
645 
646   if (resultTree.isNativeCodeCall()) {
647     auto symbol = handleReplaceWithNativeCodeCall(resultTree);
648     symbolInfoMap.bindValue(symbol);
649     return symbol;
650   }
651 
652   if (resultTree.isReplaceWithValue())
653     return handleReplaceWithValue(resultTree).str();
654 
655   // Normal op creation.
656   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
657   if (resultTree.getSymbol().empty()) {
658     // This is an op not explicitly bound to a symbol in the rewrite rule.
659     // Register the auto-generated symbol for it.
660     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
661   }
662   return symbol;
663 }
664 
665 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
666   assert(tree.isReplaceWithValue());
667 
668   if (tree.getNumArgs() != 1) {
669     PrintFatalError(
670         loc, "replaceWithValue directive must take exactly one argument");
671   }
672 
673   if (!tree.getSymbol().empty()) {
674     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
675   }
676 
677   return tree.getArgName(0);
678 }
679 
680 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
681   assert(tree.isLocationDirective());
682   auto lookUpArgLoc = [this, &tree](int idx) {
683     const auto *const lookupFmt = "(*{0}.begin()).getLoc()";
684     return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt);
685   };
686 
687   if (tree.getNumArgs() == 0)
688     llvm::PrintFatalError(
689         "At least one argument to location directive required");
690 
691   if (!tree.getSymbol().empty())
692     PrintFatalError(loc, "cannot bind symbol to location");
693 
694   if (tree.getNumArgs() == 1) {
695     DagLeaf leaf = tree.getArgAsLeaf(0);
696     if (leaf.isStringAttr())
697       return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), "
698                      "rewriter.getContext())",
699                      leaf.getStringAttr())
700           .str();
701     return lookUpArgLoc(0);
702   }
703 
704   std::string ret;
705   llvm::raw_string_ostream os(ret);
706   std::string strAttr;
707   os << "rewriter.getFusedLoc({";
708   bool first = true;
709   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
710     DagLeaf leaf = tree.getArgAsLeaf(i);
711     // Handle the optional string value.
712     if (leaf.isStringAttr()) {
713       if (!strAttr.empty())
714         llvm::PrintFatalError("Only one string attribute may be specified");
715       strAttr = leaf.getStringAttr();
716       continue;
717     }
718     os << (first ? "" : ", ") << lookUpArgLoc(i);
719     first = false;
720   }
721   os << "}";
722   if (!strAttr.empty()) {
723     os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
724   }
725   os << ")";
726   return os.str();
727 }
728 
729 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
730                                              StringRef patArgName) {
731   if (leaf.isStringAttr())
732     PrintFatalError(loc, "raw string not supported as argument");
733   if (leaf.isConstantAttr()) {
734     auto constAttr = leaf.getAsConstantAttr();
735     return handleConstantAttr(constAttr.getAttribute(),
736                               constAttr.getConstantValue());
737   }
738   if (leaf.isEnumAttrCase()) {
739     auto enumCase = leaf.getAsEnumAttrCase();
740     if (enumCase.isStrCase())
741       return handleConstantAttr(enumCase, enumCase.getSymbol());
742     // This is an enum case backed by an IntegerAttr. We need to get its value
743     // to build the constant.
744     std::string val = std::to_string(enumCase.getValue());
745     return handleConstantAttr(enumCase, val);
746   }
747 
748   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
749   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
750   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
751     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
752                             << "' (via symbol ref)\n");
753     return argName;
754   }
755   if (leaf.isNativeCodeCall()) {
756     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
757     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
758                             << "' (via NativeCodeCall)\n");
759     return std::string(repl);
760   }
761   PrintFatalError(loc, "unhandled case when rewriting op");
762 }
763 
764 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
765   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
766   LLVM_DEBUG(tree.print(llvm::dbgs()));
767   LLVM_DEBUG(llvm::dbgs() << '\n');
768 
769   auto fmt = tree.getNativeCodeTemplate();
770   // TODO: replace formatv arguments with the exact specified args.
771   SmallVector<std::string, 8> attrs(8);
772   if (tree.getNumArgs() > 8) {
773     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
774                              Twine(tree.getNumArgs()));
775   }
776   bool hasLocationDirective;
777   std::string locToUse;
778   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
779 
780   for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
781     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
782     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
783                             << " replacement: " << attrs[i] << "\n");
784   }
785   return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0],
786                            attrs[1], attrs[2], attrs[3], attrs[4], attrs[5],
787                            attrs[6], attrs[7]));
788 }
789 
790 int PatternEmitter::getNodeValueCount(DagNode node) {
791   if (node.isOperation()) {
792     // If the op is bound to a symbol in the rewrite rule, query its result
793     // count from the symbol info map.
794     auto symbol = node.getSymbol();
795     if (!symbol.empty()) {
796       return symbolInfoMap.getStaticValueCount(symbol);
797     }
798     // Otherwise this is an unbound op; we will use all its results.
799     return pattern.getDialectOp(node).getNumResults();
800   }
801   // TODO: This considers all NativeCodeCall as returning one
802   // value. Enhance if multi-value ones are needed.
803   return 1;
804 }
805 
806 std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) {
807   auto numPatArgs = tree.getNumArgs();
808 
809   if (numPatArgs != 0) {
810     if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
811       if (lastArg.isLocationDirective()) {
812         return std::make_pair(true, handleLocationDirective(lastArg));
813       }
814   }
815 
816   // If no explicit location is given, use the default, all fused, location.
817   return std::make_pair(false, "odsLoc");
818 }
819 
820 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
821                                              int depth) {
822   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
823   LLVM_DEBUG(tree.print(llvm::dbgs()));
824   LLVM_DEBUG(llvm::dbgs() << '\n');
825 
826   Operator &resultOp = tree.getDialectOp(opMap);
827   auto numOpArgs = resultOp.getNumArgs();
828   auto numPatArgs = tree.getNumArgs();
829 
830   bool hasLocationDirective;
831   std::string locToUse;
832   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
833 
834   auto inPattern = numPatArgs - hasLocationDirective;
835   if (numOpArgs != inPattern) {
836     PrintFatalError(loc,
837                     formatv("resultant op '{0}' argument number mismatch: "
838                             "{1} in pattern vs. {2} in definition",
839                             resultOp.getOperationName(), inPattern, numOpArgs));
840   }
841 
842   // A map to collect all nested DAG child nodes' names, with operand index as
843   // the key. This includes both bound and unbound child nodes.
844   ChildNodeIndexNameMap childNodeNames;
845 
846   // First go through all the child nodes who are nested DAG constructs to
847   // create ops for them and remember the symbol names for them, so that we can
848   // use the results in the current node. This happens in a recursive manner.
849   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
850     if (auto child = tree.getArgAsNestedDag(i))
851       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
852   }
853 
854   // The name of the local variable holding this op.
855   std::string valuePackName;
856   // The symbol for holding the result of this pattern. Note that the result of
857   // this pattern is not necessarily the same as the variable created by this
858   // pattern because we can use `__N` suffix to refer only a specific result if
859   // the generated op is a multi-result op.
860   std::string resultValue;
861   if (tree.getSymbol().empty()) {
862     // No symbol is explicitly bound to this op in the pattern. Generate a
863     // unique name.
864     valuePackName = resultValue = getUniqueSymbol(&resultOp);
865   } else {
866     resultValue = std::string(tree.getSymbol());
867     // Strip the index to get the name for the value pack and use it to name the
868     // local variable for the op.
869     valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
870   }
871 
872   // Create the local variable for this op.
873   os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
874                 valuePackName);
875 
876   // Right now ODS don't have general type inference support. Except a few
877   // special cases listed below, DRR needs to supply types for all results
878   // when building an op.
879   bool isSameOperandsAndResultType =
880       resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
881   bool useFirstAttr =
882       resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
883 
884   if (isSameOperandsAndResultType || useFirstAttr) {
885     // We know how to deduce the result type for ops with these traits and we've
886     // generated builders taking aggregate parameters. Use those builders to
887     // create the ops.
888 
889     // First prepare local variables for op arguments used in builder call.
890     createAggregateLocalVarsForOpArgs(tree, childNodeNames);
891 
892     // Then create the op.
893     os.scope("", "\n}\n").os << formatv(
894         "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
895         valuePackName, resultOp.getQualCppClassName(), locToUse);
896     return resultValue;
897   }
898 
899   bool usePartialResults = valuePackName != resultValue;
900 
901   if (usePartialResults || depth > 0 || resultIndex < 0) {
902     // For these cases (broadcastable ops, op results used both as auxiliary
903     // values and replacement values, ops in nested patterns, auxiliary ops), we
904     // still need to supply the result types when building the op. But because
905     // we don't generate a builder automatically with ODS for them, it's the
906     // developer's responsibility to make sure such a builder (with result type
907     // deduction ability) exists. We go through the separate-parameter builder
908     // here given that it's easier for developers to write compared to
909     // aggregate-parameter builders.
910     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
911 
912     os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
913                              resultOp.getQualCppClassName(), locToUse);
914     supplyValuesForOpArgs(tree, childNodeNames);
915     os << "\n  );\n}\n";
916     return resultValue;
917   }
918 
919   // If depth == 0 and resultIndex >= 0, it means we are replacing the values
920   // generated from the source pattern root op. Then we can use the source
921   // pattern's value types to determine the value type of the generated op
922   // here.
923 
924   // First prepare local variables for op arguments used in builder call.
925   createAggregateLocalVarsForOpArgs(tree, childNodeNames);
926 
927   // Then prepare the result types. We need to specify the types for all
928   // results.
929   os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
930                          "(void)tblgen_types;\n");
931   int numResults = resultOp.getNumResults();
932   if (numResults != 0) {
933     for (int i = 0; i < numResults; ++i)
934       os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
935                     "  tblgen_types.push_back(v.getType());\n}\n",
936                     resultIndex + i);
937   }
938   os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
939                 "tblgen_values, tblgen_attrs);\n",
940                 valuePackName, resultOp.getQualCppClassName(), locToUse);
941   os.unindent() << "}\n";
942   return resultValue;
943 }
944 
945 void PatternEmitter::createSeparateLocalVarsForOpArgs(
946     DagNode node, ChildNodeIndexNameMap &childNodeNames) {
947   Operator &resultOp = node.getDialectOp(opMap);
948 
949   // Now prepare operands used for building this op:
950   // * If the operand is non-variadic, we create a `Value` local variable.
951   // * If the operand is variadic, we create a `SmallVector<Value>` local
952   //   variable.
953 
954   int valueIndex = 0; // An index for uniquing local variable names.
955   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
956     const auto *operand =
957         resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
958     // We do not need special handling for attributes.
959     if (!operand)
960       continue;
961 
962     raw_indented_ostream::DelimitedScope scope(os);
963     std::string varName;
964     if (operand->isVariadic()) {
965       varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
966       os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName);
967       std::string range;
968       if (node.isNestedDagArg(argIndex)) {
969         range = childNodeNames[argIndex];
970       } else {
971         range = std::string(node.getArgName(argIndex));
972       }
973       // Resolve the symbol for all range use so that we have a uniform way of
974       // capturing the values.
975       range = symbolInfoMap.getValueAndRangeUse(range);
976       os << formatv("for (auto v: {0}) {{\n  {1}.push_back(v);\n}\n", range,
977                     varName);
978     } else {
979       varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
980       os << formatv("::mlir::Value {0} = ", varName);
981       if (node.isNestedDagArg(argIndex)) {
982         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
983       } else {
984         DagLeaf leaf = node.getArgAsLeaf(argIndex);
985         auto symbol =
986             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
987         if (leaf.isNativeCodeCall()) {
988           os << std::string(
989               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
990         } else {
991           os << symbol;
992         }
993       }
994       os << ";\n";
995     }
996 
997     // Update to use the newly created local variable for building the op later.
998     childNodeNames[argIndex] = varName;
999   }
1000 }
1001 
1002 void PatternEmitter::supplyValuesForOpArgs(
1003     DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
1004   Operator &resultOp = node.getDialectOp(opMap);
1005   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1006        argIndex != numOpArgs; ++argIndex) {
1007     // Start each argument on its own line.
1008     os << ",\n    ";
1009 
1010     Argument opArg = resultOp.getArg(argIndex);
1011     // Handle the case of operand first.
1012     if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
1013       if (!operand->name.empty())
1014         os << "/*" << operand->name << "=*/";
1015       os << childNodeNames.lookup(argIndex);
1016       continue;
1017     }
1018 
1019     // The argument in the op definition.
1020     auto opArgName = resultOp.getArgName(argIndex);
1021     if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1022       if (!subTree.isNativeCodeCall())
1023         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1024                              "for creating attribute");
1025       os << formatv("/*{0}=*/{1}", opArgName,
1026                     handleReplaceWithNativeCodeCall(subTree));
1027     } else {
1028       auto leaf = node.getArgAsLeaf(argIndex);
1029       // The argument in the result DAG pattern.
1030       auto patArgName = node.getArgName(argIndex);
1031       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1032         // TODO: Refactor out into map to avoid recomputing these.
1033         if (!opArg.is<NamedAttribute *>())
1034           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1035         if (!patArgName.empty())
1036           os << "/*" << patArgName << "=*/";
1037       } else {
1038         os << "/*" << opArgName << "=*/";
1039       }
1040       os << handleOpArgument(leaf, patArgName);
1041     }
1042   }
1043 }
1044 
1045 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1046     DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
1047   Operator &resultOp = node.getDialectOp(opMap);
1048 
1049   auto scope = os.scope();
1050   os << formatv("::mlir::SmallVector<::mlir::Value, 4> "
1051                 "tblgen_values; (void)tblgen_values;\n");
1052   os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
1053                 "tblgen_attrs; (void)tblgen_attrs;\n");
1054 
1055   const char *addAttrCmd =
1056       "if (auto tmpAttr = {1}) {\n"
1057       "  tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), "
1058       "tmpAttr);\n}\n";
1059   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1060     if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1061       // The argument in the op definition.
1062       auto opArgName = resultOp.getArgName(argIndex);
1063       if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1064         if (!subTree.isNativeCodeCall())
1065           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1066                                "for creating attribute");
1067         os << formatv(addAttrCmd, opArgName,
1068                       handleReplaceWithNativeCodeCall(subTree));
1069       } else {
1070         auto leaf = node.getArgAsLeaf(argIndex);
1071         // The argument in the result DAG pattern.
1072         auto patArgName = node.getArgName(argIndex);
1073         os << formatv(addAttrCmd, opArgName,
1074                       handleOpArgument(leaf, patArgName));
1075       }
1076       continue;
1077     }
1078 
1079     const auto *operand =
1080         resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1081     std::string varName;
1082     if (operand->isVariadic()) {
1083       std::string range;
1084       if (node.isNestedDagArg(argIndex)) {
1085         range = childNodeNames.lookup(argIndex);
1086       } else {
1087         range = std::string(node.getArgName(argIndex));
1088       }
1089       // Resolve the symbol for all range use so that we have a uniform way of
1090       // capturing the values.
1091       range = symbolInfoMap.getValueAndRangeUse(range);
1092       os << formatv("for (auto v: {0}) {{\n  tblgen_values.push_back(v);\n}\n",
1093                     range);
1094     } else {
1095       os << formatv("tblgen_values.push_back(", varName);
1096       if (node.isNestedDagArg(argIndex)) {
1097         os << symbolInfoMap.getValueAndRangeUse(
1098             childNodeNames.lookup(argIndex));
1099       } else {
1100         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1101         auto symbol =
1102             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1103         if (leaf.isNativeCodeCall()) {
1104           os << std::string(
1105               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1106         } else {
1107           os << symbol;
1108         }
1109       }
1110       os << ");\n";
1111     }
1112   }
1113 }
1114 
1115 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1116   emitSourceFileHeader("Rewriters", os);
1117 
1118   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1119   auto numPatterns = patterns.size();
1120 
1121   // We put the map here because it can be shared among multiple patterns.
1122   RecordOperatorMap recordOpMap;
1123 
1124   std::vector<std::string> rewriterNames;
1125   rewriterNames.reserve(numPatterns);
1126 
1127   std::string baseRewriterName = "GeneratedConvert";
1128   int rewriterIndex = 0;
1129 
1130   for (Record *p : patterns) {
1131     std::string name;
1132     if (p->isAnonymous()) {
1133       // If no name is provided, ensure unique rewriter names simply by
1134       // appending unique suffix.
1135       name = baseRewriterName + llvm::utostr(rewriterIndex++);
1136     } else {
1137       name = std::string(p->getName());
1138     }
1139     LLVM_DEBUG(llvm::dbgs()
1140                << "=== start generating pattern '" << name << "' ===\n");
1141     PatternEmitter(p, &recordOpMap, os).emit(name);
1142     LLVM_DEBUG(llvm::dbgs()
1143                << "=== done generating pattern '" << name << "' ===\n");
1144     rewriterNames.push_back(std::move(name));
1145   }
1146 
1147   // Emit function to add the generated matchers to the pattern list.
1148   os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(MLIRContext "
1149         "*context, OwningRewritePatternList *patterns) {\n";
1150   for (const auto &name : rewriterNames) {
1151     os << "  patterns->insert<" << name << ">(context);\n";
1152   }
1153   os << "}\n";
1154 }
1155 
1156 static mlir::GenRegistration
1157     genRewriters("gen-rewriters", "Generate pattern rewriters",
1158                  [](const RecordKeeper &records, raw_ostream &os) {
1159                    emitRewriters(records, os);
1160                    return false;
1161                  });
1162