xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 13c6e419ca68cd4a5434f4349db5433395e6fbf0)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Support/STLExtras.h"
23 #include "mlir/TableGen/Attribute.h"
24 #include "mlir/TableGen/Format.h"
25 #include "mlir/TableGen/GenInfo.h"
26 #include "mlir/TableGen/Operator.h"
27 #include "mlir/TableGen/Pattern.h"
28 #include "mlir/TableGen/Predicate.h"
29 #include "mlir/TableGen/Type.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/FormatAdapters.h"
35 #include "llvm/Support/PrettyStackTrace.h"
36 #include "llvm/Support/Signals.h"
37 #include "llvm/TableGen/Error.h"
38 #include "llvm/TableGen/Main.h"
39 #include "llvm/TableGen/Record.h"
40 #include "llvm/TableGen/TableGenBackend.h"
41 
42 using namespace mlir;
43 using namespace mlir::tblgen;
44 
45 using llvm::formatv;
46 using llvm::Record;
47 using llvm::RecordKeeper;
48 
49 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
50 
51 namespace llvm {
52 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
53   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
54                      raw_ostream &os, StringRef style) {
55     os << v.first << ":" << v.second;
56   }
57 };
58 } // end namespace llvm
59 
60 //===----------------------------------------------------------------------===//
61 // PatternEmitter
62 //===----------------------------------------------------------------------===//
63 
64 namespace {
65 class PatternEmitter {
66 public:
67   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
68 
69   // Emits the mlir::RewritePattern struct named `rewriteName`.
70   void emit(StringRef rewriteName);
71 
72 private:
73   // Emits the code for matching ops.
74   void emitMatchLogic(DagNode tree);
75 
76   // Emits the code for rewriting ops.
77   void emitRewriteLogic();
78 
79   //===--------------------------------------------------------------------===//
80   // Match utilities
81   //===--------------------------------------------------------------------===//
82 
83   // Emits C++ statements for matching the op constrained by the given DAG
84   // `tree`.
85   void emitOpMatch(DagNode tree, int depth);
86 
87   // Emits C++ statements for matching the `argIndex`-th argument of the given
88   // DAG `tree` as an operand.
89   void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent);
90 
91   // Emits C++ statements for matching the `argIndex`-th argument of the given
92   // DAG `tree` as an attribute.
93   void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent);
94 
95   //===--------------------------------------------------------------------===//
96   // Rewrite utilities
97   //===--------------------------------------------------------------------===//
98 
99   // The entry point for handling a result pattern rooted at `resultTree`. This
100   // method dispatches to concrete handlers according to `resultTree`'s kind and
101   // returns a symbol representing the whole value pack. Callers are expected to
102   // further resolve the symbol according to the specific use case.
103   //
104   // `depth` is the nesting level of `resultTree`; 0 means top-level result
105   // pattern. For top-level result pattern, `resultIndex` indicates which result
106   // of the matched root op this pattern is intended to replace, which can be
107   // used to deduce the result type of the op generated from this result
108   // pattern.
109   std::string handleResultPattern(DagNode resultTree, int resultIndex,
110                                   int depth);
111 
112   // Emits the C++ statement to replace the matched DAG with a value built via
113   // calling native C++ code.
114   std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
115 
116   // Returns the C++ expression referencing the old value serving as the
117   // replacement.
118   std::string handleReplaceWithValue(DagNode tree);
119 
120   // Emits the C++ statement to build a new op out of the given DAG `tree` and
121   // returns the variable name that this op is assigned to. If the root op in
122   // DAG `tree` has a specified name, the created op will be assigned to a
123   // variable of the given name. Otherwise, a unique name will be used as the
124   // result value name.
125   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
126 
127   using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
128 
129   // Emits a local variable for each value and attribute to be used for creating
130   // an op.
131   void createSeparateLocalVarsForOpArgs(DagNode node,
132                                         ChildNodeIndexNameMap &childNodeNames);
133 
134   // Emits the concrete arguments used to call a op's builder.
135   void supplyValuesForOpArgs(DagNode node,
136                              const ChildNodeIndexNameMap &childNodeNames);
137 
138   // Emits the local variables for holding all values as a whole and all named
139   // attributes as a whole to be used for creating an op.
140   void createAggregateLocalVarsForOpArgs(
141       DagNode node, const ChildNodeIndexNameMap &childNodeNames);
142 
143   // Returns the C++ expression to construct a constant attribute of the given
144   // `value` for the given attribute kind `attr`.
145   std::string handleConstantAttr(Attribute attr, StringRef value);
146 
147   // Returns the C++ expression to build an argument from the given DAG `leaf`.
148   // `patArgName` is used to bound the argument to the source pattern.
149   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
150 
151   //===--------------------------------------------------------------------===//
152   // General utilities
153   //===--------------------------------------------------------------------===//
154 
155   // Collects all of the operations within the given dag tree.
156   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
157 
158   // Returns a unique symbol for a local variable of the given `op`.
159   std::string getUniqueSymbol(const Operator *op);
160 
161   //===--------------------------------------------------------------------===//
162   // Symbol utilities
163   //===--------------------------------------------------------------------===//
164 
165   // Returns how many static values the given DAG `node` correspond to.
166   int getNodeValueCount(DagNode node);
167 
168 private:
169   // Pattern instantiation location followed by the location of multiclass
170   // prototypes used. This is intended to be used as a whole to
171   // PrintFatalError() on errors.
172   ArrayRef<llvm::SMLoc> loc;
173 
174   // Op's TableGen Record to wrapper object.
175   RecordOperatorMap *opMap;
176 
177   // Handy wrapper for pattern being emitted.
178   Pattern pattern;
179 
180   // Map for all bound symbols' info.
181   SymbolInfoMap symbolInfoMap;
182 
183   // The next unused ID for newly created values.
184   unsigned nextValueId;
185 
186   raw_ostream &os;
187 
188   // Format contexts containing placeholder substitutions.
189   FmtContext fmtCtx;
190 
191   // Number of op processed.
192   int opCounter = 0;
193 };
194 } // end anonymous namespace
195 
196 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
197                                raw_ostream &os)
198     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
199       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
200   fmtCtx.withBuilder("rewriter");
201 }
202 
203 std::string PatternEmitter::handleConstantAttr(Attribute attr,
204                                                StringRef value) {
205   if (!attr.isConstBuildable())
206     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
207                              " does not have the 'constBuilderCall' field");
208 
209   // TODO(jpienaar): Verify the constants here
210   return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value);
211 }
212 
213 // Helper function to match patterns.
214 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
215   Operator &op = tree.getDialectOp(opMap);
216   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
217                           << op.getOperationName() << "' at depth " << depth
218                           << '\n');
219 
220   int indent = 4 + 2 * depth;
221   os.indent(indent) << formatv(
222       "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
223       depth, op.getQualCppClassName());
224   // Skip the operand matching at depth 0 as the pattern rewriter already does.
225   if (depth != 0) {
226     // Skip if there is no defining operation (e.g., arguments to function).
227     os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
228                                  depth);
229   }
230   if (tree.getNumArgs() != op.getNumArgs()) {
231     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
232                                  "pattern vs. {2} in definition",
233                                  op.getOperationName(), tree.getNumArgs(),
234                                  op.getNumArgs()));
235   }
236 
237   // If the operand's name is set, set to that variable.
238   auto name = tree.getSymbol();
239   if (!name.empty())
240     os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
241 
242   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
243     auto opArg = op.getArg(i);
244 
245     // Handle nested DAG construct first
246     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
247       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
248         if (operand->isVariadic()) {
249           auto error = formatv("use nested DAG construct to match op {0}'s "
250                                "variadic operand #{1} unsupported now",
251                                op.getOperationName(), i);
252           PrintFatalError(loc, error);
253         }
254       }
255       os.indent(indent) << "{\n";
256 
257       os.indent(indent + 2) << formatv(
258           "auto *op{0} = "
259           "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n",
260           depth + 1, depth, i);
261       emitOpMatch(argTree, depth + 1);
262       os.indent(indent + 2)
263           << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
264       os.indent(indent) << "}\n";
265       continue;
266     }
267 
268     // Next handle DAG leaf: operand or attribute
269     if (opArg.is<NamedTypeConstraint *>()) {
270       emitOperandMatch(tree, i, depth, indent);
271     } else if (opArg.is<NamedAttribute *>()) {
272       emitAttributeMatch(tree, i, depth, indent);
273     } else {
274       PrintFatalError(loc, "unhandled case when matching op");
275     }
276   }
277   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
278                           << op.getOperationName() << "' at depth " << depth
279                           << '\n');
280 }
281 
282 void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
283                                       int indent) {
284   Operator &op = tree.getDialectOp(opMap);
285   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
286   auto matcher = tree.getArgAsLeaf(argIndex);
287 
288   // If a constraint is specified, we need to generate C++ statements to
289   // check the constraint.
290   if (!matcher.isUnspecified()) {
291     if (!matcher.isOperandMatcher()) {
292       PrintFatalError(
293           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
294                        op.getOperationName(), argIndex + 1));
295     }
296 
297     // Only need to verify if the matcher's type is different from the one
298     // of op definition.
299     if (operand->constraint != matcher.getAsConstraint()) {
300       if (operand->isVariadic()) {
301         auto error = formatv(
302             "further constrain op {0}'s variadic operand #{1} unsupported now",
303             op.getOperationName(), argIndex);
304         PrintFatalError(loc, error);
305       }
306       auto self =
307           formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
308                   depth, argIndex);
309       os.indent(indent) << "if (!("
310                         << tgfmt(matcher.getConditionTemplate(),
311                                  &fmtCtx.withSelf(self))
312                         << ")) return matchFailure();\n";
313     }
314   }
315 
316   // Capture the value
317   auto name = tree.getArgName(argIndex);
318   if (!name.empty()) {
319     // We need to subtract the number of attributes before this operand to get
320     // the index in the operand list.
321     auto numPrevAttrs = std::count_if(
322         op.arg_begin(), op.arg_begin() + argIndex,
323         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
324 
325     os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
326                                  name, depth, argIndex - numPrevAttrs);
327   }
328 }
329 
330 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
331                                         int indent) {
332   Operator &op = tree.getDialectOp(opMap);
333   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
334   const auto &attr = namedAttr->attr;
335 
336   os.indent(indent) << "{\n";
337   indent += 2;
338   os.indent(indent) << formatv(
339       "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
340       attr.getStorageType(), namedAttr->name);
341 
342   // TODO(antiagainst): This should use getter method to avoid duplication.
343   if (attr.hasDefaultValueInitializer()) {
344     os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
345                       << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
346                                attr.getDefaultValueInitializer())
347                       << ";\n";
348   } else if (attr.isOptional()) {
349     // For a missing attribute that is optional according to definition, we
350     // should just capture a mlir::Attribute() to signal the missing state.
351     // That is precisely what getAttr() returns on missing attributes.
352   } else {
353     os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
354   }
355 
356   auto matcher = tree.getArgAsLeaf(argIndex);
357   if (!matcher.isUnspecified()) {
358     if (!matcher.isAttrMatcher()) {
359       PrintFatalError(
360           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
361                        op.getOperationName(), argIndex + 1));
362     }
363 
364     // If a constraint is specified, we need to generate C++ statements to
365     // check the constraint.
366     os.indent(indent) << "if (!("
367                       << tgfmt(matcher.getConditionTemplate(),
368                                &fmtCtx.withSelf("tblgen_attr"))
369                       << ")) return matchFailure();\n";
370   }
371 
372   // Capture the value
373   auto name = tree.getArgName(argIndex);
374   if (!name.empty()) {
375     os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
376   }
377 
378   indent -= 2;
379   os.indent(indent) << "}\n";
380 }
381 
382 void PatternEmitter::emitMatchLogic(DagNode tree) {
383   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
384   emitOpMatch(tree, 0);
385 
386   for (auto &appliedConstraint : pattern.getConstraints()) {
387     auto &constraint = appliedConstraint.constraint;
388     auto &entities = appliedConstraint.entities;
389 
390     auto condition = constraint.getConditionTemplate();
391     auto cmd = "if (!({0})) return matchFailure();\n";
392 
393     if (isa<TypeConstraint>(constraint)) {
394       auto self = formatv("({0}->getType())",
395                           symbolInfoMap.getValueAndRangeUse(entities.front()));
396       os.indent(4) << formatv(cmd,
397                               tgfmt(condition, &fmtCtx.withSelf(self.str())));
398     } else if (isa<AttrConstraint>(constraint)) {
399       PrintFatalError(
400           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
401     } else {
402       // TODO(b/138794486): replace formatv arguments with the exact specified
403       // args.
404       if (entities.size() > 4) {
405         PrintFatalError(loc, "only support up to 4-entity constraints now");
406       }
407       SmallVector<std::string, 4> names;
408       int i = 0;
409       for (int e = entities.size(); i < e; ++i)
410         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
411       std::string self = appliedConstraint.self;
412       if (!self.empty())
413         self = symbolInfoMap.getValueAndRangeUse(self);
414       for (; i < 4; ++i)
415         names.push_back("<unused>");
416       os.indent(4) << formatv(cmd,
417                               tgfmt(condition, &fmtCtx.withSelf(self), names[0],
418                                     names[1], names[2], names[3]));
419     }
420   }
421   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
422 }
423 
424 void PatternEmitter::collectOps(DagNode tree,
425                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
426   // Check if this tree is an operation.
427   if (tree.isOperation()) {
428     const Operator &op = tree.getDialectOp(opMap);
429     LLVM_DEBUG(llvm::dbgs()
430                << "found operation " << op.getOperationName() << '\n');
431     ops.insert(&op);
432   }
433 
434   // Recurse the arguments of the tree.
435   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
436     if (auto child = tree.getArgAsNestedDag(i))
437       collectOps(child, ops);
438 }
439 
440 void PatternEmitter::emit(StringRef rewriteName) {
441   // Get the DAG tree for the source pattern.
442   DagNode sourceTree = pattern.getSourcePattern();
443 
444   const Operator &rootOp = pattern.getSourceRootOp();
445   auto rootName = rootOp.getOperationName();
446 
447   // Collect the set of result operations.
448   llvm::SmallPtrSet<const Operator *, 4> resultOps;
449   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
450   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
451     collectOps(pattern.getResultPattern(i), resultOps);
452   }
453   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
454 
455   // Emit RewritePattern for Pattern.
456   auto locs = pattern.getLocation();
457   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
458                 make_range(locs.rbegin(), locs.rend()));
459   os << formatv(R"(struct {0} : public RewritePattern {
460   {0}(MLIRContext *context)
461       : RewritePattern("{1}", {{)",
462                 rewriteName, rootName);
463   // Sort result operators by name.
464   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
465                                                          resultOps.end());
466   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
467     return lhs->getOperationName() < rhs->getOperationName();
468   });
469   interleaveComma(sortedResultOps, os, [&](const Operator *op) {
470     os << '"' << op->getOperationName() << '"';
471   });
472   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
473 
474   // Emit matchAndRewrite() function.
475   os << R"(
476   PatternMatchResult matchAndRewrite(Operation *op0,
477                                      PatternRewriter &rewriter) const override {
478 )";
479 
480   // Register all symbols bound in the source pattern.
481   pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
482 
483   LLVM_DEBUG(
484       llvm::dbgs() << "start creating local variables for capturing matches\n");
485   os.indent(4) << "// Variables for capturing values and attributes used for "
486                   "creating ops\n";
487   // Create local variables for storing the arguments and results bound
488   // to symbols.
489   for (const auto &symbolInfoPair : symbolInfoMap) {
490     StringRef symbol = symbolInfoPair.getKey();
491     auto &info = symbolInfoPair.getValue();
492     os.indent(4) << info.getVarDecl(symbol);
493   }
494   // TODO(jpienaar): capture ops with consistent numbering so that it can be
495   // reused for fused loc.
496   os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
497                           pattern.getSourcePattern().getNumOps());
498   LLVM_DEBUG(
499       llvm::dbgs() << "done creating local variables for capturing matches\n");
500 
501   os.indent(4) << "// Match\n";
502   os.indent(4) << "tblgen_ops[0] = op0;\n";
503   emitMatchLogic(sourceTree);
504   os << "\n";
505 
506   os.indent(4) << "// Rewrite\n";
507   emitRewriteLogic();
508 
509   os.indent(4) << "return matchSuccess();\n";
510   os << "  };\n";
511   os << "};\n";
512 }
513 
514 void PatternEmitter::emitRewriteLogic() {
515   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
516   const Operator &rootOp = pattern.getSourceRootOp();
517   int numExpectedResults = rootOp.getNumResults();
518   int numResultPatterns = pattern.getNumResultPatterns();
519 
520   // First register all symbols bound to ops generated in result patterns.
521   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
522 
523   // Only the last N static values generated are used to replace the matched
524   // root N-result op. We need to calculate the starting index (of the results
525   // of the matched op) each result pattern is to replace.
526   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
527   // If we don't need to replace any value at all, set the replacement starting
528   // index as the number of result patterns so we skip all of them when trying
529   // to replace the matched op's results.
530   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
531   for (int i = numResultPatterns - 1; i >= 0; --i) {
532     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
533     offsets[i] = offsets[i + 1] - numValues;
534     if (offsets[i] == 0) {
535       if (replStartIndex == -1)
536         replStartIndex = i;
537     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
538       auto error = formatv(
539           "cannot use the same multi-result op '{0}' to generate both "
540           "auxiliary values and values to be used for replacing the matched op",
541           pattern.getResultPattern(i).getSymbol());
542       PrintFatalError(loc, error);
543     }
544   }
545 
546   if (offsets.front() > 0) {
547     const char error[] = "no enough values generated to replace the matched op";
548     PrintFatalError(loc, error);
549   }
550 
551   os.indent(4) << "auto loc = rewriter.getFusedLoc({";
552   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
553     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
554   }
555   os << "}); (void)loc;\n";
556 
557   // Process auxiliary result patterns.
558   for (int i = 0; i < replStartIndex; ++i) {
559     DagNode resultTree = pattern.getResultPattern(i);
560     auto val = handleResultPattern(resultTree, offsets[i], 0);
561     // Normal op creation will be streamed to `os` by the above call; but
562     // NativeCodeCall will only be materialized to `os` if it is used. Here
563     // we are handling auxiliary patterns so we want the side effect even if
564     // NativeCodeCall is not replacing matched root op's results.
565     if (resultTree.isNativeCodeCall())
566       os.indent(4) << val << ";\n";
567   }
568 
569   if (numExpectedResults == 0) {
570     assert(replStartIndex >= numResultPatterns &&
571            "invalid auxiliary vs. replacement pattern division!");
572     // No result to replace. Just erase the op.
573     os.indent(4) << "rewriter.eraseOp(op0);\n";
574   } else {
575     // Process replacement result patterns.
576     os.indent(4) << "SmallVector<Value *, 4> tblgen_repl_values;\n";
577     for (int i = replStartIndex; i < numResultPatterns; ++i) {
578       DagNode resultTree = pattern.getResultPattern(i);
579       auto val = handleResultPattern(resultTree, offsets[i], 0);
580       os.indent(4) << "\n";
581       // Resolve each symbol for all range use so that we can loop over them.
582       os << symbolInfoMap.getAllRangeUse(
583           val, "    for (auto *v : {0}) {{ tblgen_repl_values.push_back(v); }",
584           "\n");
585     }
586     os.indent(4) << "\n";
587     os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n";
588   }
589 
590   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
591 }
592 
593 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
594   return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
595 }
596 
597 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
598                                                 int resultIndex, int depth) {
599   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
600   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
601   LLVM_DEBUG(llvm::dbgs() << '\n');
602 
603   if (resultTree.isNativeCodeCall()) {
604     auto symbol = handleReplaceWithNativeCodeCall(resultTree);
605     symbolInfoMap.bindValue(symbol);
606     return symbol;
607   }
608 
609   if (resultTree.isReplaceWithValue()) {
610     return handleReplaceWithValue(resultTree);
611   }
612 
613   // Normal op creation.
614   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
615   if (resultTree.getSymbol().empty()) {
616     // This is an op not explicitly bound to a symbol in the rewrite rule.
617     // Register the auto-generated symbol for it.
618     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
619   }
620   return symbol;
621 }
622 
623 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
624   assert(tree.isReplaceWithValue());
625 
626   if (tree.getNumArgs() != 1) {
627     PrintFatalError(
628         loc, "replaceWithValue directive must take exactly one argument");
629   }
630 
631   if (!tree.getSymbol().empty()) {
632     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
633   }
634 
635   return tree.getArgName(0);
636 }
637 
638 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
639                                              StringRef patArgName) {
640   if (leaf.isConstantAttr()) {
641     auto constAttr = leaf.getAsConstantAttr();
642     return handleConstantAttr(constAttr.getAttribute(),
643                               constAttr.getConstantValue());
644   }
645   if (leaf.isEnumAttrCase()) {
646     auto enumCase = leaf.getAsEnumAttrCase();
647     if (enumCase.isStrCase())
648       return handleConstantAttr(enumCase, enumCase.getSymbol());
649     // This is an enum case backed by an IntegerAttr. We need to get its value
650     // to build the constant.
651     std::string val = std::to_string(enumCase.getValue());
652     return handleConstantAttr(enumCase, val);
653   }
654 
655   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
656   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
657   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
658     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
659                             << "' (via symbol ref)\n");
660     return argName;
661   }
662   if (leaf.isNativeCodeCall()) {
663     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
664     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
665                             << "' (via NativeCodeCall)\n");
666     return repl;
667   }
668   PrintFatalError(loc, "unhandled case when rewriting op");
669 }
670 
671 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
672   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
673   LLVM_DEBUG(tree.print(llvm::dbgs()));
674   LLVM_DEBUG(llvm::dbgs() << '\n');
675 
676   auto fmt = tree.getNativeCodeTemplate();
677   // TODO(b/138794486): replace formatv arguments with the exact specified args.
678   SmallVector<std::string, 8> attrs(8);
679   if (tree.getNumArgs() > 8) {
680     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
681                              Twine(tree.getNumArgs()));
682   }
683   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
684     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
685     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argment #" << i
686                             << " replacement: " << attrs[i] << "\n");
687   }
688   return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
689                attrs[5], attrs[6], attrs[7]);
690 }
691 
692 int PatternEmitter::getNodeValueCount(DagNode node) {
693   if (node.isOperation()) {
694     // If the op is bound to a symbol in the rewrite rule, query its result
695     // count from the symbol info map.
696     auto symbol = node.getSymbol();
697     if (!symbol.empty()) {
698       return symbolInfoMap.getStaticValueCount(symbol);
699     }
700     // Otherwise this is an unbound op; we will use all its results.
701     return pattern.getDialectOp(node).getNumResults();
702   }
703   // TODO(antiagainst): This considers all NativeCodeCall as returning one
704   // value. Enhance if multi-value ones are needed.
705   return 1;
706 }
707 
708 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
709                                              int depth) {
710   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
711   LLVM_DEBUG(tree.print(llvm::dbgs()));
712   LLVM_DEBUG(llvm::dbgs() << '\n');
713 
714   Operator &resultOp = tree.getDialectOp(opMap);
715   auto numOpArgs = resultOp.getNumArgs();
716 
717   if (numOpArgs != tree.getNumArgs()) {
718     PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
719                                  "{1} in pattern vs. {2} in definition",
720                                  resultOp.getOperationName(), tree.getNumArgs(),
721                                  numOpArgs));
722   }
723 
724   // A map to collect all nested DAG child nodes' names, with operand index as
725   // the key. This includes both bound and unbound child nodes.
726   ChildNodeIndexNameMap childNodeNames;
727 
728   // First go through all the child nodes who are nested DAG constructs to
729   // create ops for them and remember the symbol names for them, so that we can
730   // use the results in the current node. This happens in a recursive manner.
731   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
732     if (auto child = tree.getArgAsNestedDag(i)) {
733       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
734     }
735   }
736 
737   // The name of the local variable holding this op.
738   std::string valuePackName;
739   // The symbol for holding the result of this pattern. Note that the result of
740   // this pattern is not necessarily the same as the variable created by this
741   // pattern because we can use `__N` suffix to refer only a specific result if
742   // the generated op is a multi-result op.
743   std::string resultValue;
744   if (tree.getSymbol().empty()) {
745     // No symbol is explicitly bound to this op in the pattern. Generate a
746     // unique name.
747     valuePackName = resultValue = getUniqueSymbol(&resultOp);
748   } else {
749     resultValue = tree.getSymbol();
750     // Strip the index to get the name for the value pack and use it to name the
751     // local variable for the op.
752     valuePackName = SymbolInfoMap::getValuePackName(resultValue);
753   }
754 
755   // Create the local variable for this op.
756   os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
757                           valuePackName);
758   os.indent(4) << "{\n";
759 
760   // Right now ODS don't have general type inference support. Except a few
761   // special cases listed below, DRR needs to supply types for all results
762   // when building an op.
763   bool isSameOperandsAndResultType =
764       resultOp.getTrait("OpTrait::SameOperandsAndResultType");
765   bool useFirstAttr = resultOp.getTrait("OpTrait::FirstAttrDerivedResultType");
766 
767   if (isSameOperandsAndResultType || useFirstAttr) {
768     // We know how to deduce the result type for ops with these traits and we've
769     // generated builders taking aggregrate parameters. Use those builders to
770     // create the ops.
771 
772     // First prepare local variables for op arguments used in builder call.
773     createAggregateLocalVarsForOpArgs(tree, childNodeNames);
774     // Then create the op.
775     os.indent(6) << formatv(
776         "{0} = rewriter.create<{1}>(loc, tblgen_values, tblgen_attrs);\n",
777         valuePackName, resultOp.getQualCppClassName());
778     os.indent(4) << "}\n";
779     return resultValue;
780   }
781 
782   bool isBroadcastable =
783       resultOp.getTrait("OpTrait::BroadcastableTwoOperandsOneResult");
784   bool usePartialResults = valuePackName != resultValue;
785 
786   if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) {
787     // For these cases (broadcastable ops, op results used both as auxiliary
788     // values and replacement values, ops in nested patterns, auxiliary ops), we
789     // still need to supply the result types when building the op. But because
790     // we don't generate a builder automatically with ODS for them, it's the
791     // developer's responsiblity to make sure such a builder (with result type
792     // deduction ability) exists. We go through the separate-parameter builder
793     // here given that it's easier for developers to write compared to
794     // aggregate-parameter builders.
795     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
796     os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
797                             resultOp.getQualCppClassName());
798     supplyValuesForOpArgs(tree, childNodeNames);
799     os << "\n      );\n";
800     os.indent(4) << "}\n";
801     return resultValue;
802   }
803 
804   // If depth == 0 and resultIndex >= 0, it means we are replacing the values
805   // generated from the source pattern root op. Then we can use the source
806   // pattern's value types to determine the value type of the generated op
807   // here.
808 
809   // First prepare local variables for op arguments used in builder call.
810   createAggregateLocalVarsForOpArgs(tree, childNodeNames);
811 
812   // Then prepare the result types. We need to specify the types for all
813   // results.
814   os.indent(6) << formatv(
815       "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n");
816   int numResults = resultOp.getNumResults();
817   if (numResults != 0) {
818     for (int i = 0; i < numResults; ++i)
819       os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) {{"
820                               "tblgen_types.push_back(v->getType()); }\n",
821                               resultIndex + i);
822   }
823   os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc, tblgen_types, "
824                           "tblgen_values, tblgen_attrs);\n",
825                           valuePackName, resultOp.getQualCppClassName());
826   os.indent(4) << "}\n";
827   return resultValue;
828 }
829 
830 void PatternEmitter::createSeparateLocalVarsForOpArgs(
831     DagNode node, ChildNodeIndexNameMap &childNodeNames) {
832   Operator &resultOp = node.getDialectOp(opMap);
833 
834   // Now prepare operands used for building this op:
835   // * If the operand is non-variadic, we create a `Value*` local variable.
836   // * If the operand is variadic, we create a `SmallVector<Value*>` local
837   //   variable.
838 
839   int valueIndex = 0; // An index for uniquing local variable names.
840   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
841     const auto *operand =
842         resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
843     if (!operand) {
844       // We do not need special handling for attributes.
845       continue;
846     }
847 
848     std::string varName;
849     if (operand->isVariadic()) {
850       varName = formatv("tblgen_values_{0}", valueIndex++);
851       os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
852       std::string range;
853       if (node.isNestedDagArg(argIndex)) {
854         range = childNodeNames[argIndex];
855       } else {
856         range = node.getArgName(argIndex);
857       }
858       // Resolve the symbol for all range use so that we have a uniform way of
859       // capturing the values.
860       range = symbolInfoMap.getValueAndRangeUse(range);
861       os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range,
862                               varName);
863     } else {
864       varName = formatv("tblgen_value_{0}", valueIndex++);
865       os.indent(6) << formatv("Value *{0} = ", varName);
866       if (node.isNestedDagArg(argIndex)) {
867         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
868       } else {
869         DagLeaf leaf = node.getArgAsLeaf(argIndex);
870         auto symbol =
871             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
872         if (leaf.isNativeCodeCall()) {
873           os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
874         } else {
875           os << symbol;
876         }
877       }
878       os << ";\n";
879     }
880 
881     // Update to use the newly created local variable for building the op later.
882     childNodeNames[argIndex] = varName;
883   }
884 }
885 
886 void PatternEmitter::supplyValuesForOpArgs(
887     DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
888   Operator &resultOp = node.getDialectOp(opMap);
889   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
890        argIndex != numOpArgs; ++argIndex) {
891     // Start each argment on its own line.
892     (os << ",\n").indent(8);
893 
894     Argument opArg = resultOp.getArg(argIndex);
895     // Handle the case of operand first.
896     if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
897       if (!operand->name.empty())
898         os << "/*" << operand->name << "=*/";
899       os << childNodeNames.lookup(argIndex);
900       continue;
901     }
902 
903     // The argument in the op definition.
904     auto opArgName = resultOp.getArgName(argIndex);
905     if (auto subTree = node.getArgAsNestedDag(argIndex)) {
906       if (!subTree.isNativeCodeCall())
907         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
908                              "for creating attribute");
909       os << formatv("/*{0}=*/{1}", opArgName,
910                     handleReplaceWithNativeCodeCall(subTree));
911     } else {
912       auto leaf = node.getArgAsLeaf(argIndex);
913       // The argument in the result DAG pattern.
914       auto patArgName = node.getArgName(argIndex);
915       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
916         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
917         if (!opArg.is<NamedAttribute *>())
918           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
919         if (!patArgName.empty())
920           os << "/*" << patArgName << "=*/";
921       } else {
922         os << "/*" << opArgName << "=*/";
923       }
924       os << handleOpArgument(leaf, patArgName);
925     }
926   }
927 }
928 
929 void PatternEmitter::createAggregateLocalVarsForOpArgs(
930     DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
931   Operator &resultOp = node.getDialectOp(opMap);
932 
933   os.indent(6) << formatv(
934       "SmallVector<Value *, 4> tblgen_values; (void)tblgen_values;\n");
935   os.indent(6) << formatv(
936       "SmallVector<NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs;\n");
937 
938   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
939     if (const auto *attr =
940             resultOp.getArg(argIndex).dyn_cast<NamedAttribute *>()) {
941       const char *addAttrCmd = "if ({1}) {{"
942                                "  tblgen_attrs.emplace_back(rewriter."
943                                "getIdentifier(\"{0}\"), {1}); }\n";
944       // The argument in the op definition.
945       auto opArgName = resultOp.getArgName(argIndex);
946       if (auto subTree = node.getArgAsNestedDag(argIndex)) {
947         if (!subTree.isNativeCodeCall())
948           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
949                                "for creating attribute");
950         os.indent(6) << formatv(addAttrCmd, opArgName,
951                                 handleReplaceWithNativeCodeCall(subTree));
952       } else {
953         auto leaf = node.getArgAsLeaf(argIndex);
954         // The argument in the result DAG pattern.
955         auto patArgName = node.getArgName(argIndex);
956         os.indent(6) << formatv(addAttrCmd, opArgName,
957                                 handleOpArgument(leaf, patArgName));
958       }
959       continue;
960     }
961 
962     const auto *operand =
963         resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
964     std::string varName;
965     if (operand->isVariadic()) {
966       std::string range;
967       if (node.isNestedDagArg(argIndex)) {
968         range = childNodeNames.lookup(argIndex);
969       } else {
970         range = node.getArgName(argIndex);
971       }
972       // Resolve the symbol for all range use so that we have a uniform way of
973       // capturing the values.
974       range = symbolInfoMap.getValueAndRangeUse(range);
975       os.indent(6) << formatv(
976           "for (auto *v : {0}) tblgen_values.push_back(v);\n", range);
977     } else {
978       os.indent(6) << formatv("tblgen_values.push_back(", varName);
979       if (node.isNestedDagArg(argIndex)) {
980         os << symbolInfoMap.getValueAndRangeUse(
981             childNodeNames.lookup(argIndex));
982       } else {
983         DagLeaf leaf = node.getArgAsLeaf(argIndex);
984         auto symbol =
985             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
986         if (leaf.isNativeCodeCall()) {
987           os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
988         } else {
989           os << symbol;
990         }
991       }
992       os << ");\n";
993     }
994   }
995 }
996 
997 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
998   emitSourceFileHeader("Rewriters", os);
999 
1000   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1001   auto numPatterns = patterns.size();
1002 
1003   // We put the map here because it can be shared among multiple patterns.
1004   RecordOperatorMap recordOpMap;
1005 
1006   std::vector<std::string> rewriterNames;
1007   rewriterNames.reserve(numPatterns);
1008 
1009   std::string baseRewriterName = "GeneratedConvert";
1010   int rewriterIndex = 0;
1011 
1012   for (Record *p : patterns) {
1013     std::string name;
1014     if (p->isAnonymous()) {
1015       // If no name is provided, ensure unique rewriter names simply by
1016       // appending unique suffix.
1017       name = baseRewriterName + llvm::utostr(rewriterIndex++);
1018     } else {
1019       name = p->getName();
1020     }
1021     LLVM_DEBUG(llvm::dbgs()
1022                << "=== start generating pattern '" << name << "' ===\n");
1023     PatternEmitter(p, &recordOpMap, os).emit(name);
1024     LLVM_DEBUG(llvm::dbgs()
1025                << "=== done generating pattern '" << name << "' ===\n");
1026     rewriterNames.push_back(std::move(name));
1027   }
1028 
1029   // Emit function to add the generated matchers to the pattern list.
1030   os << "void populateWithGenerated(MLIRContext *context, "
1031      << "OwningRewritePatternList *patterns) {\n";
1032   for (const auto &name : rewriterNames) {
1033     os << "  patterns->insert<" << name << ">(context);\n";
1034   }
1035   os << "}\n";
1036 }
1037 
1038 static mlir::GenRegistration
1039     genRewriters("gen-rewriters", "Generate pattern rewriters",
1040                  [](const RecordKeeper &records, raw_ostream &os) {
1041                    emitRewriters(records, os);
1042                    return false;
1043                  });
1044