xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision ac68637ba94d05f413b1b963950c300fb2f81a99)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Support/STLExtras.h"
23 #include "mlir/TableGen/Attribute.h"
24 #include "mlir/TableGen/Format.h"
25 #include "mlir/TableGen/GenInfo.h"
26 #include "mlir/TableGen/Operator.h"
27 #include "mlir/TableGen/Pattern.h"
28 #include "mlir/TableGen/Predicate.h"
29 #include "mlir/TableGen/Type.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/FormatAdapters.h"
34 #include "llvm/Support/PrettyStackTrace.h"
35 #include "llvm/Support/Signals.h"
36 #include "llvm/TableGen/Error.h"
37 #include "llvm/TableGen/Main.h"
38 #include "llvm/TableGen/Record.h"
39 #include "llvm/TableGen/TableGenBackend.h"
40 
41 using namespace llvm;
42 using namespace mlir;
43 using namespace mlir::tblgen;
44 
45 namespace llvm {
46 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
47   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
48                      raw_ostream &os, StringRef style) {
49     os << v.first << ":" << v.second;
50   }
51 };
52 } // end namespace llvm
53 
54 //===----------------------------------------------------------------------===//
55 // PatternEmitter
56 //===----------------------------------------------------------------------===//
57 
58 namespace {
59 class PatternEmitter {
60 public:
61   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
62 
63   // Emits the mlir::RewritePattern struct named `rewriteName`.
64   void emit(StringRef rewriteName);
65 
66 private:
67   // Emits the code for matching ops.
68   void emitMatchLogic(DagNode tree);
69 
70   // Emits the code for rewriting ops.
71   void emitRewriteLogic();
72 
73   //===--------------------------------------------------------------------===//
74   // Match utilities
75   //===--------------------------------------------------------------------===//
76 
77   // Emits C++ statements for matching the op constrained by the given DAG
78   // `tree`.
79   void emitOpMatch(DagNode tree, int depth);
80 
81   // Emits C++ statements for matching the `index`-th argument of the given DAG
82   // `tree` as an operand.
83   void emitOperandMatch(DagNode tree, int index, int depth, int indent);
84 
85   // Emits C++ statements for matching the `index`-th argument of the given DAG
86   // `tree` as an attribute.
87   void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
88 
89   //===--------------------------------------------------------------------===//
90   // Rewrite utilities
91   //===--------------------------------------------------------------------===//
92 
93   // Entry point for handling a result pattern rooted at `resultTree` and
94   // dispatches to concrete handlers. The given tree is the `resultIndex`-th
95   // argument of the enclosing DAG.
96   std::string handleResultPattern(DagNode resultTree, int resultIndex,
97                                   int depth);
98 
99   // Emits the C++ statement to replace the matched DAG with a value built via
100   // calling native C++ code.
101   std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
102 
103   // Returns the C++ expression referencing the old value serving as the
104   // replacement.
105   std::string handleReplaceWithValue(DagNode tree);
106 
107   // Emits the C++ statement to build a new op out of the given DAG `tree` and
108   // returns the variable name that this op is assigned to. If the root op in
109   // DAG `tree` has a specified name, the created op will be assigned to a
110   // variable of the given name. Otherwise, a unique name will be used as the
111   // result value name.
112   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
113 
114   // Returns the C++ expression to construct a constant attribute of the given
115   // `value` for the given attribute kind `attr`.
116   std::string handleConstantAttr(Attribute attr, StringRef value);
117 
118   // Returns the C++ expression to build an argument from the given DAG `leaf`.
119   // `patArgName` is used to bound the argument to the source pattern.
120   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
121 
122   //===--------------------------------------------------------------------===//
123   // General utilities
124   //===--------------------------------------------------------------------===//
125 
126   // Collects all of the operations within the given dag tree.
127   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
128 
129   // Returns a unique symbol for a local variable of the given `op`.
130   std::string getUniqueSymbol(const Operator *op);
131 
132   //===--------------------------------------------------------------------===//
133   // Symbol utilities
134   //===--------------------------------------------------------------------===//
135 
136   // Gets the substitution for `symbol`. Aborts if `symbol` is not bound.
137   std::string resolveSymbol(StringRef symbol);
138 
139   // Returns how many static values the given DAG `node` correspond to.
140   int getNodeValueCount(DagNode node);
141 
142 private:
143   // Pattern instantiation location followed by the location of multiclass
144   // prototypes used. This is intended to be used as a whole to
145   // PrintFatalError() on errors.
146   ArrayRef<llvm::SMLoc> loc;
147 
148   // Op's TableGen Record to wrapper object.
149   RecordOperatorMap *opMap;
150 
151   // Handy wrapper for pattern being emitted.
152   Pattern pattern;
153 
154   // Map for all bound symbols' info.
155   SymbolInfoMap symbolInfoMap;
156 
157   // The next unused ID for newly created values.
158   unsigned nextValueId;
159 
160   raw_ostream &os;
161 
162   // Format contexts containing placeholder substitutations.
163   FmtContext fmtCtx;
164 
165   // Number of op processed.
166   int opCounter = 0;
167 };
168 } // end anonymous namespace
169 
170 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
171                                raw_ostream &os)
172     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
173       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
174   fmtCtx.withBuilder("rewriter");
175 }
176 
177 std::string PatternEmitter::handleConstantAttr(Attribute attr,
178                                                StringRef value) {
179   if (!attr.isConstBuildable())
180     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
181                              " does not have the 'constBuilderCall' field");
182 
183   // TODO(jpienaar): Verify the constants here
184   return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value);
185 }
186 
187 // Helper function to match patterns.
188 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
189   Operator &op = tree.getDialectOp(opMap);
190   if (op.isVariadic()) {
191     PrintFatalError(loc, formatv("matching op '{0}' with variadic "
192                                  "operands/results is unsupported right now",
193                                  op.getOperationName()));
194   }
195 
196   int indent = 4 + 2 * depth;
197   os.indent(indent) << formatv(
198       "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
199       depth, op.getQualCppClassName());
200   // Skip the operand matching at depth 0 as the pattern rewriter already does.
201   if (depth != 0) {
202     // Skip if there is no defining operation (e.g., arguments to function).
203     os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
204                                  depth);
205   }
206   if (tree.getNumArgs() != op.getNumArgs()) {
207     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
208                                  "pattern vs. {2} in definition",
209                                  op.getOperationName(), tree.getNumArgs(),
210                                  op.getNumArgs()));
211   }
212 
213   // If the operand's name is set, set to that variable.
214   auto name = tree.getSymbol();
215   if (!name.empty())
216     os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
217 
218   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
219     auto opArg = op.getArg(i);
220 
221     // Handle nested DAG construct first
222     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
223       os.indent(indent) << "{\n";
224       os.indent(indent + 2)
225           << formatv("auto *op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
226                      depth + 1, depth, i);
227       emitOpMatch(argTree, depth + 1);
228       os.indent(indent + 2)
229           << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
230       os.indent(indent) << "}\n";
231       continue;
232     }
233 
234     // Next handle DAG leaf: operand or attribute
235     if (opArg.is<NamedTypeConstraint *>()) {
236       emitOperandMatch(tree, i, depth, indent);
237     } else if (opArg.is<NamedAttribute *>()) {
238       emitAttributeMatch(tree, i, depth, indent);
239     } else {
240       PrintFatalError(loc, "unhandled case when matching op");
241     }
242   }
243 }
244 
245 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
246                                       int indent) {
247   Operator &op = tree.getDialectOp(opMap);
248   auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
249   auto matcher = tree.getArgAsLeaf(index);
250 
251   // If a constraint is specified, we need to generate C++ statements to
252   // check the constraint.
253   if (!matcher.isUnspecified()) {
254     if (!matcher.isOperandMatcher()) {
255       PrintFatalError(
256           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
257                        op.getOperationName(), index + 1));
258     }
259 
260     // Only need to verify if the matcher's type is different from the one
261     // of op definition.
262     if (operand->constraint != matcher.getAsConstraint()) {
263       auto self = formatv("op{0}->getOperand({1})->getType()", depth, index);
264       os.indent(indent) << "if (!("
265                         << tgfmt(matcher.getConditionTemplate(),
266                                  &fmtCtx.withSelf(self))
267                         << ")) return matchFailure();\n";
268     }
269   }
270 
271   // Capture the value
272   auto name = tree.getArgName(index);
273   if (!name.empty()) {
274     os.indent(indent) << formatv("{0} = op{1}->getOperand({2});\n", name, depth,
275                                  index);
276   }
277 }
278 
279 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
280                                         int indent) {
281   Operator &op = tree.getDialectOp(opMap);
282   auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
283   const auto &attr = namedAttr->attr;
284 
285   os.indent(indent) << "{\n";
286   indent += 2;
287   os.indent(indent) << formatv(
288       "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
289       attr.getStorageType(), namedAttr->name);
290 
291   // TODO(antiagainst): This should use getter method to avoid duplication.
292   if (attr.hasDefaultValueInitializer()) {
293     os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
294                       << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
295                                attr.getDefaultValueInitializer())
296                       << ";\n";
297   } else if (attr.isOptional()) {
298     // For a missing attribute that is optional according to definition, we
299     // should just capature a mlir::Attribute() to signal the missing state.
300     // That is precisely what getAttr() returns on missing attributes.
301   } else {
302     os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
303   }
304 
305   auto matcher = tree.getArgAsLeaf(index);
306   if (!matcher.isUnspecified()) {
307     if (!matcher.isAttrMatcher()) {
308       PrintFatalError(
309           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
310                        op.getOperationName(), index + 1));
311     }
312 
313     // If a constraint is specified, we need to generate C++ statements to
314     // check the constraint.
315     os.indent(indent) << "if (!("
316                       << tgfmt(matcher.getConditionTemplate(),
317                                &fmtCtx.withSelf("tblgen_attr"))
318                       << ")) return matchFailure();\n";
319   }
320 
321   // Capture the value
322   auto name = tree.getArgName(index);
323   if (!name.empty()) {
324     os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
325   }
326 
327   indent -= 2;
328   os.indent(indent) << "}\n";
329 }
330 
331 void PatternEmitter::emitMatchLogic(DagNode tree) {
332   emitOpMatch(tree, 0);
333 
334   for (auto &appliedConstraint : pattern.getConstraints()) {
335     auto &constraint = appliedConstraint.constraint;
336     auto &entities = appliedConstraint.entities;
337 
338     auto condition = constraint.getConditionTemplate();
339     auto cmd = "if (!({0})) return matchFailure();\n";
340 
341     if (isa<TypeConstraint>(constraint)) {
342       auto self = formatv("({0}->getType())", resolveSymbol(entities.front()));
343       os.indent(4) << formatv(cmd,
344                               tgfmt(condition, &fmtCtx.withSelf(self.str())));
345     } else if (isa<AttrConstraint>(constraint)) {
346       PrintFatalError(
347           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
348     } else {
349       // TODO(b/138794486): replace formatv arguments with the exact specified
350       // args.
351       if (entities.size() > 4) {
352         PrintFatalError(loc, "only support up to 4-entity constraints now");
353       }
354       SmallVector<std::string, 4> names;
355       int i = 0;
356       for (int e = entities.size(); i < e; ++i)
357         names.push_back(resolveSymbol(entities[i]));
358       std::string self = appliedConstraint.self;
359       if (!self.empty())
360         self = resolveSymbol(self);
361       for (; i < 4; ++i)
362         names.push_back("<unused>");
363       os.indent(4) << formatv(cmd,
364                               tgfmt(condition, &fmtCtx.withSelf(self), names[0],
365                                     names[1], names[2], names[3]));
366     }
367   }
368 }
369 
370 void PatternEmitter::collectOps(DagNode tree,
371                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
372   // Check if this tree is an operation.
373   if (tree.isOperation())
374     ops.insert(&tree.getDialectOp(opMap));
375 
376   // Recurse the arguments of the tree.
377   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
378     if (auto child = tree.getArgAsNestedDag(i))
379       collectOps(child, ops);
380 }
381 
382 void PatternEmitter::emit(StringRef rewriteName) {
383   // Get the DAG tree for the source pattern.
384   DagNode sourceTree = pattern.getSourcePattern();
385 
386   const Operator &rootOp = pattern.getSourceRootOp();
387   auto rootName = rootOp.getOperationName();
388 
389   // Collect the set of result operations.
390   llvm::SmallPtrSet<const Operator *, 4> resultOps;
391   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i)
392     collectOps(pattern.getResultPattern(i), resultOps);
393 
394   // Emit RewritePattern for Pattern.
395   auto locs = pattern.getLocation();
396   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
397                 make_range(locs.rbegin(), locs.rend()));
398   os << formatv(R"(struct {0} : public RewritePattern {
399   {0}(MLIRContext *context)
400       : RewritePattern("{1}", {{)",
401                 rewriteName, rootName);
402   interleaveComma(resultOps, os, [&](const Operator *op) {
403     os << '"' << op->getOperationName() << '"';
404   });
405   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
406 
407   // Emit matchAndRewrite() function.
408   os << R"(
409   PatternMatchResult matchAndRewrite(Operation *op0,
410                                      PatternRewriter &rewriter) const override {
411 )";
412 
413   // Register all symbols bound in the source pattern.
414   pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
415 
416   os.indent(4) << "// Variables for capturing values and attributes used for "
417                   "creating ops\n";
418   // Create local variables for storing the arguments and results bound
419   // to symbols.
420   for (const auto &symbolInfoPair : symbolInfoMap) {
421     StringRef symbol = symbolInfoPair.getKey();
422     auto &info = symbolInfoPair.getValue();
423     os.indent(4) << info.getVarDecl(symbol);
424   }
425   // TODO(jpienaar): capture ops with consistent numbering so that it can be
426   // reused for fused loc.
427   os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
428                           pattern.getSourcePattern().getNumOps());
429 
430   os.indent(4) << "// Match\n";
431   os.indent(4) << "tblgen_ops[0] = op0;\n";
432   emitMatchLogic(sourceTree);
433   os << "\n";
434 
435   os.indent(4) << "// Rewrite\n";
436   emitRewriteLogic();
437 
438   os.indent(4) << "return matchSuccess();\n";
439   os << "  };\n";
440   os << "};\n";
441 }
442 
443 void PatternEmitter::emitRewriteLogic() {
444   const Operator &rootOp = pattern.getSourceRootOp();
445   int numExpectedResults = rootOp.getNumResults();
446   int numResultPatterns = pattern.getNumResultPatterns();
447 
448   // First register all symbols bound to ops generated in result patterns.
449   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
450 
451   // Only the last N static values generated are used to replace the matched
452   // root N-result op. We need to calculate the starting index (of the results
453   // of the matched op) each result pattern is to replace.
454   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
455   // If we don't need to replace any value at all, set the replacement starting
456   // index as the number of result patterns so we skip all of them when trying
457   // to replace the matched op's results.
458   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
459   for (int i = numResultPatterns - 1; i >= 0; --i) {
460     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
461     offsets[i] = offsets[i + 1] - numValues;
462     if (offsets[i] == 0) {
463       if (replStartIndex == -1)
464         replStartIndex = i;
465     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
466       auto error = formatv(
467           "cannot use the same multi-result op '{0}' to generate both "
468           "auxiliary values and values to be used for replacing the matched op",
469           pattern.getResultPattern(i).getSymbol());
470       PrintFatalError(loc, error);
471     }
472   }
473 
474   if (offsets.front() > 0) {
475     const char error[] = "no enough values generated to replace the matched op";
476     PrintFatalError(loc, error);
477   }
478 
479   os.indent(4) << "auto loc = rewriter.getFusedLoc({";
480   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
481     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
482   }
483   os << "}); (void)loc;\n";
484 
485   // Collect the replacement value for each result
486   llvm::SmallVector<std::string, 2> resultValues;
487   for (int i = 0; i < numResultPatterns; ++i) {
488     DagNode resultTree = pattern.getResultPattern(i);
489     resultValues.push_back(handleResultPattern(resultTree, offsets[i], 0));
490   }
491 
492   // Emit the final replaceOp() statement
493   os.indent(4) << "rewriter.replaceOp(op0, {";
494   interleaveComma(
495       ArrayRef<std::string>(resultValues).drop_front(replStartIndex), os,
496       [&](const std::string &symbol) { os << resolveSymbol(symbol); });
497   os << "});\n";
498 }
499 
500 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
501   return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
502 }
503 
504 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
505                                                 int resultIndex, int depth) {
506   if (resultTree.isNativeCodeCall()) {
507     auto symbol = handleReplaceWithNativeCodeCall(resultTree);
508     symbolInfoMap.bindValue(symbol);
509     return symbol;
510   }
511 
512   if (resultTree.isReplaceWithValue()) {
513     return handleReplaceWithValue(resultTree);
514   }
515 
516   // Normal op creation.
517   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
518   if (resultTree.getSymbol().empty()) {
519     // This is an op not explicitly bound to a symbol in the rewrite rule.
520     // Register the auto-generated symbol for it.
521     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
522   }
523   return symbol;
524 }
525 
526 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
527   assert(tree.isReplaceWithValue());
528 
529   if (tree.getNumArgs() != 1) {
530     PrintFatalError(
531         loc, "replaceWithValue directive must take exactly one argument");
532   }
533 
534   if (!tree.getSymbol().empty()) {
535     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
536   }
537 
538   return resolveSymbol(tree.getArgName(0));
539 }
540 
541 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) {
542   if (leaf.isConstantAttr()) {
543     auto constAttr = leaf.getAsConstantAttr();
544     return handleConstantAttr(constAttr.getAttribute(),
545                               constAttr.getConstantValue());
546   }
547   if (leaf.isEnumAttrCase()) {
548     auto enumCase = leaf.getAsEnumAttrCase();
549     if (enumCase.isStrCase())
550       return handleConstantAttr(enumCase, enumCase.getSymbol());
551     // This is an enum case backed by an IntegerAttr. We need to get its value
552     // to build the constant.
553     std::string val = std::to_string(enumCase.getValue());
554     return handleConstantAttr(enumCase, val);
555   }
556   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
557     return argName;
558   }
559   if (leaf.isNativeCodeCall()) {
560     return tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
561   }
562   PrintFatalError(loc, "unhandled case when rewriting op");
563 }
564 
565 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
566   auto fmt = tree.getNativeCodeTemplate();
567   // TODO(b/138794486): replace formatv arguments with the exact specified args.
568   SmallVector<std::string, 8> attrs(8);
569   if (tree.getNumArgs() > 8) {
570     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
571                              Twine(tree.getNumArgs()));
572   }
573   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
574     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
575   }
576   return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
577                attrs[5], attrs[6], attrs[7]);
578 }
579 
580 std::string PatternEmitter::resolveSymbol(StringRef symbol) {
581   auto subst = symbolInfoMap.getValueAndRangeUse(symbol);
582   if (subst.empty()) {
583     PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol));
584   }
585   return subst;
586 }
587 
588 int PatternEmitter::getNodeValueCount(DagNode node) {
589   if (node.isOperation()) {
590     // If the op is bound to a symbol in the rewrite rule, query its result
591     // count from the symbol info map.
592     auto symbol = node.getSymbol();
593     if (!symbol.empty()) {
594       return symbolInfoMap.getStaticValueCount(symbol);
595     }
596     // Otherwise this is an unbound op; we will use all its results.
597     return pattern.getDialectOp(node).getNumResults();
598   }
599   // TODO(antiagainst): This considers all NativeCodeCall as returning one
600   // value. Enhance if multi-value ones are needed.
601   return 1;
602 }
603 
604 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
605                                              int depth) {
606   Operator &resultOp = tree.getDialectOp(opMap);
607   auto numOpArgs = resultOp.getNumArgs();
608 
609   if (resultOp.isVariadic()) {
610     PrintFatalError(loc, formatv("generating op '{0}' with variadic "
611                                  "operands/results is unsupported now",
612                                  resultOp.getOperationName()));
613   }
614 
615   if (numOpArgs != tree.getNumArgs()) {
616     PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
617                                  "{1} in pattern vs. {2} in definition",
618                                  resultOp.getOperationName(), tree.getNumArgs(),
619                                  numOpArgs));
620   }
621 
622   // A map to collect all nested DAG child nodes' names, with operand index as
623   // the key. This includes both bound and unbound child nodes. Bound child
624   // nodes will additionally be tracked in `symbolResolver` so they can be
625   // referenced by other patterns. Unbound child nodes will only be used once
626   // to build this op.
627   llvm::DenseMap<unsigned, std::string> childNodeNames;
628 
629   // First go through all the child nodes who are nested DAG constructs to
630   // create ops for them, so that we can use the results in the current node.
631   // This happens in a recursive manner.
632   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
633     if (auto child = tree.getArgAsNestedDag(i)) {
634       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
635     }
636   }
637 
638   // Use the specified name for this op if available. Generate one otherwise.
639   std::string resultValue = tree.getSymbol();
640   if (resultValue.empty())
641     resultValue = getUniqueSymbol(&resultOp);
642   // Strip the index to get the name for the value pack. This will be used to
643   // name the local variable for the op.
644   StringRef valuePackName = SymbolInfoMap::getValuePackName(resultValue);
645 
646   // Then we build the new op corresponding to this DAG node.
647 
648   // Right now we don't have general type inference in MLIR. Except a few
649   // special cases listed below, we need to supply types for all results
650   // when building an op.
651   bool isSameOperandsAndResultType =
652       resultOp.hasTrait("SameOperandsAndResultType");
653   bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult");
654   bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType");
655   bool usePartialResults = valuePackName != resultValue;
656 
657   if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
658       usePartialResults || depth > 0 || resultIndex < 0) {
659     os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
660                             valuePackName, resultOp.getQualCppClassName());
661   } else {
662     // If depth == 0 and resultIndex >= 0, it means we are replacing the values
663     // generated from the source pattern root op. Then we can use the source
664     // pattern's value types to determine the value type of the generated op
665     // here.
666 
667     // We need to specify the types for all results.
668     SmallVector<std::string, 4> resultTypes;
669     int numResults = resultOp.getNumResults();
670     resultTypes.reserve(numResults);
671     for (int i = 0; i < numResults; ++i) {
672       resultTypes.push_back(
673           formatv("op0->getResult({0})->getType()", resultIndex + i));
674     }
675 
676     os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
677                             valuePackName, resultOp.getQualCppClassName())
678                  << (resultTypes.empty() ? "" : ", ")
679                  << llvm::join(resultTypes, ", ");
680   }
681 
682   // Create the builder call for the result.
683   // Add operands.
684   int argIndex = 0;
685   for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
686     const auto &operand = resultOp.getOperand(argIndex);
687 
688     // Start each operand on its own line.
689     (os << ",\n").indent(6);
690 
691     if (!operand.name.empty())
692       os << "/*" << operand.name << "=*/";
693 
694     if (tree.isNestedDagArg(argIndex)) {
695       os << childNodeNames[argIndex];
696     } else {
697       DagLeaf leaf = tree.getArgAsLeaf(argIndex);
698       auto symbol = resolveSymbol(tree.getArgName(argIndex));
699       if (leaf.isNativeCodeCall()) {
700         os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
701       } else {
702         os << symbol;
703       }
704     }
705     // TODO(jpienaar): verify types
706   }
707 
708   // Add attributes.
709   for (; argIndex != numOpArgs; ++argIndex) {
710     // Start each attribute on its own line.
711     (os << ",\n").indent(6);
712     // The argument in the op definition.
713     auto opArgName = resultOp.getArgName(argIndex);
714     if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
715       if (!subTree.isNativeCodeCall())
716         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
717                              "for creating attribute");
718       os << formatv("/*{0}=*/{1}", opArgName,
719                     handleReplaceWithNativeCodeCall(subTree));
720     } else {
721       auto leaf = tree.getArgAsLeaf(argIndex);
722       // The argument in the result DAG pattern.
723       auto patArgName = tree.getArgName(argIndex);
724       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
725         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
726         auto argument = resultOp.getArg(argIndex);
727         if (!argument.is<NamedAttribute *>())
728           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
729         if (!patArgName.empty())
730           os << "/*" << patArgName << "=*/";
731       } else {
732         os << "/*" << opArgName << "=*/";
733       }
734       os << handleOpArgument(leaf, patArgName);
735     }
736   }
737   os << "\n    );\n";
738 
739   return resultValue;
740 }
741 
742 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
743   emitSourceFileHeader("Rewriters", os);
744 
745   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
746   auto numPatterns = patterns.size();
747 
748   // We put the map here because it can be shared among multiple patterns.
749   RecordOperatorMap recordOpMap;
750 
751   std::vector<std::string> rewriterNames;
752   rewriterNames.reserve(numPatterns);
753 
754   std::string baseRewriterName = "GeneratedConvert";
755   int rewriterIndex = 0;
756 
757   for (Record *p : patterns) {
758     std::string name;
759     if (p->isAnonymous()) {
760       // If no name is provided, ensure unique rewriter names simply by
761       // appending unique suffix.
762       name = baseRewriterName + llvm::utostr(rewriterIndex++);
763     } else {
764       name = p->getName();
765     }
766     PatternEmitter(p, &recordOpMap, os).emit(name);
767     rewriterNames.push_back(std::move(name));
768   }
769 
770   // Emit function to add the generated matchers to the pattern list.
771   os << "void populateWithGenerated(MLIRContext *context, "
772      << "OwningRewritePatternList *patterns) {\n";
773   for (const auto &name : rewriterNames) {
774     os << "  patterns->insert<" << name << ">(context);\n";
775   }
776   os << "}\n";
777 }
778 
779 static mlir::GenRegistration
780     genRewriters("gen-rewriters", "Generate pattern rewriters",
781                  [](const RecordKeeper &records, raw_ostream &os) {
782                    emitRewriters(records, os);
783                    return false;
784                  });
785