xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 84a6182ddd62a2ca8eee2d8470e3be1ef6147fce)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Support/STLExtras.h"
23 #include "mlir/TableGen/Attribute.h"
24 #include "mlir/TableGen/Format.h"
25 #include "mlir/TableGen/GenInfo.h"
26 #include "mlir/TableGen/Operator.h"
27 #include "mlir/TableGen/Pattern.h"
28 #include "mlir/TableGen/Predicate.h"
29 #include "mlir/TableGen/Type.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/FormatAdapters.h"
35 #include "llvm/Support/PrettyStackTrace.h"
36 #include "llvm/Support/Signals.h"
37 #include "llvm/TableGen/Error.h"
38 #include "llvm/TableGen/Main.h"
39 #include "llvm/TableGen/Record.h"
40 #include "llvm/TableGen/TableGenBackend.h"
41 
42 using namespace mlir;
43 using namespace mlir::tblgen;
44 
45 using llvm::formatv;
46 using llvm::Record;
47 using llvm::RecordKeeper;
48 
49 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
50 
51 namespace llvm {
52 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
53   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
54                      raw_ostream &os, StringRef style) {
55     os << v.first << ":" << v.second;
56   }
57 };
58 } // end namespace llvm
59 
60 //===----------------------------------------------------------------------===//
61 // PatternEmitter
62 //===----------------------------------------------------------------------===//
63 
64 namespace {
65 class PatternEmitter {
66 public:
67   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
68 
69   // Emits the mlir::RewritePattern struct named `rewriteName`.
70   void emit(StringRef rewriteName);
71 
72 private:
73   // Emits the code for matching ops.
74   void emitMatchLogic(DagNode tree);
75 
76   // Emits the code for rewriting ops.
77   void emitRewriteLogic();
78 
79   //===--------------------------------------------------------------------===//
80   // Match utilities
81   //===--------------------------------------------------------------------===//
82 
83   // Emits C++ statements for matching the op constrained by the given DAG
84   // `tree`.
85   void emitOpMatch(DagNode tree, int depth);
86 
87   // Emits C++ statements for matching the `argIndex`-th argument of the given
88   // DAG `tree` as an operand.
89   void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent);
90 
91   // Emits C++ statements for matching the `argIndex`-th argument of the given
92   // DAG `tree` as an attribute.
93   void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent);
94 
95   //===--------------------------------------------------------------------===//
96   // Rewrite utilities
97   //===--------------------------------------------------------------------===//
98 
99   // The entry point for handling a result pattern rooted at `resultTree`. This
100   // method dispatches to concrete handlers according to `resultTree`'s kind and
101   // returns a symbol representing the whole value pack. Callers are expected to
102   // further resolve the symbol according to the specific use case.
103   //
104   // `depth` is the nesting level of `resultTree`; 0 means top-level result
105   // pattern. For top-level result pattern, `resultIndex` indicates which result
106   // of the matched root op this pattern is intended to replace, which can be
107   // used to deduce the result type of the op generated from this result
108   // pattern.
109   std::string handleResultPattern(DagNode resultTree, int resultIndex,
110                                   int depth);
111 
112   // Emits the C++ statement to replace the matched DAG with a value built via
113   // calling native C++ code.
114   std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
115 
116   // Returns the C++ expression referencing the old value serving as the
117   // replacement.
118   std::string handleReplaceWithValue(DagNode tree);
119 
120   // Emits the C++ statement to build a new op out of the given DAG `tree` and
121   // returns the variable name that this op is assigned to. If the root op in
122   // DAG `tree` has a specified name, the created op will be assigned to a
123   // variable of the given name. Otherwise, a unique name will be used as the
124   // result value name.
125   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
126 
127   using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
128 
129   // Emits a local variable for each value and attribute to be used for creating
130   // an op.
131   void createSeparateLocalVarsForOpArgs(DagNode node,
132                                         ChildNodeIndexNameMap &childNodeNames);
133 
134   // Emits the concrete arguments used to call a op's builder.
135   void supplyValuesForOpArgs(DagNode node,
136                              const ChildNodeIndexNameMap &childNodeNames);
137 
138   // Emits the local variables for holding all values as a whole and all named
139   // attributes as a whole to be used for creating an op.
140   void createAggregateLocalVarsForOpArgs(
141       DagNode node, const ChildNodeIndexNameMap &childNodeNames);
142 
143   // Returns the C++ expression to construct a constant attribute of the given
144   // `value` for the given attribute kind `attr`.
145   std::string handleConstantAttr(Attribute attr, StringRef value);
146 
147   // Returns the C++ expression to build an argument from the given DAG `leaf`.
148   // `patArgName` is used to bound the argument to the source pattern.
149   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
150 
151   //===--------------------------------------------------------------------===//
152   // General utilities
153   //===--------------------------------------------------------------------===//
154 
155   // Collects all of the operations within the given dag tree.
156   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
157 
158   // Returns a unique symbol for a local variable of the given `op`.
159   std::string getUniqueSymbol(const Operator *op);
160 
161   //===--------------------------------------------------------------------===//
162   // Symbol utilities
163   //===--------------------------------------------------------------------===//
164 
165   // Returns how many static values the given DAG `node` correspond to.
166   int getNodeValueCount(DagNode node);
167 
168 private:
169   // Pattern instantiation location followed by the location of multiclass
170   // prototypes used. This is intended to be used as a whole to
171   // PrintFatalError() on errors.
172   ArrayRef<llvm::SMLoc> loc;
173 
174   // Op's TableGen Record to wrapper object.
175   RecordOperatorMap *opMap;
176 
177   // Handy wrapper for pattern being emitted.
178   Pattern pattern;
179 
180   // Map for all bound symbols' info.
181   SymbolInfoMap symbolInfoMap;
182 
183   // The next unused ID for newly created values.
184   unsigned nextValueId;
185 
186   raw_ostream &os;
187 
188   // Format contexts containing placeholder substitutions.
189   FmtContext fmtCtx;
190 
191   // Number of op processed.
192   int opCounter = 0;
193 };
194 } // end anonymous namespace
195 
196 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
197                                raw_ostream &os)
198     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
199       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
200   fmtCtx.withBuilder("rewriter");
201 }
202 
203 std::string PatternEmitter::handleConstantAttr(Attribute attr,
204                                                StringRef value) {
205   if (!attr.isConstBuildable())
206     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
207                              " does not have the 'constBuilderCall' field");
208 
209   // TODO(jpienaar): Verify the constants here
210   return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value);
211 }
212 
213 // Helper function to match patterns.
214 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
215   Operator &op = tree.getDialectOp(opMap);
216   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
217                           << op.getOperationName() << "' at depth " << depth
218                           << '\n');
219 
220   int indent = 4 + 2 * depth;
221   os.indent(indent) << formatv(
222       "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
223       depth, op.getQualCppClassName());
224   // Skip the operand matching at depth 0 as the pattern rewriter already does.
225   if (depth != 0) {
226     // Skip if there is no defining operation (e.g., arguments to function).
227     os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
228                                  depth);
229   }
230   if (tree.getNumArgs() != op.getNumArgs()) {
231     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
232                                  "pattern vs. {2} in definition",
233                                  op.getOperationName(), tree.getNumArgs(),
234                                  op.getNumArgs()));
235   }
236 
237   // If the operand's name is set, set to that variable.
238   auto name = tree.getSymbol();
239   if (!name.empty())
240     os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
241 
242   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
243     auto opArg = op.getArg(i);
244 
245     // Handle nested DAG construct first
246     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
247       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
248         if (operand->isVariadic()) {
249           auto error = formatv("use nested DAG construct to match op {0}'s "
250                                "variadic operand #{1} unsupported now",
251                                op.getOperationName(), i);
252           PrintFatalError(loc, error);
253         }
254       }
255       os.indent(indent) << "{\n";
256 
257       os.indent(indent + 2) << formatv(
258           "auto *op{0} = "
259           "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n",
260           depth + 1, depth, i);
261       emitOpMatch(argTree, depth + 1);
262       os.indent(indent + 2)
263           << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
264       os.indent(indent) << "}\n";
265       continue;
266     }
267 
268     // Next handle DAG leaf: operand or attribute
269     if (opArg.is<NamedTypeConstraint *>()) {
270       emitOperandMatch(tree, i, depth, indent);
271     } else if (opArg.is<NamedAttribute *>()) {
272       emitAttributeMatch(tree, i, depth, indent);
273     } else {
274       PrintFatalError(loc, "unhandled case when matching op");
275     }
276   }
277   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
278                           << op.getOperationName() << "' at depth " << depth
279                           << '\n');
280 }
281 
282 void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
283                                       int indent) {
284   Operator &op = tree.getDialectOp(opMap);
285   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
286   auto matcher = tree.getArgAsLeaf(argIndex);
287 
288   // If a constraint is specified, we need to generate C++ statements to
289   // check the constraint.
290   if (!matcher.isUnspecified()) {
291     if (!matcher.isOperandMatcher()) {
292       PrintFatalError(
293           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
294                        op.getOperationName(), argIndex + 1));
295     }
296 
297     // Only need to verify if the matcher's type is different from the one
298     // of op definition.
299     if (operand->constraint != matcher.getAsConstraint()) {
300       if (operand->isVariadic()) {
301         auto error = formatv(
302             "further constrain op {0}'s variadic operand #{1} unsupported now",
303             op.getOperationName(), argIndex);
304         PrintFatalError(loc, error);
305       }
306       auto self =
307           formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
308                   depth, argIndex);
309       os.indent(indent) << "if (!("
310                         << tgfmt(matcher.getConditionTemplate(),
311                                  &fmtCtx.withSelf(self))
312                         << ")) return matchFailure();\n";
313     }
314   }
315 
316   // Capture the value
317   auto name = tree.getArgName(argIndex);
318   // `$_` is a special symbol to ignore op argument matching.
319   if (!name.empty() && name != "_") {
320     // We need to subtract the number of attributes before this operand to get
321     // the index in the operand list.
322     auto numPrevAttrs = std::count_if(
323         op.arg_begin(), op.arg_begin() + argIndex,
324         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
325 
326     os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
327                                  name, depth, argIndex - numPrevAttrs);
328   }
329 }
330 
331 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
332                                         int indent) {
333 
334   Operator &op = tree.getDialectOp(opMap);
335   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
336   const auto &attr = namedAttr->attr;
337 
338   os.indent(indent) << "{\n";
339   indent += 2;
340   os.indent(indent) << formatv(
341       "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
342       attr.getStorageType(), namedAttr->name);
343 
344   // TODO(antiagainst): This should use getter method to avoid duplication.
345   if (attr.hasDefaultValue()) {
346     os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
347                       << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
348                                attr.getDefaultValue())
349                       << ";\n";
350   } else if (attr.isOptional()) {
351     // For a missing attribute that is optional according to definition, we
352     // should just capture a mlir::Attribute() to signal the missing state.
353     // That is precisely what getAttr() returns on missing attributes.
354   } else {
355     os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
356   }
357 
358   auto matcher = tree.getArgAsLeaf(argIndex);
359   if (!matcher.isUnspecified()) {
360     if (!matcher.isAttrMatcher()) {
361       PrintFatalError(
362           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
363                        op.getOperationName(), argIndex + 1));
364     }
365 
366     // If a constraint is specified, we need to generate C++ statements to
367     // check the constraint.
368     os.indent(indent) << "if (!("
369                       << tgfmt(matcher.getConditionTemplate(),
370                                &fmtCtx.withSelf("tblgen_attr"))
371                       << ")) return matchFailure();\n";
372   }
373 
374   // Capture the value
375   auto name = tree.getArgName(argIndex);
376   // `$_` is a special symbol to ignore op argument matching.
377   if (!name.empty() && name != "_") {
378     os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
379   }
380 
381   indent -= 2;
382   os.indent(indent) << "}\n";
383 }
384 
385 void PatternEmitter::emitMatchLogic(DagNode tree) {
386   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
387   emitOpMatch(tree, 0);
388 
389   for (auto &appliedConstraint : pattern.getConstraints()) {
390     auto &constraint = appliedConstraint.constraint;
391     auto &entities = appliedConstraint.entities;
392 
393     auto condition = constraint.getConditionTemplate();
394     auto cmd = "if (!({0})) return matchFailure();\n";
395 
396     if (isa<TypeConstraint>(constraint)) {
397       auto self = formatv("({0}->getType())",
398                           symbolInfoMap.getValueAndRangeUse(entities.front()));
399       os.indent(4) << formatv(cmd,
400                               tgfmt(condition, &fmtCtx.withSelf(self.str())));
401     } else if (isa<AttrConstraint>(constraint)) {
402       PrintFatalError(
403           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
404     } else {
405       // TODO(b/138794486): replace formatv arguments with the exact specified
406       // args.
407       if (entities.size() > 4) {
408         PrintFatalError(loc, "only support up to 4-entity constraints now");
409       }
410       SmallVector<std::string, 4> names;
411       int i = 0;
412       for (int e = entities.size(); i < e; ++i)
413         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
414       std::string self = appliedConstraint.self;
415       if (!self.empty())
416         self = symbolInfoMap.getValueAndRangeUse(self);
417       for (; i < 4; ++i)
418         names.push_back("<unused>");
419       os.indent(4) << formatv(cmd,
420                               tgfmt(condition, &fmtCtx.withSelf(self), names[0],
421                                     names[1], names[2], names[3]));
422     }
423   }
424   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
425 }
426 
427 void PatternEmitter::collectOps(DagNode tree,
428                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
429   // Check if this tree is an operation.
430   if (tree.isOperation()) {
431     const Operator &op = tree.getDialectOp(opMap);
432     LLVM_DEBUG(llvm::dbgs()
433                << "found operation " << op.getOperationName() << '\n');
434     ops.insert(&op);
435   }
436 
437   // Recurse the arguments of the tree.
438   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
439     if (auto child = tree.getArgAsNestedDag(i))
440       collectOps(child, ops);
441 }
442 
443 void PatternEmitter::emit(StringRef rewriteName) {
444   // Get the DAG tree for the source pattern.
445   DagNode sourceTree = pattern.getSourcePattern();
446 
447   const Operator &rootOp = pattern.getSourceRootOp();
448   auto rootName = rootOp.getOperationName();
449 
450   // Collect the set of result operations.
451   llvm::SmallPtrSet<const Operator *, 4> resultOps;
452   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
453   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
454     collectOps(pattern.getResultPattern(i), resultOps);
455   }
456   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
457 
458   // Emit RewritePattern for Pattern.
459   auto locs = pattern.getLocation();
460   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
461                 make_range(locs.rbegin(), locs.rend()));
462   os << formatv(R"(struct {0} : public RewritePattern {
463   {0}(MLIRContext *context)
464       : RewritePattern("{1}", {{)",
465                 rewriteName, rootName);
466   // Sort result operators by name.
467   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
468                                                          resultOps.end());
469   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
470     return lhs->getOperationName() < rhs->getOperationName();
471   });
472   interleaveComma(sortedResultOps, os, [&](const Operator *op) {
473     os << '"' << op->getOperationName() << '"';
474   });
475   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
476 
477   // Emit matchAndRewrite() function.
478   os << R"(
479   PatternMatchResult matchAndRewrite(Operation *op0,
480                                      PatternRewriter &rewriter) const override {
481 )";
482 
483   // Register all symbols bound in the source pattern.
484   pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
485 
486   LLVM_DEBUG(
487       llvm::dbgs() << "start creating local variables for capturing matches\n");
488   os.indent(4) << "// Variables for capturing values and attributes used for "
489                   "creating ops\n";
490   // Create local variables for storing the arguments and results bound
491   // to symbols.
492   for (const auto &symbolInfoPair : symbolInfoMap) {
493     StringRef symbol = symbolInfoPair.getKey();
494     auto &info = symbolInfoPair.getValue();
495     os.indent(4) << info.getVarDecl(symbol);
496   }
497   // TODO(jpienaar): capture ops with consistent numbering so that it can be
498   // reused for fused loc.
499   os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
500                           pattern.getSourcePattern().getNumOps());
501   LLVM_DEBUG(
502       llvm::dbgs() << "done creating local variables for capturing matches\n");
503 
504   os.indent(4) << "// Match\n";
505   os.indent(4) << "tblgen_ops[0] = op0;\n";
506   emitMatchLogic(sourceTree);
507   os << "\n";
508 
509   os.indent(4) << "// Rewrite\n";
510   emitRewriteLogic();
511 
512   os.indent(4) << "return matchSuccess();\n";
513   os << "  };\n";
514   os << "};\n";
515 }
516 
517 void PatternEmitter::emitRewriteLogic() {
518   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
519   const Operator &rootOp = pattern.getSourceRootOp();
520   int numExpectedResults = rootOp.getNumResults();
521   int numResultPatterns = pattern.getNumResultPatterns();
522 
523   // First register all symbols bound to ops generated in result patterns.
524   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
525 
526   // Only the last N static values generated are used to replace the matched
527   // root N-result op. We need to calculate the starting index (of the results
528   // of the matched op) each result pattern is to replace.
529   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
530   // If we don't need to replace any value at all, set the replacement starting
531   // index as the number of result patterns so we skip all of them when trying
532   // to replace the matched op's results.
533   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
534   for (int i = numResultPatterns - 1; i >= 0; --i) {
535     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
536     offsets[i] = offsets[i + 1] - numValues;
537     if (offsets[i] == 0) {
538       if (replStartIndex == -1)
539         replStartIndex = i;
540     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
541       auto error = formatv(
542           "cannot use the same multi-result op '{0}' to generate both "
543           "auxiliary values and values to be used for replacing the matched op",
544           pattern.getResultPattern(i).getSymbol());
545       PrintFatalError(loc, error);
546     }
547   }
548 
549   if (offsets.front() > 0) {
550     const char error[] = "no enough values generated to replace the matched op";
551     PrintFatalError(loc, error);
552   }
553 
554   os.indent(4) << "auto loc = rewriter.getFusedLoc({";
555   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
556     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
557   }
558   os << "}); (void)loc;\n";
559 
560   // Process auxiliary result patterns.
561   for (int i = 0; i < replStartIndex; ++i) {
562     DagNode resultTree = pattern.getResultPattern(i);
563     auto val = handleResultPattern(resultTree, offsets[i], 0);
564     // Normal op creation will be streamed to `os` by the above call; but
565     // NativeCodeCall will only be materialized to `os` if it is used. Here
566     // we are handling auxiliary patterns so we want the side effect even if
567     // NativeCodeCall is not replacing matched root op's results.
568     if (resultTree.isNativeCodeCall())
569       os.indent(4) << val << ";\n";
570   }
571 
572   if (numExpectedResults == 0) {
573     assert(replStartIndex >= numResultPatterns &&
574            "invalid auxiliary vs. replacement pattern division!");
575     // No result to replace. Just erase the op.
576     os.indent(4) << "rewriter.eraseOp(op0);\n";
577   } else {
578     // Process replacement result patterns.
579     os.indent(4) << "SmallVector<Value *, 4> tblgen_repl_values;\n";
580     for (int i = replStartIndex; i < numResultPatterns; ++i) {
581       DagNode resultTree = pattern.getResultPattern(i);
582       auto val = handleResultPattern(resultTree, offsets[i], 0);
583       os.indent(4) << "\n";
584       // Resolve each symbol for all range use so that we can loop over them.
585       os << symbolInfoMap.getAllRangeUse(
586           val, "    for (auto *v : {0}) {{ tblgen_repl_values.push_back(v); }",
587           "\n");
588     }
589     os.indent(4) << "\n";
590     os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n";
591   }
592 
593   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
594 }
595 
596 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
597   return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
598 }
599 
600 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
601                                                 int resultIndex, int depth) {
602   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
603   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
604   LLVM_DEBUG(llvm::dbgs() << '\n');
605 
606   if (resultTree.isNativeCodeCall()) {
607     auto symbol = handleReplaceWithNativeCodeCall(resultTree);
608     symbolInfoMap.bindValue(symbol);
609     return symbol;
610   }
611 
612   if (resultTree.isReplaceWithValue()) {
613     return handleReplaceWithValue(resultTree);
614   }
615 
616   // Normal op creation.
617   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
618   if (resultTree.getSymbol().empty()) {
619     // This is an op not explicitly bound to a symbol in the rewrite rule.
620     // Register the auto-generated symbol for it.
621     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
622   }
623   return symbol;
624 }
625 
626 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
627   assert(tree.isReplaceWithValue());
628 
629   if (tree.getNumArgs() != 1) {
630     PrintFatalError(
631         loc, "replaceWithValue directive must take exactly one argument");
632   }
633 
634   if (!tree.getSymbol().empty()) {
635     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
636   }
637 
638   return tree.getArgName(0);
639 }
640 
641 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
642                                              StringRef patArgName) {
643   if (leaf.isConstantAttr()) {
644     auto constAttr = leaf.getAsConstantAttr();
645     return handleConstantAttr(constAttr.getAttribute(),
646                               constAttr.getConstantValue());
647   }
648   if (leaf.isEnumAttrCase()) {
649     auto enumCase = leaf.getAsEnumAttrCase();
650     if (enumCase.isStrCase())
651       return handleConstantAttr(enumCase, enumCase.getSymbol());
652     // This is an enum case backed by an IntegerAttr. We need to get its value
653     // to build the constant.
654     std::string val = std::to_string(enumCase.getValue());
655     return handleConstantAttr(enumCase, val);
656   }
657 
658   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
659   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
660   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
661     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
662                             << "' (via symbol ref)\n");
663     return argName;
664   }
665   if (leaf.isNativeCodeCall()) {
666     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
667     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
668                             << "' (via NativeCodeCall)\n");
669     return repl;
670   }
671   PrintFatalError(loc, "unhandled case when rewriting op");
672 }
673 
674 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
675   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
676   LLVM_DEBUG(tree.print(llvm::dbgs()));
677   LLVM_DEBUG(llvm::dbgs() << '\n');
678 
679   auto fmt = tree.getNativeCodeTemplate();
680   // TODO(b/138794486): replace formatv arguments with the exact specified args.
681   SmallVector<std::string, 8> attrs(8);
682   if (tree.getNumArgs() > 8) {
683     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
684                              Twine(tree.getNumArgs()));
685   }
686   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
687     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
688     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
689                             << " replacement: " << attrs[i] << "\n");
690   }
691   return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
692                attrs[5], attrs[6], attrs[7]);
693 }
694 
695 int PatternEmitter::getNodeValueCount(DagNode node) {
696   if (node.isOperation()) {
697     // If the op is bound to a symbol in the rewrite rule, query its result
698     // count from the symbol info map.
699     auto symbol = node.getSymbol();
700     if (!symbol.empty()) {
701       return symbolInfoMap.getStaticValueCount(symbol);
702     }
703     // Otherwise this is an unbound op; we will use all its results.
704     return pattern.getDialectOp(node).getNumResults();
705   }
706   // TODO(antiagainst): This considers all NativeCodeCall as returning one
707   // value. Enhance if multi-value ones are needed.
708   return 1;
709 }
710 
711 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
712                                              int depth) {
713   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
714   LLVM_DEBUG(tree.print(llvm::dbgs()));
715   LLVM_DEBUG(llvm::dbgs() << '\n');
716 
717   Operator &resultOp = tree.getDialectOp(opMap);
718   auto numOpArgs = resultOp.getNumArgs();
719 
720   if (numOpArgs != tree.getNumArgs()) {
721     PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
722                                  "{1} in pattern vs. {2} in definition",
723                                  resultOp.getOperationName(), tree.getNumArgs(),
724                                  numOpArgs));
725   }
726 
727   // A map to collect all nested DAG child nodes' names, with operand index as
728   // the key. This includes both bound and unbound child nodes.
729   ChildNodeIndexNameMap childNodeNames;
730 
731   // First go through all the child nodes who are nested DAG constructs to
732   // create ops for them and remember the symbol names for them, so that we can
733   // use the results in the current node. This happens in a recursive manner.
734   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
735     if (auto child = tree.getArgAsNestedDag(i)) {
736       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
737     }
738   }
739 
740   // The name of the local variable holding this op.
741   std::string valuePackName;
742   // The symbol for holding the result of this pattern. Note that the result of
743   // this pattern is not necessarily the same as the variable created by this
744   // pattern because we can use `__N` suffix to refer only a specific result if
745   // the generated op is a multi-result op.
746   std::string resultValue;
747   if (tree.getSymbol().empty()) {
748     // No symbol is explicitly bound to this op in the pattern. Generate a
749     // unique name.
750     valuePackName = resultValue = getUniqueSymbol(&resultOp);
751   } else {
752     resultValue = tree.getSymbol();
753     // Strip the index to get the name for the value pack and use it to name the
754     // local variable for the op.
755     valuePackName = SymbolInfoMap::getValuePackName(resultValue);
756   }
757 
758   // Create the local variable for this op.
759   os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
760                           valuePackName);
761   os.indent(4) << "{\n";
762 
763   // Right now ODS don't have general type inference support. Except a few
764   // special cases listed below, DRR needs to supply types for all results
765   // when building an op.
766   bool isSameOperandsAndResultType =
767       resultOp.getTrait("OpTrait::SameOperandsAndResultType");
768   bool useFirstAttr = resultOp.getTrait("OpTrait::FirstAttrDerivedResultType");
769 
770   if (isSameOperandsAndResultType || useFirstAttr) {
771     // We know how to deduce the result type for ops with these traits and we've
772     // generated builders taking aggregate parameters. Use those builders to
773     // create the ops.
774 
775     // First prepare local variables for op arguments used in builder call.
776     createAggregateLocalVarsForOpArgs(tree, childNodeNames);
777     // Then create the op.
778     os.indent(6) << formatv(
779         "{0} = rewriter.create<{1}>(loc, tblgen_values, tblgen_attrs);\n",
780         valuePackName, resultOp.getQualCppClassName());
781     os.indent(4) << "}\n";
782     return resultValue;
783   }
784 
785   bool isBroadcastable =
786       resultOp.getTrait("OpTrait::BroadcastableTwoOperandsOneResult");
787   bool usePartialResults = valuePackName != resultValue;
788 
789   if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) {
790     // For these cases (broadcastable ops, op results used both as auxiliary
791     // values and replacement values, ops in nested patterns, auxiliary ops), we
792     // still need to supply the result types when building the op. But because
793     // we don't generate a builder automatically with ODS for them, it's the
794     // developer's responsiblity to make sure such a builder (with result type
795     // deduction ability) exists. We go through the separate-parameter builder
796     // here given that it's easier for developers to write compared to
797     // aggregate-parameter builders.
798     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
799     os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
800                             resultOp.getQualCppClassName());
801     supplyValuesForOpArgs(tree, childNodeNames);
802     os << "\n      );\n";
803     os.indent(4) << "}\n";
804     return resultValue;
805   }
806 
807   // If depth == 0 and resultIndex >= 0, it means we are replacing the values
808   // generated from the source pattern root op. Then we can use the source
809   // pattern's value types to determine the value type of the generated op
810   // here.
811 
812   // First prepare local variables for op arguments used in builder call.
813   createAggregateLocalVarsForOpArgs(tree, childNodeNames);
814 
815   // Then prepare the result types. We need to specify the types for all
816   // results.
817   os.indent(6) << formatv(
818       "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n");
819   int numResults = resultOp.getNumResults();
820   if (numResults != 0) {
821     for (int i = 0; i < numResults; ++i)
822       os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) {{"
823                               "tblgen_types.push_back(v->getType()); }\n",
824                               resultIndex + i);
825   }
826   os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc, tblgen_types, "
827                           "tblgen_values, tblgen_attrs);\n",
828                           valuePackName, resultOp.getQualCppClassName());
829   os.indent(4) << "}\n";
830   return resultValue;
831 }
832 
833 void PatternEmitter::createSeparateLocalVarsForOpArgs(
834     DagNode node, ChildNodeIndexNameMap &childNodeNames) {
835   Operator &resultOp = node.getDialectOp(opMap);
836 
837   // Now prepare operands used for building this op:
838   // * If the operand is non-variadic, we create a `Value*` local variable.
839   // * If the operand is variadic, we create a `SmallVector<Value*>` local
840   //   variable.
841 
842   int valueIndex = 0; // An index for uniquing local variable names.
843   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
844     const auto *operand =
845         resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
846     if (!operand) {
847       // We do not need special handling for attributes.
848       continue;
849     }
850 
851     std::string varName;
852     if (operand->isVariadic()) {
853       varName = formatv("tblgen_values_{0}", valueIndex++);
854       os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
855       std::string range;
856       if (node.isNestedDagArg(argIndex)) {
857         range = childNodeNames[argIndex];
858       } else {
859         range = node.getArgName(argIndex);
860       }
861       // Resolve the symbol for all range use so that we have a uniform way of
862       // capturing the values.
863       range = symbolInfoMap.getValueAndRangeUse(range);
864       os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range,
865                               varName);
866     } else {
867       varName = formatv("tblgen_value_{0}", valueIndex++);
868       os.indent(6) << formatv("Value *{0} = ", varName);
869       if (node.isNestedDagArg(argIndex)) {
870         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
871       } else {
872         DagLeaf leaf = node.getArgAsLeaf(argIndex);
873         auto symbol =
874             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
875         if (leaf.isNativeCodeCall()) {
876           os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
877         } else {
878           os << symbol;
879         }
880       }
881       os << ";\n";
882     }
883 
884     // Update to use the newly created local variable for building the op later.
885     childNodeNames[argIndex] = varName;
886   }
887 }
888 
889 void PatternEmitter::supplyValuesForOpArgs(
890     DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
891   Operator &resultOp = node.getDialectOp(opMap);
892   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
893        argIndex != numOpArgs; ++argIndex) {
894     // Start each argument on its own line.
895     (os << ",\n").indent(8);
896 
897     Argument opArg = resultOp.getArg(argIndex);
898     // Handle the case of operand first.
899     if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
900       if (!operand->name.empty())
901         os << "/*" << operand->name << "=*/";
902       os << childNodeNames.lookup(argIndex);
903       continue;
904     }
905 
906     // The argument in the op definition.
907     auto opArgName = resultOp.getArgName(argIndex);
908     if (auto subTree = node.getArgAsNestedDag(argIndex)) {
909       if (!subTree.isNativeCodeCall())
910         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
911                              "for creating attribute");
912       os << formatv("/*{0}=*/{1}", opArgName,
913                     handleReplaceWithNativeCodeCall(subTree));
914     } else {
915       auto leaf = node.getArgAsLeaf(argIndex);
916       // The argument in the result DAG pattern.
917       auto patArgName = node.getArgName(argIndex);
918       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
919         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
920         if (!opArg.is<NamedAttribute *>())
921           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
922         if (!patArgName.empty())
923           os << "/*" << patArgName << "=*/";
924       } else {
925         os << "/*" << opArgName << "=*/";
926       }
927       os << handleOpArgument(leaf, patArgName);
928     }
929   }
930 }
931 
932 void PatternEmitter::createAggregateLocalVarsForOpArgs(
933     DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
934   Operator &resultOp = node.getDialectOp(opMap);
935 
936   os.indent(6) << formatv(
937       "SmallVector<Value *, 4> tblgen_values; (void)tblgen_values;\n");
938   os.indent(6) << formatv(
939       "SmallVector<NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs;\n");
940 
941   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
942     if (const auto *attr =
943             resultOp.getArg(argIndex).dyn_cast<NamedAttribute *>()) {
944       const char *addAttrCmd = "if ({1}) {{"
945                                "  tblgen_attrs.emplace_back(rewriter."
946                                "getIdentifier(\"{0}\"), {1}); }\n";
947       // The argument in the op definition.
948       auto opArgName = resultOp.getArgName(argIndex);
949       if (auto subTree = node.getArgAsNestedDag(argIndex)) {
950         if (!subTree.isNativeCodeCall())
951           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
952                                "for creating attribute");
953         os.indent(6) << formatv(addAttrCmd, opArgName,
954                                 handleReplaceWithNativeCodeCall(subTree));
955       } else {
956         auto leaf = node.getArgAsLeaf(argIndex);
957         // The argument in the result DAG pattern.
958         auto patArgName = node.getArgName(argIndex);
959         os.indent(6) << formatv(addAttrCmd, opArgName,
960                                 handleOpArgument(leaf, patArgName));
961       }
962       continue;
963     }
964 
965     const auto *operand =
966         resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
967     std::string varName;
968     if (operand->isVariadic()) {
969       std::string range;
970       if (node.isNestedDagArg(argIndex)) {
971         range = childNodeNames.lookup(argIndex);
972       } else {
973         range = node.getArgName(argIndex);
974       }
975       // Resolve the symbol for all range use so that we have a uniform way of
976       // capturing the values.
977       range = symbolInfoMap.getValueAndRangeUse(range);
978       os.indent(6) << formatv(
979           "for (auto *v : {0}) tblgen_values.push_back(v);\n", range);
980     } else {
981       os.indent(6) << formatv("tblgen_values.push_back(", varName);
982       if (node.isNestedDagArg(argIndex)) {
983         os << symbolInfoMap.getValueAndRangeUse(
984             childNodeNames.lookup(argIndex));
985       } else {
986         DagLeaf leaf = node.getArgAsLeaf(argIndex);
987         auto symbol =
988             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
989         if (leaf.isNativeCodeCall()) {
990           os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
991         } else {
992           os << symbol;
993         }
994       }
995       os << ");\n";
996     }
997   }
998 }
999 
1000 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1001   emitSourceFileHeader("Rewriters", os);
1002 
1003   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1004   auto numPatterns = patterns.size();
1005 
1006   // We put the map here because it can be shared among multiple patterns.
1007   RecordOperatorMap recordOpMap;
1008 
1009   std::vector<std::string> rewriterNames;
1010   rewriterNames.reserve(numPatterns);
1011 
1012   std::string baseRewriterName = "GeneratedConvert";
1013   int rewriterIndex = 0;
1014 
1015   for (Record *p : patterns) {
1016     std::string name;
1017     if (p->isAnonymous()) {
1018       // If no name is provided, ensure unique rewriter names simply by
1019       // appending unique suffix.
1020       name = baseRewriterName + llvm::utostr(rewriterIndex++);
1021     } else {
1022       name = p->getName();
1023     }
1024     LLVM_DEBUG(llvm::dbgs()
1025                << "=== start generating pattern '" << name << "' ===\n");
1026     PatternEmitter(p, &recordOpMap, os).emit(name);
1027     LLVM_DEBUG(llvm::dbgs()
1028                << "=== done generating pattern '" << name << "' ===\n");
1029     rewriterNames.push_back(std::move(name));
1030   }
1031 
1032   // Emit function to add the generated matchers to the pattern list.
1033   os << "void populateWithGenerated(MLIRContext *context, "
1034      << "OwningRewritePatternList *patterns) {\n";
1035   for (const auto &name : rewriterNames) {
1036     os << "  patterns->insert<" << name << ">(context);\n";
1037   }
1038   os << "}\n";
1039 }
1040 
1041 static mlir::GenRegistration
1042     genRewriters("gen-rewriters", "Generate pattern rewriters",
1043                  [](const RecordKeeper &records, raw_ostream &os) {
1044                    emitRewriters(records, os);
1045                    return false;
1046                  });
1047