xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision f02e61a8b957871292e092aa440964c0f4e2bb21)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/TableGen/Pattern.h"
19 #include "mlir/TableGen/Predicate.h"
20 #include "mlir/TableGen/Type.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringSet.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatAdapters.h"
26 #include "llvm/Support/PrettyStackTrace.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Main.h"
30 #include "llvm/TableGen/Record.h"
31 #include "llvm/TableGen/TableGenBackend.h"
32 
33 using namespace mlir;
34 using namespace mlir::tblgen;
35 
36 using llvm::formatv;
37 using llvm::Record;
38 using llvm::RecordKeeper;
39 
40 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
41 
42 namespace llvm {
43 template <>
44 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
45   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
46                      raw_ostream &os, StringRef style) {
47     os << v.first << ":" << v.second;
48   }
49 };
50 } // end namespace llvm
51 
52 //===----------------------------------------------------------------------===//
53 // PatternEmitter
54 //===----------------------------------------------------------------------===//
55 
56 namespace {
57 class PatternEmitter {
58 public:
59   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
60 
61   // Emits the mlir::RewritePattern struct named `rewriteName`.
62   void emit(StringRef rewriteName);
63 
64 private:
65   // Emits the code for matching ops.
66   void emitMatchLogic(DagNode tree, StringRef opName);
67 
68   // Emits the code for rewriting ops.
69   void emitRewriteLogic();
70 
71   //===--------------------------------------------------------------------===//
72   // Match utilities
73   //===--------------------------------------------------------------------===//
74 
75   // Emits C++ statements for matching the DAG structure.
76   void emitMatch(DagNode tree, StringRef name, int depth);
77 
78   // Emits C++ statements for matching using a native code call.
79   void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
80 
81   // Emits C++ statements for matching the op constrained by the given DAG
82   // `tree` returning the op's variable name.
83   void emitOpMatch(DagNode tree, StringRef opName, int depth);
84 
85   // Emits C++ statements for matching the `argIndex`-th argument of the given
86   // DAG `tree` as an operand. operandIndex is the index in the DAG excluding
87   // the preceding attributes.
88   void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
89                         int operandIndex, int depth);
90 
91   // Emits C++ statements for matching the `argIndex`-th argument of the given
92   // DAG `tree` as an attribute.
93   void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
94                           int depth);
95 
96   // Emits C++ for checking a match with a corresponding match failure
97   // diagnostic.
98   void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
99                       const llvm::formatv_object_base &failureFmt);
100 
101   // Emits C++ for checking a match with a corresponding match failure
102   // diagnostics.
103   void emitMatchCheck(StringRef opName, const std::string &matchStr,
104                       const std::string &failureStr);
105 
106   //===--------------------------------------------------------------------===//
107   // Rewrite utilities
108   //===--------------------------------------------------------------------===//
109 
110   // The entry point for handling a result pattern rooted at `resultTree`. This
111   // method dispatches to concrete handlers according to `resultTree`'s kind and
112   // returns a symbol representing the whole value pack. Callers are expected to
113   // further resolve the symbol according to the specific use case.
114   //
115   // `depth` is the nesting level of `resultTree`; 0 means top-level result
116   // pattern. For top-level result pattern, `resultIndex` indicates which result
117   // of the matched root op this pattern is intended to replace, which can be
118   // used to deduce the result type of the op generated from this result
119   // pattern.
120   std::string handleResultPattern(DagNode resultTree, int resultIndex,
121                                   int depth);
122 
123   // Emits the C++ statement to replace the matched DAG with a value built via
124   // calling native C++ code.
125   std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
126 
127   // Returns the symbol of the old value serving as the replacement.
128   StringRef handleReplaceWithValue(DagNode tree);
129 
130   // Returns the location value to use.
131   std::pair<bool, std::string> getLocation(DagNode tree);
132 
133   // Returns the location value to use.
134   std::string handleLocationDirective(DagNode tree);
135 
136   // Emits the C++ statement to build a new op out of the given DAG `tree` and
137   // returns the variable name that this op is assigned to. If the root op in
138   // DAG `tree` has a specified name, the created op will be assigned to a
139   // variable of the given name. Otherwise, a unique name will be used as the
140   // result value name.
141   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
142 
143   using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
144 
145   // Emits a local variable for each value and attribute to be used for creating
146   // an op.
147   void createSeparateLocalVarsForOpArgs(DagNode node,
148                                         ChildNodeIndexNameMap &childNodeNames);
149 
150   // Emits the concrete arguments used to call an op's builder.
151   void supplyValuesForOpArgs(DagNode node,
152                              const ChildNodeIndexNameMap &childNodeNames,
153                              int depth);
154 
155   // Emits the local variables for holding all values as a whole and all named
156   // attributes as a whole to be used for creating an op.
157   void createAggregateLocalVarsForOpArgs(
158       DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
159 
160   // Returns the C++ expression to construct a constant attribute of the given
161   // `value` for the given attribute kind `attr`.
162   std::string handleConstantAttr(Attribute attr, StringRef value);
163 
164   // Returns the C++ expression to build an argument from the given DAG `leaf`.
165   // `patArgName` is used to bound the argument to the source pattern.
166   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
167 
168   //===--------------------------------------------------------------------===//
169   // General utilities
170   //===--------------------------------------------------------------------===//
171 
172   // Collects all of the operations within the given dag tree.
173   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
174 
175   // Returns a unique symbol for a local variable of the given `op`.
176   std::string getUniqueSymbol(const Operator *op);
177 
178   //===--------------------------------------------------------------------===//
179   // Symbol utilities
180   //===--------------------------------------------------------------------===//
181 
182   // Returns how many static values the given DAG `node` correspond to.
183   int getNodeValueCount(DagNode node);
184 
185 private:
186   // Pattern instantiation location followed by the location of multiclass
187   // prototypes used. This is intended to be used as a whole to
188   // PrintFatalError() on errors.
189   ArrayRef<llvm::SMLoc> loc;
190 
191   // Op's TableGen Record to wrapper object.
192   RecordOperatorMap *opMap;
193 
194   // Handy wrapper for pattern being emitted.
195   Pattern pattern;
196 
197   // Map for all bound symbols' info.
198   SymbolInfoMap symbolInfoMap;
199 
200   // The next unused ID for newly created values.
201   unsigned nextValueId;
202 
203   raw_indented_ostream os;
204 
205   // Format contexts containing placeholder substitutions.
206   FmtContext fmtCtx;
207 
208   // Number of op processed.
209   int opCounter = 0;
210 };
211 } // end anonymous namespace
212 
213 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
214                                raw_ostream &os)
215     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
216       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
217   fmtCtx.withBuilder("rewriter");
218 }
219 
220 std::string PatternEmitter::handleConstantAttr(Attribute attr,
221                                                StringRef value) {
222   if (!attr.isConstBuildable())
223     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
224                              " does not have the 'constBuilderCall' field");
225 
226   // TODO: Verify the constants here
227   return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
228 }
229 
230 // Helper function to match patterns.
231 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
232   if (tree.isNativeCodeCall()) {
233     emitNativeCodeMatch(tree, name, depth);
234     return;
235   }
236 
237   if (tree.isOperation()) {
238     emitOpMatch(tree, name, depth);
239     return;
240   }
241 
242   PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
243 }
244 
245 // Helper function to match patterns.
246 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
247                                          int depth) {
248   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
249   LLVM_DEBUG(tree.print(llvm::dbgs()));
250   LLVM_DEBUG(llvm::dbgs() << '\n');
251 
252   // TODO(suderman): iterate through arguments, determine their types, output
253   // names.
254   SmallVector<std::string, 8> capture(8);
255   if (tree.getNumArgs() > 8) {
256     PrintFatalError(loc,
257                     "unsupported NativeCodeCall matcher argument numbers: " +
258                         Twine(tree.getNumArgs()));
259   }
260 
261   raw_indented_ostream::DelimitedScope scope(os);
262 
263   os << "if(!" << opName << ") return failure();\n";
264   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
265     std::string argName = formatv("arg{0}_{1}", depth, i);
266     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
267       os << "Value " << argName << ";\n";
268     } else {
269       auto leaf = tree.getArgAsLeaf(i);
270       if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
271         os << "Attribute " << argName << ";\n";
272       } else if (leaf.isOperandMatcher()) {
273         os << "Operation " << argName << ";\n";
274       }
275     }
276 
277     capture[i] = std::move(argName);
278   }
279 
280   bool hasLocationDirective;
281   std::string locToUse;
282   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
283 
284   auto fmt = tree.getNativeCodeTemplate();
285   auto nativeCodeCall = std::string(tgfmt(
286       fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1],
287       capture[2], capture[3], capture[4], capture[5], capture[6], capture[7]));
288 
289   os << "if (failed(" << nativeCodeCall << ")) return failure();\n";
290 
291   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
292     auto name = tree.getArgName(i);
293     if (!name.empty() && name != "_") {
294       os << formatv("{0} = {1};\n", name, capture[i]);
295     }
296   }
297 
298   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
299     std::string argName = capture[i];
300 
301     // Handle nested DAG construct first
302     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
303       PrintFatalError(
304           loc, formatv("Matching nested tree in NativeCodecall not support for "
305                        "{0} as arg {1}",
306                        argName, i));
307     }
308 
309     DagLeaf leaf = tree.getArgAsLeaf(i);
310     auto constraint = leaf.getAsConstraint();
311 
312     auto self = formatv("{0}", argName);
313     emitMatchCheck(
314         opName,
315         tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
316         formatv("\"operand {0} of native code call '{1}' failed to satisfy "
317                 "constraint: "
318                 "'{2}'\"",
319                 i, tree.getNativeCodeTemplate(), constraint.getSummary()));
320   }
321 
322   LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
323 }
324 
325 // Helper function to match patterns.
326 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
327   Operator &op = tree.getDialectOp(opMap);
328   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
329                           << op.getOperationName() << "' at depth " << depth
330                           << '\n');
331 
332   std::string castedName = formatv("castedOp{0}", depth);
333   os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); "
334                 "(void){0};\n",
335                 castedName, opName, op.getQualCppClassName());
336   // Skip the operand matching at depth 0 as the pattern rewriter already does.
337   if (depth != 0) {
338     // Skip if there is no defining operation (e.g., arguments to function).
339     os << formatv("if (!{0}) return failure();\n", castedName);
340   }
341   if (tree.getNumArgs() != op.getNumArgs()) {
342     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
343                                  "pattern vs. {2} in definition",
344                                  op.getOperationName(), tree.getNumArgs(),
345                                  op.getNumArgs()));
346   }
347 
348   // If the operand's name is set, set to that variable.
349   auto name = tree.getSymbol();
350   if (!name.empty())
351     os << formatv("{0} = {1};\n", name, castedName);
352 
353   for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) {
354     auto opArg = op.getArg(i);
355     std::string argName = formatv("op{0}", depth + 1);
356 
357     // Handle nested DAG construct first
358     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
359       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
360         if (operand->isVariableLength()) {
361           auto error = formatv("use nested DAG construct to match op {0}'s "
362                                "variadic operand #{1} unsupported now",
363                                op.getOperationName(), i);
364           PrintFatalError(loc, error);
365         }
366       }
367       os << "{\n";
368 
369       // Attributes don't count for getODSOperands.
370       os.indent() << formatv(
371           "auto *{0} = "
372           "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
373           argName, castedName, nextOperand++);
374       emitMatch(argTree, argName, depth + 1);
375       os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName);
376       os.unindent() << "}\n";
377       continue;
378     }
379 
380     // Next handle DAG leaf: operand or attribute
381     if (opArg.is<NamedTypeConstraint *>()) {
382       // emitOperandMatch's argument indexing counts attributes.
383       emitOperandMatch(tree, castedName, i, nextOperand, depth);
384       ++nextOperand;
385     } else if (opArg.is<NamedAttribute *>()) {
386       emitAttributeMatch(tree, opName, i, depth);
387     } else {
388       PrintFatalError(loc, "unhandled case when matching op");
389     }
390   }
391   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
392                           << op.getOperationName() << "' at depth " << depth
393                           << '\n');
394 }
395 
396 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
397                                       int argIndex, int operandIndex,
398                                       int depth) {
399   Operator &op = tree.getDialectOp(opMap);
400   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
401   auto matcher = tree.getArgAsLeaf(argIndex);
402 
403   // If a constraint is specified, we need to generate C++ statements to
404   // check the constraint.
405   if (!matcher.isUnspecified()) {
406     if (!matcher.isOperandMatcher()) {
407       PrintFatalError(
408           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
409                        op.getOperationName(), argIndex + 1));
410     }
411 
412     // Only need to verify if the matcher's type is different from the one
413     // of op definition.
414     Constraint constraint = matcher.getAsConstraint();
415     if (operand->constraint != constraint) {
416       if (operand->isVariableLength()) {
417         auto error = formatv(
418             "further constrain op {0}'s variadic operand #{1} unsupported now",
419             op.getOperationName(), argIndex);
420         PrintFatalError(loc, error);
421       }
422       auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
423                           opName, operandIndex);
424       emitMatchCheck(
425           opName,
426           tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
427           formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
428                   "'{2}'\"",
429                   operand - op.operand_begin(), op.getOperationName(),
430                   constraint.getSummary()));
431     }
432   }
433 
434   // Capture the value
435   auto name = tree.getArgName(argIndex);
436   // `$_` is a special symbol to ignore op argument matching.
437   if (!name.empty() && name != "_") {
438     // We need to subtract the number of attributes before this operand to get
439     // the index in the operand list.
440     auto numPrevAttrs = std::count_if(
441         op.arg_begin(), op.arg_begin() + argIndex,
442         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
443 
444     auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
445     os << formatv("{0} = {1}.getODSOperands({2});\n",
446                   res->second.getVarName(name), opName,
447                   argIndex - numPrevAttrs);
448   }
449 }
450 
451 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
452                                         int argIndex, int depth) {
453   Operator &op = tree.getDialectOp(opMap);
454   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
455   const auto &attr = namedAttr->attr;
456 
457   os << "{\n";
458   os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
459                          "(void)tblgen_attr;\n",
460                          opName, attr.getStorageType(), namedAttr->name);
461 
462   // TODO: This should use getter method to avoid duplication.
463   if (attr.hasDefaultValue()) {
464     os << "if (!tblgen_attr) tblgen_attr = "
465        << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
466                             attr.getDefaultValue()))
467        << ";\n";
468   } else if (attr.isOptional()) {
469     // For a missing attribute that is optional according to definition, we
470     // should just capture a mlir::Attribute() to signal the missing state.
471     // That is precisely what getAttr() returns on missing attributes.
472   } else {
473     emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
474                    formatv("\"expected op '{0}' to have attribute '{1}' "
475                            "of type '{2}'\"",
476                            op.getOperationName(), namedAttr->name,
477                            attr.getStorageType()));
478   }
479 
480   auto matcher = tree.getArgAsLeaf(argIndex);
481   if (!matcher.isUnspecified()) {
482     if (!matcher.isAttrMatcher()) {
483       PrintFatalError(
484           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
485                        op.getOperationName(), argIndex + 1));
486     }
487 
488     // If a constraint is specified, we need to generate C++ statements to
489     // check the constraint.
490     emitMatchCheck(
491         opName,
492         tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
493         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
494                 "{2}\"",
495                 op.getOperationName(), namedAttr->name,
496                 matcher.getAsConstraint().getSummary()));
497   }
498 
499   // Capture the value
500   auto name = tree.getArgName(argIndex);
501   // `$_` is a special symbol to ignore op argument matching.
502   if (!name.empty() && name != "_") {
503     os << formatv("{0} = tblgen_attr;\n", name);
504   }
505 
506   os.unindent() << "}\n";
507 }
508 
509 void PatternEmitter::emitMatchCheck(
510     StringRef opName, const FmtObjectBase &matchFmt,
511     const llvm::formatv_object_base &failureFmt) {
512   emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
513 }
514 
515 void PatternEmitter::emitMatchCheck(StringRef opName,
516                                     const std::string &matchStr,
517                                     const std::string &failureStr) {
518 
519   os << "if (!(" << matchStr << "))";
520   os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
521                               << ", [&](::mlir::Diagnostic &diag) {\n  diag << "
522                               << failureStr << ";\n});";
523 }
524 
525 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
526   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
527   int depth = 0;
528   emitMatch(tree, opName, depth);
529 
530   for (auto &appliedConstraint : pattern.getConstraints()) {
531     auto &constraint = appliedConstraint.constraint;
532     auto &entities = appliedConstraint.entities;
533 
534     auto condition = constraint.getConditionTemplate();
535     if (isa<TypeConstraint>(constraint)) {
536       auto self = formatv("({0}.getType())",
537                           symbolInfoMap.getValueAndRangeUse(entities.front()));
538       emitMatchCheck(
539           opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
540           formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
541                   entities.front(), constraint.getSummary()));
542 
543     } else if (isa<AttrConstraint>(constraint)) {
544       PrintFatalError(
545           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
546     } else {
547       // TODO: replace formatv arguments with the exact specified
548       // args.
549       if (entities.size() > 4) {
550         PrintFatalError(loc, "only support up to 4-entity constraints now");
551       }
552       SmallVector<std::string, 4> names;
553       int i = 0;
554       for (int e = entities.size(); i < e; ++i)
555         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
556       std::string self = appliedConstraint.self;
557       if (!self.empty())
558         self = symbolInfoMap.getValueAndRangeUse(self);
559       for (; i < 4; ++i)
560         names.push_back("<unused>");
561       emitMatchCheck(opName,
562                      tgfmt(condition, &fmtCtx.withSelf(self), names[0],
563                            names[1], names[2], names[3]),
564                      formatv("\"entities '{0}' failed to satisfy constraint: "
565                              "{1}\"",
566                              llvm::join(entities, ", "),
567                              constraint.getSummary()));
568     }
569   }
570 
571   // Some of the operands could be bound to the same symbol name, we need
572   // to enforce equality constraint on those.
573   // TODO: we should be able to emit equality checks early
574   // and short circuit unnecessary work if vars are not equal.
575   for (auto symbolInfoIt = symbolInfoMap.begin();
576        symbolInfoIt != symbolInfoMap.end();) {
577     auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
578     auto startRange = range.first;
579     auto endRange = range.second;
580 
581     auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
582     for (++startRange; startRange != endRange; ++startRange) {
583       auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
584       emitMatchCheck(
585           opName,
586           formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
587           formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
588                   secondOperand));
589     }
590 
591     symbolInfoIt = endRange;
592   }
593 
594   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
595 }
596 
597 void PatternEmitter::collectOps(DagNode tree,
598                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
599   // Check if this tree is an operation.
600   if (tree.isOperation()) {
601     const Operator &op = tree.getDialectOp(opMap);
602     LLVM_DEBUG(llvm::dbgs()
603                << "found operation " << op.getOperationName() << '\n');
604     ops.insert(&op);
605   }
606 
607   // Recurse the arguments of the tree.
608   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
609     if (auto child = tree.getArgAsNestedDag(i))
610       collectOps(child, ops);
611 }
612 
613 void PatternEmitter::emit(StringRef rewriteName) {
614   // Get the DAG tree for the source pattern.
615   DagNode sourceTree = pattern.getSourcePattern();
616 
617   const Operator &rootOp = pattern.getSourceRootOp();
618   auto rootName = rootOp.getOperationName();
619 
620   // Collect the set of result operations.
621   llvm::SmallPtrSet<const Operator *, 4> resultOps;
622   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
623   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
624     collectOps(pattern.getResultPattern(i), resultOps);
625   }
626   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
627 
628   // Emit RewritePattern for Pattern.
629   auto locs = pattern.getLocation();
630   os << formatv("/* Generated from:\n    {0:$[ instantiating\n    ]}\n*/\n",
631                 make_range(locs.rbegin(), locs.rend()));
632   os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
633   {0}(::mlir::MLIRContext *context)
634       : ::mlir::RewritePattern("{1}", {{)",
635                 rewriteName, rootName);
636   // Sort result operators by name.
637   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
638                                                          resultOps.end());
639   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
640     return lhs->getOperationName() < rhs->getOperationName();
641   });
642   llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
643     os << '"' << op->getOperationName() << '"';
644   });
645   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
646 
647   // Emit matchAndRewrite() function.
648   {
649     auto classScope = os.scope();
650     os.reindent(R"(
651     ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
652         ::mlir::PatternRewriter &rewriter) const override {)")
653         << '\n';
654     {
655       auto functionScope = os.scope();
656 
657       // Register all symbols bound in the source pattern.
658       pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
659 
660       LLVM_DEBUG(llvm::dbgs()
661                  << "start creating local variables for capturing matches\n");
662       os << "// Variables for capturing values and attributes used while "
663             "creating ops\n";
664       // Create local variables for storing the arguments and results bound
665       // to symbols.
666       for (const auto &symbolInfoPair : symbolInfoMap) {
667         const auto &symbol = symbolInfoPair.first;
668         const auto &info = symbolInfoPair.second;
669 
670         os << info.getVarDecl(symbol);
671       }
672       // TODO: capture ops with consistent numbering so that it can be
673       // reused for fused loc.
674       os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
675                     pattern.getSourcePattern().getNumOps());
676       LLVM_DEBUG(llvm::dbgs()
677                  << "done creating local variables for capturing matches\n");
678 
679       os << "// Match\n";
680       os << "tblgen_ops[0] = op0;\n";
681       emitMatchLogic(sourceTree, "op0");
682 
683       os << "\n// Rewrite\n";
684       emitRewriteLogic();
685 
686       os << "return ::mlir::success();\n";
687     }
688     os << "};\n";
689   }
690   os << "};\n\n";
691 }
692 
693 void PatternEmitter::emitRewriteLogic() {
694   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
695   const Operator &rootOp = pattern.getSourceRootOp();
696   int numExpectedResults = rootOp.getNumResults();
697   int numResultPatterns = pattern.getNumResultPatterns();
698 
699   // First register all symbols bound to ops generated in result patterns.
700   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
701 
702   // Only the last N static values generated are used to replace the matched
703   // root N-result op. We need to calculate the starting index (of the results
704   // of the matched op) each result pattern is to replace.
705   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
706   // If we don't need to replace any value at all, set the replacement starting
707   // index as the number of result patterns so we skip all of them when trying
708   // to replace the matched op's results.
709   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
710   for (int i = numResultPatterns - 1; i >= 0; --i) {
711     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
712     offsets[i] = offsets[i + 1] - numValues;
713     if (offsets[i] == 0) {
714       if (replStartIndex == -1)
715         replStartIndex = i;
716     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
717       auto error = formatv(
718           "cannot use the same multi-result op '{0}' to generate both "
719           "auxiliary values and values to be used for replacing the matched op",
720           pattern.getResultPattern(i).getSymbol());
721       PrintFatalError(loc, error);
722     }
723   }
724 
725   if (offsets.front() > 0) {
726     const char error[] = "no enough values generated to replace the matched op";
727     PrintFatalError(loc, error);
728   }
729 
730   os << "auto odsLoc = rewriter.getFusedLoc({";
731   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
732     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
733   }
734   os << "}); (void)odsLoc;\n";
735 
736   // Process auxiliary result patterns.
737   for (int i = 0; i < replStartIndex; ++i) {
738     DagNode resultTree = pattern.getResultPattern(i);
739     auto val = handleResultPattern(resultTree, offsets[i], 0);
740     // Normal op creation will be streamed to `os` by the above call; but
741     // NativeCodeCall will only be materialized to `os` if it is used. Here
742     // we are handling auxiliary patterns so we want the side effect even if
743     // NativeCodeCall is not replacing matched root op's results.
744     if (resultTree.isNativeCodeCall())
745       os << val << ";\n";
746   }
747 
748   if (numExpectedResults == 0) {
749     assert(replStartIndex >= numResultPatterns &&
750            "invalid auxiliary vs. replacement pattern division!");
751     // No result to replace. Just erase the op.
752     os << "rewriter.eraseOp(op0);\n";
753   } else {
754     // Process replacement result patterns.
755     os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
756     for (int i = replStartIndex; i < numResultPatterns; ++i) {
757       DagNode resultTree = pattern.getResultPattern(i);
758       auto val = handleResultPattern(resultTree, offsets[i], 0);
759       os << "\n";
760       // Resolve each symbol for all range use so that we can loop over them.
761       // We need an explicit cast to `SmallVector` to capture the cases where
762       // `{0}` resolves to an `Operation::result_range` as well as cases that
763       // are not iterable (e.g. vector that gets wrapped in additional braces by
764       // RewriterGen).
765       // TODO: Revisit the need for materializing a vector.
766       os << symbolInfoMap.getAllRangeUse(
767           val,
768           "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
769           "  tblgen_repl_values.push_back(v);\n}\n",
770           "\n");
771     }
772     os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
773   }
774 
775   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
776 }
777 
778 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
779   return std::string(
780       formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
781 }
782 
783 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
784                                                 int resultIndex, int depth) {
785   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
786   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
787   LLVM_DEBUG(llvm::dbgs() << '\n');
788 
789   if (resultTree.isLocationDirective()) {
790     PrintFatalError(loc,
791                     "location directive can only be used with op creation");
792   }
793 
794   if (resultTree.isNativeCodeCall()) {
795     auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth);
796     symbolInfoMap.bindValue(symbol);
797     return symbol;
798   }
799 
800   if (resultTree.isReplaceWithValue())
801     return handleReplaceWithValue(resultTree).str();
802 
803   // Normal op creation.
804   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
805   if (resultTree.getSymbol().empty()) {
806     // This is an op not explicitly bound to a symbol in the rewrite rule.
807     // Register the auto-generated symbol for it.
808     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
809   }
810   return symbol;
811 }
812 
813 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
814   assert(tree.isReplaceWithValue());
815 
816   if (tree.getNumArgs() != 1) {
817     PrintFatalError(
818         loc, "replaceWithValue directive must take exactly one argument");
819   }
820 
821   if (!tree.getSymbol().empty()) {
822     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
823   }
824 
825   return tree.getArgName(0);
826 }
827 
828 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
829   assert(tree.isLocationDirective());
830   auto lookUpArgLoc = [this, &tree](int idx) {
831     const auto *const lookupFmt = "(*{0}.begin()).getLoc()";
832     return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt);
833   };
834 
835   if (tree.getNumArgs() == 0)
836     llvm::PrintFatalError(
837         "At least one argument to location directive required");
838 
839   if (!tree.getSymbol().empty())
840     PrintFatalError(loc, "cannot bind symbol to location");
841 
842   if (tree.getNumArgs() == 1) {
843     DagLeaf leaf = tree.getArgAsLeaf(0);
844     if (leaf.isStringAttr())
845       return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), "
846                      "rewriter.getContext())",
847                      leaf.getStringAttr())
848           .str();
849     return lookUpArgLoc(0);
850   }
851 
852   std::string ret;
853   llvm::raw_string_ostream os(ret);
854   std::string strAttr;
855   os << "rewriter.getFusedLoc({";
856   bool first = true;
857   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
858     DagLeaf leaf = tree.getArgAsLeaf(i);
859     // Handle the optional string value.
860     if (leaf.isStringAttr()) {
861       if (!strAttr.empty())
862         llvm::PrintFatalError("Only one string attribute may be specified");
863       strAttr = leaf.getStringAttr();
864       continue;
865     }
866     os << (first ? "" : ", ") << lookUpArgLoc(i);
867     first = false;
868   }
869   os << "}";
870   if (!strAttr.empty()) {
871     os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
872   }
873   os << ")";
874   return os.str();
875 }
876 
877 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
878                                              StringRef patArgName) {
879   if (leaf.isStringAttr())
880     PrintFatalError(loc, "raw string not supported as argument");
881   if (leaf.isConstantAttr()) {
882     auto constAttr = leaf.getAsConstantAttr();
883     return handleConstantAttr(constAttr.getAttribute(),
884                               constAttr.getConstantValue());
885   }
886   if (leaf.isEnumAttrCase()) {
887     auto enumCase = leaf.getAsEnumAttrCase();
888     if (enumCase.isStrCase())
889       return handleConstantAttr(enumCase, enumCase.getSymbol());
890     // This is an enum case backed by an IntegerAttr. We need to get its value
891     // to build the constant.
892     std::string val = std::to_string(enumCase.getValue());
893     return handleConstantAttr(enumCase, val);
894   }
895 
896   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
897   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
898   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
899     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
900                             << "' (via symbol ref)\n");
901     return argName;
902   }
903   if (leaf.isNativeCodeCall()) {
904     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
905     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
906                             << "' (via NativeCodeCall)\n");
907     return std::string(repl);
908   }
909   PrintFatalError(loc, "unhandled case when rewriting op");
910 }
911 
912 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
913                                                             int depth) {
914   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
915   LLVM_DEBUG(tree.print(llvm::dbgs()));
916   LLVM_DEBUG(llvm::dbgs() << '\n');
917 
918   auto fmt = tree.getNativeCodeTemplate();
919   // TODO: replace formatv arguments with the exact specified args.
920   SmallVector<std::string, 8> attrs(8);
921   if (tree.getNumArgs() > 8) {
922     PrintFatalError(loc,
923                     "unsupported NativeCodeCall replace argument numbers: " +
924                         Twine(tree.getNumArgs()));
925   }
926   bool hasLocationDirective;
927   std::string locToUse;
928   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
929 
930   for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
931     if (tree.isNestedDagArg(i)) {
932       attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1);
933     } else {
934       attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
935     }
936     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
937                             << " replacement: " << attrs[i] << "\n");
938   }
939   return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0],
940                            attrs[1], attrs[2], attrs[3], attrs[4], attrs[5],
941                            attrs[6], attrs[7]));
942 }
943 
944 int PatternEmitter::getNodeValueCount(DagNode node) {
945   if (node.isOperation()) {
946     // If the op is bound to a symbol in the rewrite rule, query its result
947     // count from the symbol info map.
948     auto symbol = node.getSymbol();
949     if (!symbol.empty()) {
950       return symbolInfoMap.getStaticValueCount(symbol);
951     }
952     // Otherwise this is an unbound op; we will use all its results.
953     return pattern.getDialectOp(node).getNumResults();
954   }
955   // TODO: This considers all NativeCodeCall as returning one
956   // value. Enhance if multi-value ones are needed.
957   return 1;
958 }
959 
960 std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) {
961   auto numPatArgs = tree.getNumArgs();
962 
963   if (numPatArgs != 0) {
964     if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
965       if (lastArg.isLocationDirective()) {
966         return std::make_pair(true, handleLocationDirective(lastArg));
967       }
968   }
969 
970   // If no explicit location is given, use the default, all fused, location.
971   return std::make_pair(false, "odsLoc");
972 }
973 
974 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
975                                              int depth) {
976   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
977   LLVM_DEBUG(tree.print(llvm::dbgs()));
978   LLVM_DEBUG(llvm::dbgs() << '\n');
979 
980   Operator &resultOp = tree.getDialectOp(opMap);
981   auto numOpArgs = resultOp.getNumArgs();
982   auto numPatArgs = tree.getNumArgs();
983 
984   bool hasLocationDirective;
985   std::string locToUse;
986   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
987 
988   auto inPattern = numPatArgs - hasLocationDirective;
989   if (numOpArgs != inPattern) {
990     PrintFatalError(loc,
991                     formatv("resultant op '{0}' argument number mismatch: "
992                             "{1} in pattern vs. {2} in definition",
993                             resultOp.getOperationName(), inPattern, numOpArgs));
994   }
995 
996   // A map to collect all nested DAG child nodes' names, with operand index as
997   // the key. This includes both bound and unbound child nodes.
998   ChildNodeIndexNameMap childNodeNames;
999 
1000   // First go through all the child nodes who are nested DAG constructs to
1001   // create ops for them and remember the symbol names for them, so that we can
1002   // use the results in the current node. This happens in a recursive manner.
1003   for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
1004     if (auto child = tree.getArgAsNestedDag(i))
1005       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
1006   }
1007 
1008   // The name of the local variable holding this op.
1009   std::string valuePackName;
1010   // The symbol for holding the result of this pattern. Note that the result of
1011   // this pattern is not necessarily the same as the variable created by this
1012   // pattern because we can use `__N` suffix to refer only a specific result if
1013   // the generated op is a multi-result op.
1014   std::string resultValue;
1015   if (tree.getSymbol().empty()) {
1016     // No symbol is explicitly bound to this op in the pattern. Generate a
1017     // unique name.
1018     valuePackName = resultValue = getUniqueSymbol(&resultOp);
1019   } else {
1020     resultValue = std::string(tree.getSymbol());
1021     // Strip the index to get the name for the value pack and use it to name the
1022     // local variable for the op.
1023     valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
1024   }
1025 
1026   // Create the local variable for this op.
1027   os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
1028                 valuePackName);
1029 
1030   // Right now ODS don't have general type inference support. Except a few
1031   // special cases listed below, DRR needs to supply types for all results
1032   // when building an op.
1033   bool isSameOperandsAndResultType =
1034       resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
1035   bool useFirstAttr =
1036       resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
1037 
1038   if (isSameOperandsAndResultType || useFirstAttr) {
1039     // We know how to deduce the result type for ops with these traits and we've
1040     // generated builders taking aggregate parameters. Use those builders to
1041     // create the ops.
1042 
1043     // First prepare local variables for op arguments used in builder call.
1044     createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1045 
1046     // Then create the op.
1047     os.scope("", "\n}\n").os << formatv(
1048         "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1049         valuePackName, resultOp.getQualCppClassName(), locToUse);
1050     return resultValue;
1051   }
1052 
1053   bool usePartialResults = valuePackName != resultValue;
1054 
1055   if (usePartialResults || depth > 0 || resultIndex < 0) {
1056     // For these cases (broadcastable ops, op results used both as auxiliary
1057     // values and replacement values, ops in nested patterns, auxiliary ops), we
1058     // still need to supply the result types when building the op. But because
1059     // we don't generate a builder automatically with ODS for them, it's the
1060     // developer's responsibility to make sure such a builder (with result type
1061     // deduction ability) exists. We go through the separate-parameter builder
1062     // here given that it's easier for developers to write compared to
1063     // aggregate-parameter builders.
1064     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
1065 
1066     os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
1067                              resultOp.getQualCppClassName(), locToUse);
1068     supplyValuesForOpArgs(tree, childNodeNames, depth);
1069     os << "\n  );\n}\n";
1070     return resultValue;
1071   }
1072 
1073   // If depth == 0 and resultIndex >= 0, it means we are replacing the values
1074   // generated from the source pattern root op. Then we can use the source
1075   // pattern's value types to determine the value type of the generated op
1076   // here.
1077 
1078   // First prepare local variables for op arguments used in builder call.
1079   createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1080 
1081   // Then prepare the result types. We need to specify the types for all
1082   // results.
1083   os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
1084                          "(void)tblgen_types;\n");
1085   int numResults = resultOp.getNumResults();
1086   if (numResults != 0) {
1087     for (int i = 0; i < numResults; ++i)
1088       os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
1089                     "  tblgen_types.push_back(v.getType());\n}\n",
1090                     resultIndex + i);
1091   }
1092   os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
1093                 "tblgen_values, tblgen_attrs);\n",
1094                 valuePackName, resultOp.getQualCppClassName(), locToUse);
1095   os.unindent() << "}\n";
1096   return resultValue;
1097 }
1098 
1099 void PatternEmitter::createSeparateLocalVarsForOpArgs(
1100     DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1101   Operator &resultOp = node.getDialectOp(opMap);
1102 
1103   // Now prepare operands used for building this op:
1104   // * If the operand is non-variadic, we create a `Value` local variable.
1105   // * If the operand is variadic, we create a `SmallVector<Value>` local
1106   //   variable.
1107 
1108   int valueIndex = 0; // An index for uniquing local variable names.
1109   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1110     const auto *operand =
1111         resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
1112     // We do not need special handling for attributes.
1113     if (!operand)
1114       continue;
1115 
1116     raw_indented_ostream::DelimitedScope scope(os);
1117     std::string varName;
1118     if (operand->isVariadic()) {
1119       varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
1120       os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName);
1121       std::string range;
1122       if (node.isNestedDagArg(argIndex)) {
1123         range = childNodeNames[argIndex];
1124       } else {
1125         range = std::string(node.getArgName(argIndex));
1126       }
1127       // Resolve the symbol for all range use so that we have a uniform way of
1128       // capturing the values.
1129       range = symbolInfoMap.getValueAndRangeUse(range);
1130       os << formatv("for (auto v: {0}) {{\n  {1}.push_back(v);\n}\n", range,
1131                     varName);
1132     } else {
1133       varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
1134       os << formatv("::mlir::Value {0} = ", varName);
1135       if (node.isNestedDagArg(argIndex)) {
1136         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
1137       } else {
1138         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1139         auto symbol =
1140             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1141         if (leaf.isNativeCodeCall()) {
1142           os << std::string(
1143               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1144         } else {
1145           os << symbol;
1146         }
1147       }
1148       os << ";\n";
1149     }
1150 
1151     // Update to use the newly created local variable for building the op later.
1152     childNodeNames[argIndex] = varName;
1153   }
1154 }
1155 
1156 void PatternEmitter::supplyValuesForOpArgs(
1157     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1158   Operator &resultOp = node.getDialectOp(opMap);
1159   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1160        argIndex != numOpArgs; ++argIndex) {
1161     // Start each argument on its own line.
1162     os << ",\n    ";
1163 
1164     Argument opArg = resultOp.getArg(argIndex);
1165     // Handle the case of operand first.
1166     if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
1167       if (!operand->name.empty())
1168         os << "/*" << operand->name << "=*/";
1169       os << childNodeNames.lookup(argIndex);
1170       continue;
1171     }
1172 
1173     // The argument in the op definition.
1174     auto opArgName = resultOp.getArgName(argIndex);
1175     if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1176       if (!subTree.isNativeCodeCall())
1177         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1178                              "for creating attribute");
1179       os << formatv("/*{0}=*/{1}", opArgName,
1180                     handleReplaceWithNativeCodeCall(subTree, depth));
1181     } else {
1182       auto leaf = node.getArgAsLeaf(argIndex);
1183       // The argument in the result DAG pattern.
1184       auto patArgName = node.getArgName(argIndex);
1185       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1186         // TODO: Refactor out into map to avoid recomputing these.
1187         if (!opArg.is<NamedAttribute *>())
1188           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1189         if (!patArgName.empty())
1190           os << "/*" << patArgName << "=*/";
1191       } else {
1192         os << "/*" << opArgName << "=*/";
1193       }
1194       os << handleOpArgument(leaf, patArgName);
1195     }
1196   }
1197 }
1198 
1199 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1200     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1201   Operator &resultOp = node.getDialectOp(opMap);
1202 
1203   auto scope = os.scope();
1204   os << formatv("::mlir::SmallVector<::mlir::Value, 4> "
1205                 "tblgen_values; (void)tblgen_values;\n");
1206   os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
1207                 "tblgen_attrs; (void)tblgen_attrs;\n");
1208 
1209   const char *addAttrCmd =
1210       "if (auto tmpAttr = {1}) {\n"
1211       "  tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), "
1212       "tmpAttr);\n}\n";
1213   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1214     if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1215       // The argument in the op definition.
1216       auto opArgName = resultOp.getArgName(argIndex);
1217       if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1218         if (!subTree.isNativeCodeCall())
1219           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1220                                "for creating attribute");
1221         os << formatv(addAttrCmd, opArgName,
1222                       handleReplaceWithNativeCodeCall(subTree, depth + 1));
1223       } else {
1224         auto leaf = node.getArgAsLeaf(argIndex);
1225         // The argument in the result DAG pattern.
1226         auto patArgName = node.getArgName(argIndex);
1227         os << formatv(addAttrCmd, opArgName,
1228                       handleOpArgument(leaf, patArgName));
1229       }
1230       continue;
1231     }
1232 
1233     const auto *operand =
1234         resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1235     std::string varName;
1236     if (operand->isVariadic()) {
1237       std::string range;
1238       if (node.isNestedDagArg(argIndex)) {
1239         range = childNodeNames.lookup(argIndex);
1240       } else {
1241         range = std::string(node.getArgName(argIndex));
1242       }
1243       // Resolve the symbol for all range use so that we have a uniform way of
1244       // capturing the values.
1245       range = symbolInfoMap.getValueAndRangeUse(range);
1246       os << formatv("for (auto v: {0}) {{\n  tblgen_values.push_back(v);\n}\n",
1247                     range);
1248     } else {
1249       os << formatv("tblgen_values.push_back(");
1250       if (node.isNestedDagArg(argIndex)) {
1251         os << symbolInfoMap.getValueAndRangeUse(
1252             childNodeNames.lookup(argIndex));
1253       } else {
1254         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1255         auto symbol =
1256             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1257         if (leaf.isNativeCodeCall()) {
1258           os << std::string(
1259               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1260         } else {
1261           os << symbol;
1262         }
1263       }
1264       os << ");\n";
1265     }
1266   }
1267 }
1268 
1269 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1270   emitSourceFileHeader("Rewriters", os);
1271 
1272   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1273   auto numPatterns = patterns.size();
1274 
1275   // We put the map here because it can be shared among multiple patterns.
1276   RecordOperatorMap recordOpMap;
1277 
1278   std::vector<std::string> rewriterNames;
1279   rewriterNames.reserve(numPatterns);
1280 
1281   std::string baseRewriterName = "GeneratedConvert";
1282   int rewriterIndex = 0;
1283 
1284   for (Record *p : patterns) {
1285     std::string name;
1286     if (p->isAnonymous()) {
1287       // If no name is provided, ensure unique rewriter names simply by
1288       // appending unique suffix.
1289       name = baseRewriterName + llvm::utostr(rewriterIndex++);
1290     } else {
1291       name = std::string(p->getName());
1292     }
1293     LLVM_DEBUG(llvm::dbgs()
1294                << "=== start generating pattern '" << name << "' ===\n");
1295     PatternEmitter(p, &recordOpMap, os).emit(name);
1296     LLVM_DEBUG(llvm::dbgs()
1297                << "=== done generating pattern '" << name << "' ===\n");
1298     rewriterNames.push_back(std::move(name));
1299   }
1300 
1301   // Emit function to add the generated matchers to the pattern list.
1302   os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::MLIRContext "
1303         "*context, ::mlir::OwningRewritePatternList &patterns) {\n";
1304   for (const auto &name : rewriterNames) {
1305     os << "  patterns.insert<" << name << ">(context);\n";
1306   }
1307   os << "}\n";
1308 }
1309 
1310 static mlir::GenRegistration
1311     genRewriters("gen-rewriters", "Generate pattern rewriters",
1312                  [](const RecordKeeper &records, raw_ostream &os) {
1313                    emitRewriters(records, os);
1314                    return false;
1315                  });
1316