xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision e62a69561fb9d7b1013d2853da68d79a7907fead)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Support/STLExtras.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/TableGen/Pattern.h"
19 #include "mlir/TableGen/Predicate.h"
20 #include "mlir/TableGen/Type.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringSet.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatAdapters.h"
26 #include "llvm/Support/PrettyStackTrace.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Main.h"
30 #include "llvm/TableGen/Record.h"
31 #include "llvm/TableGen/TableGenBackend.h"
32 
33 using namespace mlir;
34 using namespace mlir::tblgen;
35 
36 using llvm::formatv;
37 using llvm::Record;
38 using llvm::RecordKeeper;
39 
40 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
41 
42 namespace llvm {
43 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
44   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
45                      raw_ostream &os, StringRef style) {
46     os << v.first << ":" << v.second;
47   }
48 };
49 } // end namespace llvm
50 
51 //===----------------------------------------------------------------------===//
52 // PatternEmitter
53 //===----------------------------------------------------------------------===//
54 
55 namespace {
56 class PatternEmitter {
57 public:
58   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
59 
60   // Emits the mlir::RewritePattern struct named `rewriteName`.
61   void emit(StringRef rewriteName);
62 
63 private:
64   // Emits the code for matching ops.
65   void emitMatchLogic(DagNode tree);
66 
67   // Emits the code for rewriting ops.
68   void emitRewriteLogic();
69 
70   //===--------------------------------------------------------------------===//
71   // Match utilities
72   //===--------------------------------------------------------------------===//
73 
74   // Emits C++ statements for matching the op constrained by the given DAG
75   // `tree`.
76   void emitOpMatch(DagNode tree, int depth);
77 
78   // Emits C++ statements for matching the `argIndex`-th argument of the given
79   // DAG `tree` as an operand.
80   void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent);
81 
82   // Emits C++ statements for matching the `argIndex`-th argument of the given
83   // DAG `tree` as an attribute.
84   void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent);
85 
86   //===--------------------------------------------------------------------===//
87   // Rewrite utilities
88   //===--------------------------------------------------------------------===//
89 
90   // The entry point for handling a result pattern rooted at `resultTree`. This
91   // method dispatches to concrete handlers according to `resultTree`'s kind and
92   // returns a symbol representing the whole value pack. Callers are expected to
93   // further resolve the symbol according to the specific use case.
94   //
95   // `depth` is the nesting level of `resultTree`; 0 means top-level result
96   // pattern. For top-level result pattern, `resultIndex` indicates which result
97   // of the matched root op this pattern is intended to replace, which can be
98   // used to deduce the result type of the op generated from this result
99   // pattern.
100   std::string handleResultPattern(DagNode resultTree, int resultIndex,
101                                   int depth);
102 
103   // Emits the C++ statement to replace the matched DAG with a value built via
104   // calling native C++ code.
105   std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
106 
107   // Returns the C++ expression referencing the old value serving as the
108   // replacement.
109   std::string handleReplaceWithValue(DagNode tree);
110 
111   // Emits the C++ statement to build a new op out of the given DAG `tree` and
112   // returns the variable name that this op is assigned to. If the root op in
113   // DAG `tree` has a specified name, the created op will be assigned to a
114   // variable of the given name. Otherwise, a unique name will be used as the
115   // result value name.
116   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
117 
118   using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
119 
120   // Emits a local variable for each value and attribute to be used for creating
121   // an op.
122   void createSeparateLocalVarsForOpArgs(DagNode node,
123                                         ChildNodeIndexNameMap &childNodeNames);
124 
125   // Emits the concrete arguments used to call a op's builder.
126   void supplyValuesForOpArgs(DagNode node,
127                              const ChildNodeIndexNameMap &childNodeNames);
128 
129   // Emits the local variables for holding all values as a whole and all named
130   // attributes as a whole to be used for creating an op.
131   void createAggregateLocalVarsForOpArgs(
132       DagNode node, const ChildNodeIndexNameMap &childNodeNames);
133 
134   // Returns the C++ expression to construct a constant attribute of the given
135   // `value` for the given attribute kind `attr`.
136   std::string handleConstantAttr(Attribute attr, StringRef value);
137 
138   // Returns the C++ expression to build an argument from the given DAG `leaf`.
139   // `patArgName` is used to bound the argument to the source pattern.
140   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
141 
142   //===--------------------------------------------------------------------===//
143   // General utilities
144   //===--------------------------------------------------------------------===//
145 
146   // Collects all of the operations within the given dag tree.
147   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
148 
149   // Returns a unique symbol for a local variable of the given `op`.
150   std::string getUniqueSymbol(const Operator *op);
151 
152   //===--------------------------------------------------------------------===//
153   // Symbol utilities
154   //===--------------------------------------------------------------------===//
155 
156   // Returns how many static values the given DAG `node` correspond to.
157   int getNodeValueCount(DagNode node);
158 
159 private:
160   // Pattern instantiation location followed by the location of multiclass
161   // prototypes used. This is intended to be used as a whole to
162   // PrintFatalError() on errors.
163   ArrayRef<llvm::SMLoc> loc;
164 
165   // Op's TableGen Record to wrapper object.
166   RecordOperatorMap *opMap;
167 
168   // Handy wrapper for pattern being emitted.
169   Pattern pattern;
170 
171   // Map for all bound symbols' info.
172   SymbolInfoMap symbolInfoMap;
173 
174   // The next unused ID for newly created values.
175   unsigned nextValueId;
176 
177   raw_ostream &os;
178 
179   // Format contexts containing placeholder substitutions.
180   FmtContext fmtCtx;
181 
182   // Number of op processed.
183   int opCounter = 0;
184 };
185 } // end anonymous namespace
186 
187 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
188                                raw_ostream &os)
189     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
190       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
191   fmtCtx.withBuilder("rewriter");
192 }
193 
194 std::string PatternEmitter::handleConstantAttr(Attribute attr,
195                                                StringRef value) {
196   if (!attr.isConstBuildable())
197     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
198                              " does not have the 'constBuilderCall' field");
199 
200   // TODO(jpienaar): Verify the constants here
201   return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value);
202 }
203 
204 // Helper function to match patterns.
205 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
206   Operator &op = tree.getDialectOp(opMap);
207   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
208                           << op.getOperationName() << "' at depth " << depth
209                           << '\n');
210 
211   int indent = 4 + 2 * depth;
212   os.indent(indent) << formatv(
213       "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
214       depth, op.getQualCppClassName());
215   // Skip the operand matching at depth 0 as the pattern rewriter already does.
216   if (depth != 0) {
217     // Skip if there is no defining operation (e.g., arguments to function).
218     os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
219                                  depth);
220   }
221   if (tree.getNumArgs() != op.getNumArgs()) {
222     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
223                                  "pattern vs. {2} in definition",
224                                  op.getOperationName(), tree.getNumArgs(),
225                                  op.getNumArgs()));
226   }
227 
228   // If the operand's name is set, set to that variable.
229   auto name = tree.getSymbol();
230   if (!name.empty())
231     os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
232 
233   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
234     auto opArg = op.getArg(i);
235 
236     // Handle nested DAG construct first
237     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
238       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
239         if (operand->isVariadic()) {
240           auto error = formatv("use nested DAG construct to match op {0}'s "
241                                "variadic operand #{1} unsupported now",
242                                op.getOperationName(), i);
243           PrintFatalError(loc, error);
244         }
245       }
246       os.indent(indent) << "{\n";
247 
248       os.indent(indent + 2) << formatv(
249           "auto *op{0} = "
250           "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n",
251           depth + 1, depth, i);
252       emitOpMatch(argTree, depth + 1);
253       os.indent(indent + 2)
254           << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
255       os.indent(indent) << "}\n";
256       continue;
257     }
258 
259     // Next handle DAG leaf: operand or attribute
260     if (opArg.is<NamedTypeConstraint *>()) {
261       emitOperandMatch(tree, i, depth, indent);
262     } else if (opArg.is<NamedAttribute *>()) {
263       emitAttributeMatch(tree, i, depth, indent);
264     } else {
265       PrintFatalError(loc, "unhandled case when matching op");
266     }
267   }
268   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
269                           << op.getOperationName() << "' at depth " << depth
270                           << '\n');
271 }
272 
273 void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
274                                       int indent) {
275   Operator &op = tree.getDialectOp(opMap);
276   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
277   auto matcher = tree.getArgAsLeaf(argIndex);
278 
279   // If a constraint is specified, we need to generate C++ statements to
280   // check the constraint.
281   if (!matcher.isUnspecified()) {
282     if (!matcher.isOperandMatcher()) {
283       PrintFatalError(
284           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
285                        op.getOperationName(), argIndex + 1));
286     }
287 
288     // Only need to verify if the matcher's type is different from the one
289     // of op definition.
290     if (operand->constraint != matcher.getAsConstraint()) {
291       if (operand->isVariadic()) {
292         auto error = formatv(
293             "further constrain op {0}'s variadic operand #{1} unsupported now",
294             op.getOperationName(), argIndex);
295         PrintFatalError(loc, error);
296       }
297       auto self =
298           formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
299                   depth, argIndex);
300       os.indent(indent) << "if (!("
301                         << tgfmt(matcher.getConditionTemplate(),
302                                  &fmtCtx.withSelf(self))
303                         << ")) return matchFailure();\n";
304     }
305   }
306 
307   // Capture the value
308   auto name = tree.getArgName(argIndex);
309   // `$_` is a special symbol to ignore op argument matching.
310   if (!name.empty() && name != "_") {
311     // We need to subtract the number of attributes before this operand to get
312     // the index in the operand list.
313     auto numPrevAttrs = std::count_if(
314         op.arg_begin(), op.arg_begin() + argIndex,
315         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
316 
317     os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
318                                  name, depth, argIndex - numPrevAttrs);
319   }
320 }
321 
322 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
323                                         int indent) {
324 
325   Operator &op = tree.getDialectOp(opMap);
326   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
327   const auto &attr = namedAttr->attr;
328 
329   os.indent(indent) << "{\n";
330   indent += 2;
331   os.indent(indent) << formatv(
332       "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
333       attr.getStorageType(), namedAttr->name);
334 
335   // TODO(antiagainst): This should use getter method to avoid duplication.
336   if (attr.hasDefaultValue()) {
337     os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
338                       << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
339                                attr.getDefaultValue())
340                       << ";\n";
341   } else if (attr.isOptional()) {
342     // For a missing attribute that is optional according to definition, we
343     // should just capture a mlir::Attribute() to signal the missing state.
344     // That is precisely what getAttr() returns on missing attributes.
345   } else {
346     os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
347   }
348 
349   auto matcher = tree.getArgAsLeaf(argIndex);
350   if (!matcher.isUnspecified()) {
351     if (!matcher.isAttrMatcher()) {
352       PrintFatalError(
353           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
354                        op.getOperationName(), argIndex + 1));
355     }
356 
357     // If a constraint is specified, we need to generate C++ statements to
358     // check the constraint.
359     os.indent(indent) << "if (!("
360                       << tgfmt(matcher.getConditionTemplate(),
361                                &fmtCtx.withSelf("tblgen_attr"))
362                       << ")) return matchFailure();\n";
363   }
364 
365   // Capture the value
366   auto name = tree.getArgName(argIndex);
367   // `$_` is a special symbol to ignore op argument matching.
368   if (!name.empty() && name != "_") {
369     os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
370   }
371 
372   indent -= 2;
373   os.indent(indent) << "}\n";
374 }
375 
376 void PatternEmitter::emitMatchLogic(DagNode tree) {
377   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
378   emitOpMatch(tree, 0);
379 
380   for (auto &appliedConstraint : pattern.getConstraints()) {
381     auto &constraint = appliedConstraint.constraint;
382     auto &entities = appliedConstraint.entities;
383 
384     auto condition = constraint.getConditionTemplate();
385     auto cmd = "if (!({0})) return matchFailure();\n";
386 
387     if (isa<TypeConstraint>(constraint)) {
388       auto self = formatv("({0}->getType())",
389                           symbolInfoMap.getValueAndRangeUse(entities.front()));
390       os.indent(4) << formatv(cmd,
391                               tgfmt(condition, &fmtCtx.withSelf(self.str())));
392     } else if (isa<AttrConstraint>(constraint)) {
393       PrintFatalError(
394           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
395     } else {
396       // TODO(b/138794486): replace formatv arguments with the exact specified
397       // args.
398       if (entities.size() > 4) {
399         PrintFatalError(loc, "only support up to 4-entity constraints now");
400       }
401       SmallVector<std::string, 4> names;
402       int i = 0;
403       for (int e = entities.size(); i < e; ++i)
404         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
405       std::string self = appliedConstraint.self;
406       if (!self.empty())
407         self = symbolInfoMap.getValueAndRangeUse(self);
408       for (; i < 4; ++i)
409         names.push_back("<unused>");
410       os.indent(4) << formatv(cmd,
411                               tgfmt(condition, &fmtCtx.withSelf(self), names[0],
412                                     names[1], names[2], names[3]));
413     }
414   }
415   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
416 }
417 
418 void PatternEmitter::collectOps(DagNode tree,
419                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
420   // Check if this tree is an operation.
421   if (tree.isOperation()) {
422     const Operator &op = tree.getDialectOp(opMap);
423     LLVM_DEBUG(llvm::dbgs()
424                << "found operation " << op.getOperationName() << '\n');
425     ops.insert(&op);
426   }
427 
428   // Recurse the arguments of the tree.
429   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
430     if (auto child = tree.getArgAsNestedDag(i))
431       collectOps(child, ops);
432 }
433 
434 void PatternEmitter::emit(StringRef rewriteName) {
435   // Get the DAG tree for the source pattern.
436   DagNode sourceTree = pattern.getSourcePattern();
437 
438   const Operator &rootOp = pattern.getSourceRootOp();
439   auto rootName = rootOp.getOperationName();
440 
441   // Collect the set of result operations.
442   llvm::SmallPtrSet<const Operator *, 4> resultOps;
443   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
444   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
445     collectOps(pattern.getResultPattern(i), resultOps);
446   }
447   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
448 
449   // Emit RewritePattern for Pattern.
450   auto locs = pattern.getLocation();
451   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
452                 make_range(locs.rbegin(), locs.rend()));
453   os << formatv(R"(struct {0} : public RewritePattern {
454   {0}(MLIRContext *context)
455       : RewritePattern("{1}", {{)",
456                 rewriteName, rootName);
457   // Sort result operators by name.
458   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
459                                                          resultOps.end());
460   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
461     return lhs->getOperationName() < rhs->getOperationName();
462   });
463   interleaveComma(sortedResultOps, os, [&](const Operator *op) {
464     os << '"' << op->getOperationName() << '"';
465   });
466   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
467 
468   // Emit matchAndRewrite() function.
469   os << R"(
470   PatternMatchResult matchAndRewrite(Operation *op0,
471                                      PatternRewriter &rewriter) const override {
472 )";
473 
474   // Register all symbols bound in the source pattern.
475   pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
476 
477   LLVM_DEBUG(
478       llvm::dbgs() << "start creating local variables for capturing matches\n");
479   os.indent(4) << "// Variables for capturing values and attributes used for "
480                   "creating ops\n";
481   // Create local variables for storing the arguments and results bound
482   // to symbols.
483   for (const auto &symbolInfoPair : symbolInfoMap) {
484     StringRef symbol = symbolInfoPair.getKey();
485     auto &info = symbolInfoPair.getValue();
486     os.indent(4) << info.getVarDecl(symbol);
487   }
488   // TODO(jpienaar): capture ops with consistent numbering so that it can be
489   // reused for fused loc.
490   os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
491                           pattern.getSourcePattern().getNumOps());
492   LLVM_DEBUG(
493       llvm::dbgs() << "done creating local variables for capturing matches\n");
494 
495   os.indent(4) << "// Match\n";
496   os.indent(4) << "tblgen_ops[0] = op0;\n";
497   emitMatchLogic(sourceTree);
498   os << "\n";
499 
500   os.indent(4) << "// Rewrite\n";
501   emitRewriteLogic();
502 
503   os.indent(4) << "return matchSuccess();\n";
504   os << "  };\n";
505   os << "};\n";
506 }
507 
508 void PatternEmitter::emitRewriteLogic() {
509   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
510   const Operator &rootOp = pattern.getSourceRootOp();
511   int numExpectedResults = rootOp.getNumResults();
512   int numResultPatterns = pattern.getNumResultPatterns();
513 
514   // First register all symbols bound to ops generated in result patterns.
515   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
516 
517   // Only the last N static values generated are used to replace the matched
518   // root N-result op. We need to calculate the starting index (of the results
519   // of the matched op) each result pattern is to replace.
520   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
521   // If we don't need to replace any value at all, set the replacement starting
522   // index as the number of result patterns so we skip all of them when trying
523   // to replace the matched op's results.
524   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
525   for (int i = numResultPatterns - 1; i >= 0; --i) {
526     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
527     offsets[i] = offsets[i + 1] - numValues;
528     if (offsets[i] == 0) {
529       if (replStartIndex == -1)
530         replStartIndex = i;
531     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
532       auto error = formatv(
533           "cannot use the same multi-result op '{0}' to generate both "
534           "auxiliary values and values to be used for replacing the matched op",
535           pattern.getResultPattern(i).getSymbol());
536       PrintFatalError(loc, error);
537     }
538   }
539 
540   if (offsets.front() > 0) {
541     const char error[] = "no enough values generated to replace the matched op";
542     PrintFatalError(loc, error);
543   }
544 
545   os.indent(4) << "auto loc = rewriter.getFusedLoc({";
546   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
547     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
548   }
549   os << "}); (void)loc;\n";
550 
551   // Process auxiliary result patterns.
552   for (int i = 0; i < replStartIndex; ++i) {
553     DagNode resultTree = pattern.getResultPattern(i);
554     auto val = handleResultPattern(resultTree, offsets[i], 0);
555     // Normal op creation will be streamed to `os` by the above call; but
556     // NativeCodeCall will only be materialized to `os` if it is used. Here
557     // we are handling auxiliary patterns so we want the side effect even if
558     // NativeCodeCall is not replacing matched root op's results.
559     if (resultTree.isNativeCodeCall())
560       os.indent(4) << val << ";\n";
561   }
562 
563   if (numExpectedResults == 0) {
564     assert(replStartIndex >= numResultPatterns &&
565            "invalid auxiliary vs. replacement pattern division!");
566     // No result to replace. Just erase the op.
567     os.indent(4) << "rewriter.eraseOp(op0);\n";
568   } else {
569     // Process replacement result patterns.
570     os.indent(4) << "SmallVector<Value, 4> tblgen_repl_values;\n";
571     for (int i = replStartIndex; i < numResultPatterns; ++i) {
572       DagNode resultTree = pattern.getResultPattern(i);
573       auto val = handleResultPattern(resultTree, offsets[i], 0);
574       os.indent(4) << "\n";
575       // Resolve each symbol for all range use so that we can loop over them.
576       os << symbolInfoMap.getAllRangeUse(
577           val, "    for (auto v : {0}) {{ tblgen_repl_values.push_back(v); }",
578           "\n");
579     }
580     os.indent(4) << "\n";
581     os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n";
582   }
583 
584   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
585 }
586 
587 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
588   return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
589 }
590 
591 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
592                                                 int resultIndex, int depth) {
593   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
594   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
595   LLVM_DEBUG(llvm::dbgs() << '\n');
596 
597   if (resultTree.isNativeCodeCall()) {
598     auto symbol = handleReplaceWithNativeCodeCall(resultTree);
599     symbolInfoMap.bindValue(symbol);
600     return symbol;
601   }
602 
603   if (resultTree.isReplaceWithValue()) {
604     return handleReplaceWithValue(resultTree);
605   }
606 
607   // Normal op creation.
608   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
609   if (resultTree.getSymbol().empty()) {
610     // This is an op not explicitly bound to a symbol in the rewrite rule.
611     // Register the auto-generated symbol for it.
612     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
613   }
614   return symbol;
615 }
616 
617 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
618   assert(tree.isReplaceWithValue());
619 
620   if (tree.getNumArgs() != 1) {
621     PrintFatalError(
622         loc, "replaceWithValue directive must take exactly one argument");
623   }
624 
625   if (!tree.getSymbol().empty()) {
626     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
627   }
628 
629   return tree.getArgName(0);
630 }
631 
632 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
633                                              StringRef patArgName) {
634   if (leaf.isConstantAttr()) {
635     auto constAttr = leaf.getAsConstantAttr();
636     return handleConstantAttr(constAttr.getAttribute(),
637                               constAttr.getConstantValue());
638   }
639   if (leaf.isEnumAttrCase()) {
640     auto enumCase = leaf.getAsEnumAttrCase();
641     if (enumCase.isStrCase())
642       return handleConstantAttr(enumCase, enumCase.getSymbol());
643     // This is an enum case backed by an IntegerAttr. We need to get its value
644     // to build the constant.
645     std::string val = std::to_string(enumCase.getValue());
646     return handleConstantAttr(enumCase, val);
647   }
648 
649   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
650   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
651   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
652     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
653                             << "' (via symbol ref)\n");
654     return argName;
655   }
656   if (leaf.isNativeCodeCall()) {
657     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
658     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
659                             << "' (via NativeCodeCall)\n");
660     return repl;
661   }
662   PrintFatalError(loc, "unhandled case when rewriting op");
663 }
664 
665 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
666   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
667   LLVM_DEBUG(tree.print(llvm::dbgs()));
668   LLVM_DEBUG(llvm::dbgs() << '\n');
669 
670   auto fmt = tree.getNativeCodeTemplate();
671   // TODO(b/138794486): replace formatv arguments with the exact specified args.
672   SmallVector<std::string, 8> attrs(8);
673   if (tree.getNumArgs() > 8) {
674     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
675                              Twine(tree.getNumArgs()));
676   }
677   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
678     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
679     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
680                             << " replacement: " << attrs[i] << "\n");
681   }
682   return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
683                attrs[5], attrs[6], attrs[7]);
684 }
685 
686 int PatternEmitter::getNodeValueCount(DagNode node) {
687   if (node.isOperation()) {
688     // If the op is bound to a symbol in the rewrite rule, query its result
689     // count from the symbol info map.
690     auto symbol = node.getSymbol();
691     if (!symbol.empty()) {
692       return symbolInfoMap.getStaticValueCount(symbol);
693     }
694     // Otherwise this is an unbound op; we will use all its results.
695     return pattern.getDialectOp(node).getNumResults();
696   }
697   // TODO(antiagainst): This considers all NativeCodeCall as returning one
698   // value. Enhance if multi-value ones are needed.
699   return 1;
700 }
701 
702 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
703                                              int depth) {
704   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
705   LLVM_DEBUG(tree.print(llvm::dbgs()));
706   LLVM_DEBUG(llvm::dbgs() << '\n');
707 
708   Operator &resultOp = tree.getDialectOp(opMap);
709   auto numOpArgs = resultOp.getNumArgs();
710 
711   if (numOpArgs != tree.getNumArgs()) {
712     PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
713                                  "{1} in pattern vs. {2} in definition",
714                                  resultOp.getOperationName(), tree.getNumArgs(),
715                                  numOpArgs));
716   }
717 
718   // A map to collect all nested DAG child nodes' names, with operand index as
719   // the key. This includes both bound and unbound child nodes.
720   ChildNodeIndexNameMap childNodeNames;
721 
722   // First go through all the child nodes who are nested DAG constructs to
723   // create ops for them and remember the symbol names for them, so that we can
724   // use the results in the current node. This happens in a recursive manner.
725   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
726     if (auto child = tree.getArgAsNestedDag(i)) {
727       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
728     }
729   }
730 
731   // The name of the local variable holding this op.
732   std::string valuePackName;
733   // The symbol for holding the result of this pattern. Note that the result of
734   // this pattern is not necessarily the same as the variable created by this
735   // pattern because we can use `__N` suffix to refer only a specific result if
736   // the generated op is a multi-result op.
737   std::string resultValue;
738   if (tree.getSymbol().empty()) {
739     // No symbol is explicitly bound to this op in the pattern. Generate a
740     // unique name.
741     valuePackName = resultValue = getUniqueSymbol(&resultOp);
742   } else {
743     resultValue = tree.getSymbol();
744     // Strip the index to get the name for the value pack and use it to name the
745     // local variable for the op.
746     valuePackName = SymbolInfoMap::getValuePackName(resultValue);
747   }
748 
749   // Create the local variable for this op.
750   os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
751                           valuePackName);
752   os.indent(4) << "{\n";
753 
754   // Right now ODS don't have general type inference support. Except a few
755   // special cases listed below, DRR needs to supply types for all results
756   // when building an op.
757   bool isSameOperandsAndResultType =
758       resultOp.getTrait("OpTrait::SameOperandsAndResultType");
759   bool useFirstAttr = resultOp.getTrait("OpTrait::FirstAttrDerivedResultType");
760 
761   if (isSameOperandsAndResultType || useFirstAttr) {
762     // We know how to deduce the result type for ops with these traits and we've
763     // generated builders taking aggregate parameters. Use those builders to
764     // create the ops.
765 
766     // First prepare local variables for op arguments used in builder call.
767     createAggregateLocalVarsForOpArgs(tree, childNodeNames);
768     // Then create the op.
769     os.indent(6) << formatv(
770         "{0} = rewriter.create<{1}>(loc, tblgen_values, tblgen_attrs);\n",
771         valuePackName, resultOp.getQualCppClassName());
772     os.indent(4) << "}\n";
773     return resultValue;
774   }
775 
776   bool isBroadcastable =
777       resultOp.getTrait("OpTrait::BroadcastableTwoOperandsOneResult");
778   bool usePartialResults = valuePackName != resultValue;
779 
780   if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) {
781     // For these cases (broadcastable ops, op results used both as auxiliary
782     // values and replacement values, ops in nested patterns, auxiliary ops), we
783     // still need to supply the result types when building the op. But because
784     // we don't generate a builder automatically with ODS for them, it's the
785     // developer's responsiblity to make sure such a builder (with result type
786     // deduction ability) exists. We go through the separate-parameter builder
787     // here given that it's easier for developers to write compared to
788     // aggregate-parameter builders.
789     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
790     os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
791                             resultOp.getQualCppClassName());
792     supplyValuesForOpArgs(tree, childNodeNames);
793     os << "\n      );\n";
794     os.indent(4) << "}\n";
795     return resultValue;
796   }
797 
798   // If depth == 0 and resultIndex >= 0, it means we are replacing the values
799   // generated from the source pattern root op. Then we can use the source
800   // pattern's value types to determine the value type of the generated op
801   // here.
802 
803   // First prepare local variables for op arguments used in builder call.
804   createAggregateLocalVarsForOpArgs(tree, childNodeNames);
805 
806   // Then prepare the result types. We need to specify the types for all
807   // results.
808   os.indent(6) << formatv(
809       "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n");
810   int numResults = resultOp.getNumResults();
811   if (numResults != 0) {
812     for (int i = 0; i < numResults; ++i)
813       os.indent(6) << formatv("for (auto v : castedOp0.getODSResults({0})) {{"
814                               "tblgen_types.push_back(v->getType()); }\n",
815                               resultIndex + i);
816   }
817   os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc, tblgen_types, "
818                           "tblgen_values, tblgen_attrs);\n",
819                           valuePackName, resultOp.getQualCppClassName());
820   os.indent(4) << "}\n";
821   return resultValue;
822 }
823 
824 void PatternEmitter::createSeparateLocalVarsForOpArgs(
825     DagNode node, ChildNodeIndexNameMap &childNodeNames) {
826   Operator &resultOp = node.getDialectOp(opMap);
827 
828   // Now prepare operands used for building this op:
829   // * If the operand is non-variadic, we create a `Value` local variable.
830   // * If the operand is variadic, we create a `SmallVector<Value>` local
831   //   variable.
832 
833   int valueIndex = 0; // An index for uniquing local variable names.
834   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
835     const auto *operand =
836         resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
837     if (!operand) {
838       // We do not need special handling for attributes.
839       continue;
840     }
841 
842     std::string varName;
843     if (operand->isVariadic()) {
844       varName = formatv("tblgen_values_{0}", valueIndex++);
845       os.indent(6) << formatv("SmallVector<Value, 4> {0};\n", varName);
846       std::string range;
847       if (node.isNestedDagArg(argIndex)) {
848         range = childNodeNames[argIndex];
849       } else {
850         range = node.getArgName(argIndex);
851       }
852       // Resolve the symbol for all range use so that we have a uniform way of
853       // capturing the values.
854       range = symbolInfoMap.getValueAndRangeUse(range);
855       os.indent(6) << formatv("for (auto v : {0}) {1}.push_back(v);\n", range,
856                               varName);
857     } else {
858       varName = formatv("tblgen_value_{0}", valueIndex++);
859       os.indent(6) << formatv("Value {0} = ", varName);
860       if (node.isNestedDagArg(argIndex)) {
861         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
862       } else {
863         DagLeaf leaf = node.getArgAsLeaf(argIndex);
864         auto symbol =
865             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
866         if (leaf.isNativeCodeCall()) {
867           os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
868         } else {
869           os << symbol;
870         }
871       }
872       os << ";\n";
873     }
874 
875     // Update to use the newly created local variable for building the op later.
876     childNodeNames[argIndex] = varName;
877   }
878 }
879 
880 void PatternEmitter::supplyValuesForOpArgs(
881     DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
882   Operator &resultOp = node.getDialectOp(opMap);
883   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
884        argIndex != numOpArgs; ++argIndex) {
885     // Start each argument on its own line.
886     (os << ",\n").indent(8);
887 
888     Argument opArg = resultOp.getArg(argIndex);
889     // Handle the case of operand first.
890     if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
891       if (!operand->name.empty())
892         os << "/*" << operand->name << "=*/";
893       os << childNodeNames.lookup(argIndex);
894       continue;
895     }
896 
897     // The argument in the op definition.
898     auto opArgName = resultOp.getArgName(argIndex);
899     if (auto subTree = node.getArgAsNestedDag(argIndex)) {
900       if (!subTree.isNativeCodeCall())
901         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
902                              "for creating attribute");
903       os << formatv("/*{0}=*/{1}", opArgName,
904                     handleReplaceWithNativeCodeCall(subTree));
905     } else {
906       auto leaf = node.getArgAsLeaf(argIndex);
907       // The argument in the result DAG pattern.
908       auto patArgName = node.getArgName(argIndex);
909       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
910         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
911         if (!opArg.is<NamedAttribute *>())
912           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
913         if (!patArgName.empty())
914           os << "/*" << patArgName << "=*/";
915       } else {
916         os << "/*" << opArgName << "=*/";
917       }
918       os << handleOpArgument(leaf, patArgName);
919     }
920   }
921 }
922 
923 void PatternEmitter::createAggregateLocalVarsForOpArgs(
924     DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
925   Operator &resultOp = node.getDialectOp(opMap);
926 
927   os.indent(6) << formatv(
928       "SmallVector<Value, 4> tblgen_values; (void)tblgen_values;\n");
929   os.indent(6) << formatv(
930       "SmallVector<NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs;\n");
931 
932   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
933     if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
934       const char *addAttrCmd = "if ({1}) {{"
935                                "  tblgen_attrs.emplace_back(rewriter."
936                                "getIdentifier(\"{0}\"), {1}); }\n";
937       // The argument in the op definition.
938       auto opArgName = resultOp.getArgName(argIndex);
939       if (auto subTree = node.getArgAsNestedDag(argIndex)) {
940         if (!subTree.isNativeCodeCall())
941           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
942                                "for creating attribute");
943         os.indent(6) << formatv(addAttrCmd, opArgName,
944                                 handleReplaceWithNativeCodeCall(subTree));
945       } else {
946         auto leaf = node.getArgAsLeaf(argIndex);
947         // The argument in the result DAG pattern.
948         auto patArgName = node.getArgName(argIndex);
949         os.indent(6) << formatv(addAttrCmd, opArgName,
950                                 handleOpArgument(leaf, patArgName));
951       }
952       continue;
953     }
954 
955     const auto *operand =
956         resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
957     std::string varName;
958     if (operand->isVariadic()) {
959       std::string range;
960       if (node.isNestedDagArg(argIndex)) {
961         range = childNodeNames.lookup(argIndex);
962       } else {
963         range = node.getArgName(argIndex);
964       }
965       // Resolve the symbol for all range use so that we have a uniform way of
966       // capturing the values.
967       range = symbolInfoMap.getValueAndRangeUse(range);
968       os.indent(6) << formatv(
969           "for (auto v : {0}) tblgen_values.push_back(v);\n", range);
970     } else {
971       os.indent(6) << formatv("tblgen_values.push_back(", varName);
972       if (node.isNestedDagArg(argIndex)) {
973         os << symbolInfoMap.getValueAndRangeUse(
974             childNodeNames.lookup(argIndex));
975       } else {
976         DagLeaf leaf = node.getArgAsLeaf(argIndex);
977         auto symbol =
978             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
979         if (leaf.isNativeCodeCall()) {
980           os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
981         } else {
982           os << symbol;
983         }
984       }
985       os << ");\n";
986     }
987   }
988 }
989 
990 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
991   emitSourceFileHeader("Rewriters", os);
992 
993   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
994   auto numPatterns = patterns.size();
995 
996   // We put the map here because it can be shared among multiple patterns.
997   RecordOperatorMap recordOpMap;
998 
999   std::vector<std::string> rewriterNames;
1000   rewriterNames.reserve(numPatterns);
1001 
1002   std::string baseRewriterName = "GeneratedConvert";
1003   int rewriterIndex = 0;
1004 
1005   for (Record *p : patterns) {
1006     std::string name;
1007     if (p->isAnonymous()) {
1008       // If no name is provided, ensure unique rewriter names simply by
1009       // appending unique suffix.
1010       name = baseRewriterName + llvm::utostr(rewriterIndex++);
1011     } else {
1012       name = p->getName();
1013     }
1014     LLVM_DEBUG(llvm::dbgs()
1015                << "=== start generating pattern '" << name << "' ===\n");
1016     PatternEmitter(p, &recordOpMap, os).emit(name);
1017     LLVM_DEBUG(llvm::dbgs()
1018                << "=== done generating pattern '" << name << "' ===\n");
1019     rewriterNames.push_back(std::move(name));
1020   }
1021 
1022   // Emit function to add the generated matchers to the pattern list.
1023   os << "void populateWithGenerated(MLIRContext *context, "
1024      << "OwningRewritePatternList *patterns) {\n";
1025   for (const auto &name : rewriterNames) {
1026     os << "  patterns->insert<" << name << ">(context);\n";
1027   }
1028   os << "}\n";
1029 }
1030 
1031 static mlir::GenRegistration
1032     genRewriters("gen-rewriters", "Generate pattern rewriters",
1033                  [](const RecordKeeper &records, raw_ostream &os) {
1034                    emitRewriters(records, os);
1035                    return false;
1036                  });
1037