xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 31cfee60773c6277f652aab5fa1f8f824a551627)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Support/STLExtras.h"
23 #include "mlir/TableGen/Attribute.h"
24 #include "mlir/TableGen/Format.h"
25 #include "mlir/TableGen/GenInfo.h"
26 #include "mlir/TableGen/Operator.h"
27 #include "mlir/TableGen/Pattern.h"
28 #include "mlir/TableGen/Predicate.h"
29 #include "mlir/TableGen/Type.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/FormatAdapters.h"
34 #include "llvm/Support/PrettyStackTrace.h"
35 #include "llvm/Support/Signals.h"
36 #include "llvm/TableGen/Error.h"
37 #include "llvm/TableGen/Main.h"
38 #include "llvm/TableGen/Record.h"
39 #include "llvm/TableGen/TableGenBackend.h"
40 
41 using namespace llvm;
42 using namespace mlir;
43 using namespace mlir::tblgen;
44 
45 namespace llvm {
46 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
47   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
48                      raw_ostream &os, StringRef style) {
49     os << v.first << ":" << v.second;
50   }
51 };
52 } // end namespace llvm
53 
54 //===----------------------------------------------------------------------===//
55 // PatternEmitter
56 //===----------------------------------------------------------------------===//
57 
58 namespace {
59 class PatternEmitter {
60 public:
61   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
62 
63   // Emits the mlir::RewritePattern struct named `rewriteName`.
64   void emit(StringRef rewriteName);
65 
66 private:
67   // Emits the code for matching ops.
68   void emitMatchLogic(DagNode tree);
69 
70   // Emits the code for rewriting ops.
71   void emitRewriteLogic();
72 
73   //===--------------------------------------------------------------------===//
74   // Match utilities
75   //===--------------------------------------------------------------------===//
76 
77   // Emits C++ statements for matching the op constrained by the given DAG
78   // `tree`.
79   void emitOpMatch(DagNode tree, int depth);
80 
81   // Emits C++ statements for matching the `index`-th argument of the given DAG
82   // `tree` as an operand.
83   void emitOperandMatch(DagNode tree, int index, int depth, int indent);
84 
85   // Emits C++ statements for matching the `index`-th argument of the given DAG
86   // `tree` as an attribute.
87   void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
88 
89   //===--------------------------------------------------------------------===//
90   // Rewrite utilities
91   //===--------------------------------------------------------------------===//
92 
93   // The entry point for handling a result pattern rooted at `resultTree`. This
94   // method dispatches to concrete handlers according to `resultTree`'s kind and
95   // returns a symbol representing the whole value pack. Callers are expected to
96   // further resolve the symbol according to the specific use case.
97   //
98   // `depth` is the nesting level of `resultTree`; 0 means top-level result
99   // pattern. For top-level result pattern, `resultIndex` indicates which result
100   // of the matched root op this pattern is intended to replace, which can be
101   // used to deduce the result type of the op generated from this result
102   // pattern.
103   std::string handleResultPattern(DagNode resultTree, int resultIndex,
104                                   int depth);
105 
106   // Emits the C++ statement to replace the matched DAG with a value built via
107   // calling native C++ code.
108   std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
109 
110   // Returns the C++ expression referencing the old value serving as the
111   // replacement.
112   std::string handleReplaceWithValue(DagNode tree);
113 
114   // Emits the C++ statement to build a new op out of the given DAG `tree` and
115   // returns the variable name that this op is assigned to. If the root op in
116   // DAG `tree` has a specified name, the created op will be assigned to a
117   // variable of the given name. Otherwise, a unique name will be used as the
118   // result value name.
119   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
120 
121   // Returns the C++ expression to construct a constant attribute of the given
122   // `value` for the given attribute kind `attr`.
123   std::string handleConstantAttr(Attribute attr, StringRef value);
124 
125   // Returns the C++ expression to build an argument from the given DAG `leaf`.
126   // `patArgName` is used to bound the argument to the source pattern.
127   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
128 
129   //===--------------------------------------------------------------------===//
130   // General utilities
131   //===--------------------------------------------------------------------===//
132 
133   // Collects all of the operations within the given dag tree.
134   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
135 
136   // Returns a unique symbol for a local variable of the given `op`.
137   std::string getUniqueSymbol(const Operator *op);
138 
139   //===--------------------------------------------------------------------===//
140   // Symbol utilities
141   //===--------------------------------------------------------------------===//
142 
143   // Returns how many static values the given DAG `node` correspond to.
144   int getNodeValueCount(DagNode node);
145 
146 private:
147   // Pattern instantiation location followed by the location of multiclass
148   // prototypes used. This is intended to be used as a whole to
149   // PrintFatalError() on errors.
150   ArrayRef<llvm::SMLoc> loc;
151 
152   // Op's TableGen Record to wrapper object.
153   RecordOperatorMap *opMap;
154 
155   // Handy wrapper for pattern being emitted.
156   Pattern pattern;
157 
158   // Map for all bound symbols' info.
159   SymbolInfoMap symbolInfoMap;
160 
161   // The next unused ID for newly created values.
162   unsigned nextValueId;
163 
164   raw_ostream &os;
165 
166   // Format contexts containing placeholder substitutations.
167   FmtContext fmtCtx;
168 
169   // Number of op processed.
170   int opCounter = 0;
171 };
172 } // end anonymous namespace
173 
174 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
175                                raw_ostream &os)
176     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
177       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
178   fmtCtx.withBuilder("rewriter");
179 }
180 
181 std::string PatternEmitter::handleConstantAttr(Attribute attr,
182                                                StringRef value) {
183   if (!attr.isConstBuildable())
184     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
185                              " does not have the 'constBuilderCall' field");
186 
187   // TODO(jpienaar): Verify the constants here
188   return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value);
189 }
190 
191 // Helper function to match patterns.
192 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
193   Operator &op = tree.getDialectOp(opMap);
194 
195   int indent = 4 + 2 * depth;
196   os.indent(indent) << formatv(
197       "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
198       depth, op.getQualCppClassName());
199   // Skip the operand matching at depth 0 as the pattern rewriter already does.
200   if (depth != 0) {
201     // Skip if there is no defining operation (e.g., arguments to function).
202     os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
203                                  depth);
204   }
205   if (tree.getNumArgs() != op.getNumArgs()) {
206     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
207                                  "pattern vs. {2} in definition",
208                                  op.getOperationName(), tree.getNumArgs(),
209                                  op.getNumArgs()));
210   }
211 
212   // If the operand's name is set, set to that variable.
213   auto name = tree.getSymbol();
214   if (!name.empty())
215     os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
216 
217   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
218     auto opArg = op.getArg(i);
219 
220     // Handle nested DAG construct first
221     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
222       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
223         if (operand->isVariadic()) {
224           auto error = formatv("use nested DAG construct to match op {0}'s "
225                                "variadic operand #{1} unsupported now",
226                                op.getOperationName(), i);
227           PrintFatalError(loc, error);
228         }
229       }
230       os.indent(indent) << "{\n";
231 
232       os.indent(indent + 2) << formatv(
233           "auto *op{0} = "
234           "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n",
235           depth + 1, depth, i);
236       emitOpMatch(argTree, depth + 1);
237       os.indent(indent + 2)
238           << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
239       os.indent(indent) << "}\n";
240       continue;
241     }
242 
243     // Next handle DAG leaf: operand or attribute
244     if (opArg.is<NamedTypeConstraint *>()) {
245       emitOperandMatch(tree, i, depth, indent);
246     } else if (opArg.is<NamedAttribute *>()) {
247       emitAttributeMatch(tree, i, depth, indent);
248     } else {
249       PrintFatalError(loc, "unhandled case when matching op");
250     }
251   }
252 }
253 
254 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
255                                       int indent) {
256   Operator &op = tree.getDialectOp(opMap);
257   auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
258   auto matcher = tree.getArgAsLeaf(index);
259 
260   // If a constraint is specified, we need to generate C++ statements to
261   // check the constraint.
262   if (!matcher.isUnspecified()) {
263     if (!matcher.isOperandMatcher()) {
264       PrintFatalError(
265           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
266                        op.getOperationName(), index + 1));
267     }
268 
269     // Only need to verify if the matcher's type is different from the one
270     // of op definition.
271     if (operand->constraint != matcher.getAsConstraint()) {
272       if (operand->isVariadic()) {
273         auto error = formatv(
274             "further constrain op {0}'s variadic operand #{1} unsupported now",
275             op.getOperationName(), index);
276         PrintFatalError(loc, error);
277       }
278       auto self =
279           formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
280                   depth, index);
281       os.indent(indent) << "if (!("
282                         << tgfmt(matcher.getConditionTemplate(),
283                                  &fmtCtx.withSelf(self))
284                         << ")) return matchFailure();\n";
285     }
286   }
287 
288   // Capture the value
289   auto name = tree.getArgName(index);
290   if (!name.empty()) {
291     os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
292                                  name, depth, index);
293   }
294 }
295 
296 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
297                                         int indent) {
298   Operator &op = tree.getDialectOp(opMap);
299   auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
300   const auto &attr = namedAttr->attr;
301 
302   os.indent(indent) << "{\n";
303   indent += 2;
304   os.indent(indent) << formatv(
305       "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
306       attr.getStorageType(), namedAttr->name);
307 
308   // TODO(antiagainst): This should use getter method to avoid duplication.
309   if (attr.hasDefaultValueInitializer()) {
310     os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
311                       << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
312                                attr.getDefaultValueInitializer())
313                       << ";\n";
314   } else if (attr.isOptional()) {
315     // For a missing attribute that is optional according to definition, we
316     // should just capature a mlir::Attribute() to signal the missing state.
317     // That is precisely what getAttr() returns on missing attributes.
318   } else {
319     os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
320   }
321 
322   auto matcher = tree.getArgAsLeaf(index);
323   if (!matcher.isUnspecified()) {
324     if (!matcher.isAttrMatcher()) {
325       PrintFatalError(
326           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
327                        op.getOperationName(), index + 1));
328     }
329 
330     // If a constraint is specified, we need to generate C++ statements to
331     // check the constraint.
332     os.indent(indent) << "if (!("
333                       << tgfmt(matcher.getConditionTemplate(),
334                                &fmtCtx.withSelf("tblgen_attr"))
335                       << ")) return matchFailure();\n";
336   }
337 
338   // Capture the value
339   auto name = tree.getArgName(index);
340   if (!name.empty()) {
341     os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
342   }
343 
344   indent -= 2;
345   os.indent(indent) << "}\n";
346 }
347 
348 void PatternEmitter::emitMatchLogic(DagNode tree) {
349   emitOpMatch(tree, 0);
350 
351   for (auto &appliedConstraint : pattern.getConstraints()) {
352     auto &constraint = appliedConstraint.constraint;
353     auto &entities = appliedConstraint.entities;
354 
355     auto condition = constraint.getConditionTemplate();
356     auto cmd = "if (!({0})) return matchFailure();\n";
357 
358     if (isa<TypeConstraint>(constraint)) {
359       auto self = formatv("({0}->getType())",
360                           symbolInfoMap.getValueAndRangeUse(entities.front()));
361       os.indent(4) << formatv(cmd,
362                               tgfmt(condition, &fmtCtx.withSelf(self.str())));
363     } else if (isa<AttrConstraint>(constraint)) {
364       PrintFatalError(
365           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
366     } else {
367       // TODO(b/138794486): replace formatv arguments with the exact specified
368       // args.
369       if (entities.size() > 4) {
370         PrintFatalError(loc, "only support up to 4-entity constraints now");
371       }
372       SmallVector<std::string, 4> names;
373       int i = 0;
374       for (int e = entities.size(); i < e; ++i)
375         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
376       std::string self = appliedConstraint.self;
377       if (!self.empty())
378         self = symbolInfoMap.getValueAndRangeUse(self);
379       for (; i < 4; ++i)
380         names.push_back("<unused>");
381       os.indent(4) << formatv(cmd,
382                               tgfmt(condition, &fmtCtx.withSelf(self), names[0],
383                                     names[1], names[2], names[3]));
384     }
385   }
386 }
387 
388 void PatternEmitter::collectOps(DagNode tree,
389                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
390   // Check if this tree is an operation.
391   if (tree.isOperation())
392     ops.insert(&tree.getDialectOp(opMap));
393 
394   // Recurse the arguments of the tree.
395   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
396     if (auto child = tree.getArgAsNestedDag(i))
397       collectOps(child, ops);
398 }
399 
400 void PatternEmitter::emit(StringRef rewriteName) {
401   // Get the DAG tree for the source pattern.
402   DagNode sourceTree = pattern.getSourcePattern();
403 
404   const Operator &rootOp = pattern.getSourceRootOp();
405   auto rootName = rootOp.getOperationName();
406 
407   // Collect the set of result operations.
408   llvm::SmallPtrSet<const Operator *, 4> resultOps;
409   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i)
410     collectOps(pattern.getResultPattern(i), resultOps);
411 
412   // Emit RewritePattern for Pattern.
413   auto locs = pattern.getLocation();
414   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
415                 make_range(locs.rbegin(), locs.rend()));
416   os << formatv(R"(struct {0} : public RewritePattern {
417   {0}(MLIRContext *context)
418       : RewritePattern("{1}", {{)",
419                 rewriteName, rootName);
420   interleaveComma(resultOps, os, [&](const Operator *op) {
421     os << '"' << op->getOperationName() << '"';
422   });
423   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
424 
425   // Emit matchAndRewrite() function.
426   os << R"(
427   PatternMatchResult matchAndRewrite(Operation *op0,
428                                      PatternRewriter &rewriter) const override {
429 )";
430 
431   // Register all symbols bound in the source pattern.
432   pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
433 
434   os.indent(4) << "// Variables for capturing values and attributes used for "
435                   "creating ops\n";
436   // Create local variables for storing the arguments and results bound
437   // to symbols.
438   for (const auto &symbolInfoPair : symbolInfoMap) {
439     StringRef symbol = symbolInfoPair.getKey();
440     auto &info = symbolInfoPair.getValue();
441     os.indent(4) << info.getVarDecl(symbol);
442   }
443   // TODO(jpienaar): capture ops with consistent numbering so that it can be
444   // reused for fused loc.
445   os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
446                           pattern.getSourcePattern().getNumOps());
447 
448   os.indent(4) << "// Match\n";
449   os.indent(4) << "tblgen_ops[0] = op0;\n";
450   emitMatchLogic(sourceTree);
451   os << "\n";
452 
453   os.indent(4) << "// Rewrite\n";
454   emitRewriteLogic();
455 
456   os.indent(4) << "return matchSuccess();\n";
457   os << "  };\n";
458   os << "};\n";
459 }
460 
461 void PatternEmitter::emitRewriteLogic() {
462   const Operator &rootOp = pattern.getSourceRootOp();
463   int numExpectedResults = rootOp.getNumResults();
464   int numResultPatterns = pattern.getNumResultPatterns();
465 
466   // First register all symbols bound to ops generated in result patterns.
467   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
468 
469   // Only the last N static values generated are used to replace the matched
470   // root N-result op. We need to calculate the starting index (of the results
471   // of the matched op) each result pattern is to replace.
472   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
473   // If we don't need to replace any value at all, set the replacement starting
474   // index as the number of result patterns so we skip all of them when trying
475   // to replace the matched op's results.
476   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
477   for (int i = numResultPatterns - 1; i >= 0; --i) {
478     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
479     offsets[i] = offsets[i + 1] - numValues;
480     if (offsets[i] == 0) {
481       if (replStartIndex == -1)
482         replStartIndex = i;
483     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
484       auto error = formatv(
485           "cannot use the same multi-result op '{0}' to generate both "
486           "auxiliary values and values to be used for replacing the matched op",
487           pattern.getResultPattern(i).getSymbol());
488       PrintFatalError(loc, error);
489     }
490   }
491 
492   if (offsets.front() > 0) {
493     const char error[] = "no enough values generated to replace the matched op";
494     PrintFatalError(loc, error);
495   }
496 
497   os.indent(4) << "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n";
498   os.indent(4) << "auto loc = rewriter.getFusedLoc({";
499   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
500     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
501   }
502   os << "}); (void)loc;\n";
503 
504   // Process each result pattern and record the result symbol.
505   llvm::SmallVector<std::string, 2> resultValues;
506   for (int i = 0; i < numResultPatterns; ++i) {
507     DagNode resultTree = pattern.getResultPattern(i);
508     resultValues.push_back(handleResultPattern(resultTree, offsets[i], 0));
509   }
510 
511   os.indent(4) << "SmallVector<Value *, 4> tblgen_values;";
512   // Only use the last portion for replacing the matched root op's results.
513   auto range = llvm::makeArrayRef(resultValues).drop_front(replStartIndex);
514   for (const auto &val : range) {
515     os.indent(4) << "\n";
516     // Resolve each symbol for all range use so that we can loop over them.
517     os << symbolInfoMap.getAllRangeUse(
518         val, "    for (auto *v : {0}) tblgen_values.push_back(v);", "\n");
519   }
520   os.indent(4) << "\n";
521   os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n";
522 }
523 
524 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
525   return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
526 }
527 
528 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
529                                                 int resultIndex, int depth) {
530   if (resultTree.isNativeCodeCall()) {
531     auto symbol = handleReplaceWithNativeCodeCall(resultTree);
532     symbolInfoMap.bindValue(symbol);
533     return symbol;
534   }
535 
536   if (resultTree.isReplaceWithValue()) {
537     return handleReplaceWithValue(resultTree);
538   }
539 
540   // Normal op creation.
541   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
542   if (resultTree.getSymbol().empty()) {
543     // This is an op not explicitly bound to a symbol in the rewrite rule.
544     // Register the auto-generated symbol for it.
545     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
546   }
547   return symbol;
548 }
549 
550 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
551   assert(tree.isReplaceWithValue());
552 
553   if (tree.getNumArgs() != 1) {
554     PrintFatalError(
555         loc, "replaceWithValue directive must take exactly one argument");
556   }
557 
558   if (!tree.getSymbol().empty()) {
559     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
560   }
561 
562   return tree.getArgName(0);
563 }
564 
565 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
566                                              StringRef patArgName) {
567   if (leaf.isConstantAttr()) {
568     auto constAttr = leaf.getAsConstantAttr();
569     return handleConstantAttr(constAttr.getAttribute(),
570                               constAttr.getConstantValue());
571   }
572   if (leaf.isEnumAttrCase()) {
573     auto enumCase = leaf.getAsEnumAttrCase();
574     if (enumCase.isStrCase())
575       return handleConstantAttr(enumCase, enumCase.getSymbol());
576     // This is an enum case backed by an IntegerAttr. We need to get its value
577     // to build the constant.
578     std::string val = std::to_string(enumCase.getValue());
579     return handleConstantAttr(enumCase, val);
580   }
581 
582   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
583   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
584     return argName;
585   }
586   if (leaf.isNativeCodeCall()) {
587     return tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
588   }
589   PrintFatalError(loc, "unhandled case when rewriting op");
590 }
591 
592 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
593   auto fmt = tree.getNativeCodeTemplate();
594   // TODO(b/138794486): replace formatv arguments with the exact specified args.
595   SmallVector<std::string, 8> attrs(8);
596   if (tree.getNumArgs() > 8) {
597     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
598                              Twine(tree.getNumArgs()));
599   }
600   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
601     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
602   }
603   return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
604                attrs[5], attrs[6], attrs[7]);
605 }
606 
607 int PatternEmitter::getNodeValueCount(DagNode node) {
608   if (node.isOperation()) {
609     // If the op is bound to a symbol in the rewrite rule, query its result
610     // count from the symbol info map.
611     auto symbol = node.getSymbol();
612     if (!symbol.empty()) {
613       return symbolInfoMap.getStaticValueCount(symbol);
614     }
615     // Otherwise this is an unbound op; we will use all its results.
616     return pattern.getDialectOp(node).getNumResults();
617   }
618   // TODO(antiagainst): This considers all NativeCodeCall as returning one
619   // value. Enhance if multi-value ones are needed.
620   return 1;
621 }
622 
623 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
624                                              int depth) {
625   Operator &resultOp = tree.getDialectOp(opMap);
626   auto numOpArgs = resultOp.getNumArgs();
627 
628   if (numOpArgs != tree.getNumArgs()) {
629     PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
630                                  "{1} in pattern vs. {2} in definition",
631                                  resultOp.getOperationName(), tree.getNumArgs(),
632                                  numOpArgs));
633   }
634 
635   // A map to collect all nested DAG child nodes' names, with operand index as
636   // the key. This includes both bound and unbound child nodes.
637   llvm::DenseMap<unsigned, std::string> childNodeNames;
638 
639   // First go through all the child nodes who are nested DAG constructs to
640   // create ops for them and remember the symbol names for them, so that we can
641   // use the results in the current node. This happens in a recursive manner.
642   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
643     if (auto child = tree.getArgAsNestedDag(i)) {
644       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
645     }
646   }
647 
648   // The name of the local variable holding this op.
649   std::string valuePackName;
650   // The symbol for holding the result of this pattern. Note that the result of
651   // this pattern is not necessarily the same as the variable created by this
652   // pattern because we can use `__N` suffix to refer only a specific result if
653   // the generated op is a multi-result op.
654   std::string resultValue;
655   if (tree.getSymbol().empty()) {
656     // No symbol is explicitly bound to this op in the pattern. Generate a
657     // unique name.
658     valuePackName = resultValue = getUniqueSymbol(&resultOp);
659   } else {
660     resultValue = tree.getSymbol();
661     // Strip the index to get the name for the value pack and use it to name the
662     // local variable for the op.
663     valuePackName = SymbolInfoMap::getValuePackName(resultValue);
664   }
665 
666   // Create the local variable for this op.
667   os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
668                           valuePackName);
669   os.indent(4) << "{\n";
670 
671   // Now prepare operands used for building this op:
672   // * If the operand is non-variadic, we create a `Value*` local variable.
673   // * If the operand is variadic, we create a `SmallVector<Value*>` local
674   //   variable.
675 
676   int argIndex = 0;   // The current index to this op's ODS argument
677   int valueIndex = 0; // An index for uniquing local variable names.
678   for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
679     const auto &operand = resultOp.getOperand(argIndex);
680     std::string varName;
681     if (operand.isVariadic()) {
682       varName = formatv("tblgen_values_{0}", valueIndex++);
683       os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
684       std::string range;
685       if (tree.isNestedDagArg(argIndex)) {
686         range = childNodeNames[argIndex];
687       } else {
688         range = tree.getArgName(argIndex);
689       }
690       // Resolve the symbol for all range use so that we have a uniform way of
691       // capturing the values.
692       range = symbolInfoMap.getValueAndRangeUse(range);
693       os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range,
694                               varName);
695     } else {
696       varName = formatv("tblgen_value_{0}", valueIndex++);
697       os.indent(6) << formatv("Value *{0} = ", varName);
698       if (tree.isNestedDagArg(argIndex)) {
699         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
700       } else {
701         DagLeaf leaf = tree.getArgAsLeaf(argIndex);
702         auto symbol =
703             symbolInfoMap.getValueAndRangeUse(tree.getArgName(argIndex));
704         if (leaf.isNativeCodeCall()) {
705           os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
706         } else {
707           os << symbol;
708         }
709       }
710       os << ";\n";
711     }
712 
713     // Update to use the newly created local variable for building the op later.
714     childNodeNames[argIndex] = varName;
715   }
716 
717   // Then we create the builder call.
718 
719   // Right now we don't have general type inference in MLIR. Except a few
720   // special cases listed below, we need to supply types for all results
721   // when building an op.
722   bool isSameOperandsAndResultType =
723       resultOp.hasTrait("OpTrait::SameOperandsAndResultType");
724   bool isBroadcastable =
725       resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult");
726   bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType");
727   bool usePartialResults = valuePackName != resultValue;
728 
729   if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
730       usePartialResults || depth > 0 || resultIndex < 0) {
731     os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
732                             resultOp.getQualCppClassName());
733   } else {
734     // If depth == 0 and resultIndex >= 0, it means we are replacing the values
735     // generated from the source pattern root op. Then we can use the source
736     // pattern's value types to determine the value type of the generated op
737     // here.
738 
739     // We need to specify the types for all results.
740     int numResults = resultOp.getNumResults();
741     if (numResults != 0) {
742       os.indent(6) << "tblgen_types.clear();\n";
743       for (int i = 0; i < numResults; ++i) {
744         os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) "
745                                 "tblgen_types.push_back(v->getType());\n",
746                                 resultIndex + i);
747       }
748     }
749 
750     os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
751                             resultOp.getQualCppClassName());
752     if (numResults != 0)
753       os.indent(6) << ", tblgen_types";
754   }
755 
756   // Add operands for the builder all.
757   for (int i = 0; i < argIndex; ++i) {
758     const auto &operand = resultOp.getOperand(i);
759     // Start each operand on its own line.
760     (os << ",\n").indent(8);
761     if (!operand.name.empty()) {
762       os << "/*" << operand.name << "=*/";
763     }
764     os << childNodeNames[i];
765     // TODO(jpienaar): verify types
766   }
767 
768   // Add attributes for the builder call.
769   for (; argIndex != numOpArgs; ++argIndex) {
770     // Start each attribute on its own line.
771     (os << ",\n").indent(8);
772     // The argument in the op definition.
773     auto opArgName = resultOp.getArgName(argIndex);
774     if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
775       if (!subTree.isNativeCodeCall())
776         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
777                              "for creating attribute");
778       os << formatv("/*{0}=*/{1}", opArgName,
779                     handleReplaceWithNativeCodeCall(subTree));
780     } else {
781       auto leaf = tree.getArgAsLeaf(argIndex);
782       // The argument in the result DAG pattern.
783       auto patArgName = tree.getArgName(argIndex);
784       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
785         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
786         auto argument = resultOp.getArg(argIndex);
787         if (!argument.is<NamedAttribute *>())
788           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
789         if (!patArgName.empty())
790           os << "/*" << patArgName << "=*/";
791       } else {
792         os << "/*" << opArgName << "=*/";
793       }
794       os << handleOpArgument(leaf, patArgName);
795     }
796   }
797   os << "\n      );\n";
798   os.indent(4) << "}\n";
799 
800   return resultValue;
801 }
802 
803 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
804   emitSourceFileHeader("Rewriters", os);
805 
806   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
807   auto numPatterns = patterns.size();
808 
809   // We put the map here because it can be shared among multiple patterns.
810   RecordOperatorMap recordOpMap;
811 
812   std::vector<std::string> rewriterNames;
813   rewriterNames.reserve(numPatterns);
814 
815   std::string baseRewriterName = "GeneratedConvert";
816   int rewriterIndex = 0;
817 
818   for (Record *p : patterns) {
819     std::string name;
820     if (p->isAnonymous()) {
821       // If no name is provided, ensure unique rewriter names simply by
822       // appending unique suffix.
823       name = baseRewriterName + llvm::utostr(rewriterIndex++);
824     } else {
825       name = p->getName();
826     }
827     PatternEmitter(p, &recordOpMap, os).emit(name);
828     rewriterNames.push_back(std::move(name));
829   }
830 
831   // Emit function to add the generated matchers to the pattern list.
832   os << "void populateWithGenerated(MLIRContext *context, "
833      << "OwningRewritePatternList *patterns) {\n";
834   for (const auto &name : rewriterNames) {
835     os << "  patterns->insert<" << name << ">(context);\n";
836   }
837   os << "}\n";
838 }
839 
840 static mlir::GenRegistration
841     genRewriters("gen-rewriters", "Generate pattern rewriters",
842                  [](const RecordKeeper &records, raw_ostream &os) {
843                    emitRewriters(records, os);
844                    return false;
845                  });
846