xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 3aae473658d62f5352a0e8b2fa77ebabcbf6a50a)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Support/STLExtras.h"
23 #include "mlir/TableGen/Attribute.h"
24 #include "mlir/TableGen/Format.h"
25 #include "mlir/TableGen/GenInfo.h"
26 #include "mlir/TableGen/Operator.h"
27 #include "mlir/TableGen/Pattern.h"
28 #include "mlir/TableGen/Predicate.h"
29 #include "mlir/TableGen/Type.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/FormatAdapters.h"
35 #include "llvm/Support/PrettyStackTrace.h"
36 #include "llvm/Support/Signals.h"
37 #include "llvm/TableGen/Error.h"
38 #include "llvm/TableGen/Main.h"
39 #include "llvm/TableGen/Record.h"
40 #include "llvm/TableGen/TableGenBackend.h"
41 
42 using namespace llvm;
43 using namespace mlir;
44 using namespace mlir::tblgen;
45 
46 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
47 
48 namespace llvm {
49 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
50   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
51                      raw_ostream &os, StringRef style) {
52     os << v.first << ":" << v.second;
53   }
54 };
55 } // end namespace llvm
56 
57 //===----------------------------------------------------------------------===//
58 // PatternEmitter
59 //===----------------------------------------------------------------------===//
60 
61 namespace {
62 class PatternEmitter {
63 public:
64   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
65 
66   // Emits the mlir::RewritePattern struct named `rewriteName`.
67   void emit(StringRef rewriteName);
68 
69 private:
70   // Emits the code for matching ops.
71   void emitMatchLogic(DagNode tree);
72 
73   // Emits the code for rewriting ops.
74   void emitRewriteLogic();
75 
76   //===--------------------------------------------------------------------===//
77   // Match utilities
78   //===--------------------------------------------------------------------===//
79 
80   // Emits C++ statements for matching the op constrained by the given DAG
81   // `tree`.
82   void emitOpMatch(DagNode tree, int depth);
83 
84   // Emits C++ statements for matching the `index`-th argument of the given DAG
85   // `tree` as an operand.
86   void emitOperandMatch(DagNode tree, int index, int depth, int indent);
87 
88   // Emits C++ statements for matching the `index`-th argument of the given DAG
89   // `tree` as an attribute.
90   void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
91 
92   //===--------------------------------------------------------------------===//
93   // Rewrite utilities
94   //===--------------------------------------------------------------------===//
95 
96   // The entry point for handling a result pattern rooted at `resultTree`. This
97   // method dispatches to concrete handlers according to `resultTree`'s kind and
98   // returns a symbol representing the whole value pack. Callers are expected to
99   // further resolve the symbol according to the specific use case.
100   //
101   // `depth` is the nesting level of `resultTree`; 0 means top-level result
102   // pattern. For top-level result pattern, `resultIndex` indicates which result
103   // of the matched root op this pattern is intended to replace, which can be
104   // used to deduce the result type of the op generated from this result
105   // pattern.
106   std::string handleResultPattern(DagNode resultTree, int resultIndex,
107                                   int depth);
108 
109   // Emits the C++ statement to replace the matched DAG with a value built via
110   // calling native C++ code.
111   std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
112 
113   // Returns the C++ expression referencing the old value serving as the
114   // replacement.
115   std::string handleReplaceWithValue(DagNode tree);
116 
117   // Emits the C++ statement to build a new op out of the given DAG `tree` and
118   // returns the variable name that this op is assigned to. If the root op in
119   // DAG `tree` has a specified name, the created op will be assigned to a
120   // variable of the given name. Otherwise, a unique name will be used as the
121   // result value name.
122   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
123 
124   // Returns the C++ expression to construct a constant attribute of the given
125   // `value` for the given attribute kind `attr`.
126   std::string handleConstantAttr(Attribute attr, StringRef value);
127 
128   // Returns the C++ expression to build an argument from the given DAG `leaf`.
129   // `patArgName` is used to bound the argument to the source pattern.
130   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
131 
132   //===--------------------------------------------------------------------===//
133   // General utilities
134   //===--------------------------------------------------------------------===//
135 
136   // Collects all of the operations within the given dag tree.
137   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
138 
139   // Returns a unique symbol for a local variable of the given `op`.
140   std::string getUniqueSymbol(const Operator *op);
141 
142   //===--------------------------------------------------------------------===//
143   // Symbol utilities
144   //===--------------------------------------------------------------------===//
145 
146   // Returns how many static values the given DAG `node` correspond to.
147   int getNodeValueCount(DagNode node);
148 
149 private:
150   // Pattern instantiation location followed by the location of multiclass
151   // prototypes used. This is intended to be used as a whole to
152   // PrintFatalError() on errors.
153   ArrayRef<llvm::SMLoc> loc;
154 
155   // Op's TableGen Record to wrapper object.
156   RecordOperatorMap *opMap;
157 
158   // Handy wrapper for pattern being emitted.
159   Pattern pattern;
160 
161   // Map for all bound symbols' info.
162   SymbolInfoMap symbolInfoMap;
163 
164   // The next unused ID for newly created values.
165   unsigned nextValueId;
166 
167   raw_ostream &os;
168 
169   // Format contexts containing placeholder substitutions.
170   FmtContext fmtCtx;
171 
172   // Number of op processed.
173   int opCounter = 0;
174 };
175 } // end anonymous namespace
176 
177 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
178                                raw_ostream &os)
179     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
180       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
181   fmtCtx.withBuilder("rewriter");
182 }
183 
184 std::string PatternEmitter::handleConstantAttr(Attribute attr,
185                                                StringRef value) {
186   if (!attr.isConstBuildable())
187     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
188                              " does not have the 'constBuilderCall' field");
189 
190   // TODO(jpienaar): Verify the constants here
191   return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value);
192 }
193 
194 // Helper function to match patterns.
195 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
196   Operator &op = tree.getDialectOp(opMap);
197   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
198                           << op.getOperationName() << "' at depth " << depth
199                           << '\n');
200 
201   int indent = 4 + 2 * depth;
202   os.indent(indent) << formatv(
203       "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
204       depth, op.getQualCppClassName());
205   // Skip the operand matching at depth 0 as the pattern rewriter already does.
206   if (depth != 0) {
207     // Skip if there is no defining operation (e.g., arguments to function).
208     os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
209                                  depth);
210   }
211   if (tree.getNumArgs() != op.getNumArgs()) {
212     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
213                                  "pattern vs. {2} in definition",
214                                  op.getOperationName(), tree.getNumArgs(),
215                                  op.getNumArgs()));
216   }
217 
218   // If the operand's name is set, set to that variable.
219   auto name = tree.getSymbol();
220   if (!name.empty())
221     os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
222 
223   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
224     auto opArg = op.getArg(i);
225 
226     // Handle nested DAG construct first
227     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
228       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
229         if (operand->isVariadic()) {
230           auto error = formatv("use nested DAG construct to match op {0}'s "
231                                "variadic operand #{1} unsupported now",
232                                op.getOperationName(), i);
233           PrintFatalError(loc, error);
234         }
235       }
236       os.indent(indent) << "{\n";
237 
238       os.indent(indent + 2) << formatv(
239           "auto *op{0} = "
240           "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n",
241           depth + 1, depth, i);
242       emitOpMatch(argTree, depth + 1);
243       os.indent(indent + 2)
244           << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
245       os.indent(indent) << "}\n";
246       continue;
247     }
248 
249     // Next handle DAG leaf: operand or attribute
250     if (opArg.is<NamedTypeConstraint *>()) {
251       emitOperandMatch(tree, i, depth, indent);
252     } else if (opArg.is<NamedAttribute *>()) {
253       emitAttributeMatch(tree, i, depth, indent);
254     } else {
255       PrintFatalError(loc, "unhandled case when matching op");
256     }
257   }
258   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
259                           << op.getOperationName() << "' at depth " << depth
260                           << '\n');
261 }
262 
263 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
264                                       int indent) {
265   Operator &op = tree.getDialectOp(opMap);
266   auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
267   auto matcher = tree.getArgAsLeaf(index);
268 
269   // If a constraint is specified, we need to generate C++ statements to
270   // check the constraint.
271   if (!matcher.isUnspecified()) {
272     if (!matcher.isOperandMatcher()) {
273       PrintFatalError(
274           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
275                        op.getOperationName(), index + 1));
276     }
277 
278     // Only need to verify if the matcher's type is different from the one
279     // of op definition.
280     if (operand->constraint != matcher.getAsConstraint()) {
281       if (operand->isVariadic()) {
282         auto error = formatv(
283             "further constrain op {0}'s variadic operand #{1} unsupported now",
284             op.getOperationName(), index);
285         PrintFatalError(loc, error);
286       }
287       auto self =
288           formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
289                   depth, index);
290       os.indent(indent) << "if (!("
291                         << tgfmt(matcher.getConditionTemplate(),
292                                  &fmtCtx.withSelf(self))
293                         << ")) return matchFailure();\n";
294     }
295   }
296 
297   // Capture the value
298   auto name = tree.getArgName(index);
299   if (!name.empty()) {
300     os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
301                                  name, depth, index);
302   }
303 }
304 
305 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
306                                         int indent) {
307   Operator &op = tree.getDialectOp(opMap);
308   auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
309   const auto &attr = namedAttr->attr;
310 
311   os.indent(indent) << "{\n";
312   indent += 2;
313   os.indent(indent) << formatv(
314       "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
315       attr.getStorageType(), namedAttr->name);
316 
317   // TODO(antiagainst): This should use getter method to avoid duplication.
318   if (attr.hasDefaultValueInitializer()) {
319     os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
320                       << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
321                                attr.getDefaultValueInitializer())
322                       << ";\n";
323   } else if (attr.isOptional()) {
324     // For a missing attribute that is optional according to definition, we
325     // should just capture a mlir::Attribute() to signal the missing state.
326     // That is precisely what getAttr() returns on missing attributes.
327   } else {
328     os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
329   }
330 
331   auto matcher = tree.getArgAsLeaf(index);
332   if (!matcher.isUnspecified()) {
333     if (!matcher.isAttrMatcher()) {
334       PrintFatalError(
335           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
336                        op.getOperationName(), index + 1));
337     }
338 
339     // If a constraint is specified, we need to generate C++ statements to
340     // check the constraint.
341     os.indent(indent) << "if (!("
342                       << tgfmt(matcher.getConditionTemplate(),
343                                &fmtCtx.withSelf("tblgen_attr"))
344                       << ")) return matchFailure();\n";
345   }
346 
347   // Capture the value
348   auto name = tree.getArgName(index);
349   if (!name.empty()) {
350     os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
351   }
352 
353   indent -= 2;
354   os.indent(indent) << "}\n";
355 }
356 
357 void PatternEmitter::emitMatchLogic(DagNode tree) {
358   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
359   emitOpMatch(tree, 0);
360 
361   for (auto &appliedConstraint : pattern.getConstraints()) {
362     auto &constraint = appliedConstraint.constraint;
363     auto &entities = appliedConstraint.entities;
364 
365     auto condition = constraint.getConditionTemplate();
366     auto cmd = "if (!({0})) return matchFailure();\n";
367 
368     if (isa<TypeConstraint>(constraint)) {
369       auto self = formatv("({0}->getType())",
370                           symbolInfoMap.getValueAndRangeUse(entities.front()));
371       os.indent(4) << formatv(cmd,
372                               tgfmt(condition, &fmtCtx.withSelf(self.str())));
373     } else if (isa<AttrConstraint>(constraint)) {
374       PrintFatalError(
375           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
376     } else {
377       // TODO(b/138794486): replace formatv arguments with the exact specified
378       // args.
379       if (entities.size() > 4) {
380         PrintFatalError(loc, "only support up to 4-entity constraints now");
381       }
382       SmallVector<std::string, 4> names;
383       int i = 0;
384       for (int e = entities.size(); i < e; ++i)
385         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
386       std::string self = appliedConstraint.self;
387       if (!self.empty())
388         self = symbolInfoMap.getValueAndRangeUse(self);
389       for (; i < 4; ++i)
390         names.push_back("<unused>");
391       os.indent(4) << formatv(cmd,
392                               tgfmt(condition, &fmtCtx.withSelf(self), names[0],
393                                     names[1], names[2], names[3]));
394     }
395   }
396   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
397 }
398 
399 void PatternEmitter::collectOps(DagNode tree,
400                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
401   // Check if this tree is an operation.
402   if (tree.isOperation()) {
403     const Operator &op = tree.getDialectOp(opMap);
404     LLVM_DEBUG(llvm::dbgs()
405                << "found operation " << op.getOperationName() << '\n');
406     ops.insert(&op);
407   }
408 
409   // Recurse the arguments of the tree.
410   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
411     if (auto child = tree.getArgAsNestedDag(i))
412       collectOps(child, ops);
413 }
414 
415 void PatternEmitter::emit(StringRef rewriteName) {
416   // Get the DAG tree for the source pattern.
417   DagNode sourceTree = pattern.getSourcePattern();
418 
419   const Operator &rootOp = pattern.getSourceRootOp();
420   auto rootName = rootOp.getOperationName();
421 
422   // Collect the set of result operations.
423   llvm::SmallPtrSet<const Operator *, 4> resultOps;
424   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
425   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
426     collectOps(pattern.getResultPattern(i), resultOps);
427   }
428   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
429 
430   // Emit RewritePattern for Pattern.
431   auto locs = pattern.getLocation();
432   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
433                 make_range(locs.rbegin(), locs.rend()));
434   os << formatv(R"(struct {0} : public RewritePattern {
435   {0}(MLIRContext *context)
436       : RewritePattern("{1}", {{)",
437                 rewriteName, rootName);
438   // Sort result operators by name.
439   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
440                                                          resultOps.end());
441   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
442     return lhs->getOperationName() < rhs->getOperationName();
443   });
444   interleaveComma(sortedResultOps, os, [&](const Operator *op) {
445     os << '"' << op->getOperationName() << '"';
446   });
447   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
448 
449   // Emit matchAndRewrite() function.
450   os << R"(
451   PatternMatchResult matchAndRewrite(Operation *op0,
452                                      PatternRewriter &rewriter) const override {
453 )";
454 
455   // Register all symbols bound in the source pattern.
456   pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
457 
458   LLVM_DEBUG(
459       llvm::dbgs() << "start creating local variables for capturing matches\n");
460   os.indent(4) << "// Variables for capturing values and attributes used for "
461                   "creating ops\n";
462   // Create local variables for storing the arguments and results bound
463   // to symbols.
464   for (const auto &symbolInfoPair : symbolInfoMap) {
465     StringRef symbol = symbolInfoPair.getKey();
466     auto &info = symbolInfoPair.getValue();
467     os.indent(4) << info.getVarDecl(symbol);
468   }
469   // TODO(jpienaar): capture ops with consistent numbering so that it can be
470   // reused for fused loc.
471   os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
472                           pattern.getSourcePattern().getNumOps());
473   LLVM_DEBUG(
474       llvm::dbgs() << "done creating local variables for capturing matches\n");
475 
476   os.indent(4) << "// Match\n";
477   os.indent(4) << "tblgen_ops[0] = op0;\n";
478   emitMatchLogic(sourceTree);
479   os << "\n";
480 
481   os.indent(4) << "// Rewrite\n";
482   emitRewriteLogic();
483 
484   os.indent(4) << "return matchSuccess();\n";
485   os << "  };\n";
486   os << "};\n";
487 }
488 
489 void PatternEmitter::emitRewriteLogic() {
490   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
491   const Operator &rootOp = pattern.getSourceRootOp();
492   int numExpectedResults = rootOp.getNumResults();
493   int numResultPatterns = pattern.getNumResultPatterns();
494 
495   // First register all symbols bound to ops generated in result patterns.
496   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
497 
498   // Only the last N static values generated are used to replace the matched
499   // root N-result op. We need to calculate the starting index (of the results
500   // of the matched op) each result pattern is to replace.
501   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
502   // If we don't need to replace any value at all, set the replacement starting
503   // index as the number of result patterns so we skip all of them when trying
504   // to replace the matched op's results.
505   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
506   for (int i = numResultPatterns - 1; i >= 0; --i) {
507     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
508     offsets[i] = offsets[i + 1] - numValues;
509     if (offsets[i] == 0) {
510       if (replStartIndex == -1)
511         replStartIndex = i;
512     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
513       auto error = formatv(
514           "cannot use the same multi-result op '{0}' to generate both "
515           "auxiliary values and values to be used for replacing the matched op",
516           pattern.getResultPattern(i).getSymbol());
517       PrintFatalError(loc, error);
518     }
519   }
520 
521   if (offsets.front() > 0) {
522     const char error[] = "no enough values generated to replace the matched op";
523     PrintFatalError(loc, error);
524   }
525 
526   os.indent(4) << "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n";
527   os.indent(4) << "auto loc = rewriter.getFusedLoc({";
528   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
529     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
530   }
531   os << "}); (void)loc;\n";
532 
533   // Process auxiliary result patterns.
534   for (int i = 0; i < replStartIndex; ++i) {
535     DagNode resultTree = pattern.getResultPattern(i);
536     auto val = handleResultPattern(resultTree, offsets[i], 0);
537     // Normal op creation will be streamed to `os` by the above call; but
538     // NativeCodeCall will only be materialized to `os` if it is used. Here
539     // we are handling auxiliary patterns so we want the side effect even if
540     // NativeCodeCall is not replacing matched root op's results.
541     if (resultTree.isNativeCodeCall())
542       os.indent(4) << val << ";\n";
543   }
544 
545   if (numExpectedResults == 0) {
546     assert(replStartIndex >= numResultPatterns &&
547            "invalid auxiliary vs. replacement pattern division!");
548     // No result to replace. Just erase the op.
549     os.indent(4) << "rewriter.eraseOp(op0);\n";
550   } else {
551     // Process replacement result patterns.
552     os.indent(4) << "SmallVector<Value *, 4> tblgen_values;";
553     for (int i = replStartIndex; i < numResultPatterns; ++i) {
554       DagNode resultTree = pattern.getResultPattern(i);
555       auto val = handleResultPattern(resultTree, offsets[i], 0);
556       os.indent(4) << "\n";
557       // Resolve each symbol for all range use so that we can loop over them.
558       os << symbolInfoMap.getAllRangeUse(
559           val, "    for (auto *v : {0}) tblgen_values.push_back(v);", "\n");
560     }
561     os.indent(4) << "\n";
562     os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n";
563   }
564 
565   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
566 }
567 
568 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
569   return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
570 }
571 
572 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
573                                                 int resultIndex, int depth) {
574   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
575   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
576   LLVM_DEBUG(llvm::dbgs() << '\n');
577 
578   if (resultTree.isNativeCodeCall()) {
579     auto symbol = handleReplaceWithNativeCodeCall(resultTree);
580     symbolInfoMap.bindValue(symbol);
581     return symbol;
582   }
583 
584   if (resultTree.isReplaceWithValue()) {
585     return handleReplaceWithValue(resultTree);
586   }
587 
588   // Normal op creation.
589   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
590   if (resultTree.getSymbol().empty()) {
591     // This is an op not explicitly bound to a symbol in the rewrite rule.
592     // Register the auto-generated symbol for it.
593     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
594   }
595   return symbol;
596 }
597 
598 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
599   assert(tree.isReplaceWithValue());
600 
601   if (tree.getNumArgs() != 1) {
602     PrintFatalError(
603         loc, "replaceWithValue directive must take exactly one argument");
604   }
605 
606   if (!tree.getSymbol().empty()) {
607     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
608   }
609 
610   return tree.getArgName(0);
611 }
612 
613 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
614                                              StringRef patArgName) {
615   if (leaf.isConstantAttr()) {
616     auto constAttr = leaf.getAsConstantAttr();
617     return handleConstantAttr(constAttr.getAttribute(),
618                               constAttr.getConstantValue());
619   }
620   if (leaf.isEnumAttrCase()) {
621     auto enumCase = leaf.getAsEnumAttrCase();
622     if (enumCase.isStrCase())
623       return handleConstantAttr(enumCase, enumCase.getSymbol());
624     // This is an enum case backed by an IntegerAttr. We need to get its value
625     // to build the constant.
626     std::string val = std::to_string(enumCase.getValue());
627     return handleConstantAttr(enumCase, val);
628   }
629 
630   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
631   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
632   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
633     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
634                             << "' (via symbol ref)\n");
635     return argName;
636   }
637   if (leaf.isNativeCodeCall()) {
638     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
639     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
640                             << "' (via NativeCodeCall)\n");
641     return repl;
642   }
643   PrintFatalError(loc, "unhandled case when rewriting op");
644 }
645 
646 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
647   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
648   LLVM_DEBUG(tree.print(llvm::dbgs()));
649   LLVM_DEBUG(llvm::dbgs() << '\n');
650 
651   auto fmt = tree.getNativeCodeTemplate();
652   // TODO(b/138794486): replace formatv arguments with the exact specified args.
653   SmallVector<std::string, 8> attrs(8);
654   if (tree.getNumArgs() > 8) {
655     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
656                              Twine(tree.getNumArgs()));
657   }
658   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
659     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
660     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argment #" << i
661                             << " replacement: " << attrs[i] << "\n");
662   }
663   return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
664                attrs[5], attrs[6], attrs[7]);
665 }
666 
667 int PatternEmitter::getNodeValueCount(DagNode node) {
668   if (node.isOperation()) {
669     // If the op is bound to a symbol in the rewrite rule, query its result
670     // count from the symbol info map.
671     auto symbol = node.getSymbol();
672     if (!symbol.empty()) {
673       return symbolInfoMap.getStaticValueCount(symbol);
674     }
675     // Otherwise this is an unbound op; we will use all its results.
676     return pattern.getDialectOp(node).getNumResults();
677   }
678   // TODO(antiagainst): This considers all NativeCodeCall as returning one
679   // value. Enhance if multi-value ones are needed.
680   return 1;
681 }
682 
683 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
684                                              int depth) {
685   Operator &resultOp = tree.getDialectOp(opMap);
686   auto numOpArgs = resultOp.getNumArgs();
687 
688   if (numOpArgs != tree.getNumArgs()) {
689     PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
690                                  "{1} in pattern vs. {2} in definition",
691                                  resultOp.getOperationName(), tree.getNumArgs(),
692                                  numOpArgs));
693   }
694 
695   // A map to collect all nested DAG child nodes' names, with operand index as
696   // the key. This includes both bound and unbound child nodes.
697   llvm::DenseMap<unsigned, std::string> childNodeNames;
698 
699   // First go through all the child nodes who are nested DAG constructs to
700   // create ops for them and remember the symbol names for them, so that we can
701   // use the results in the current node. This happens in a recursive manner.
702   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
703     if (auto child = tree.getArgAsNestedDag(i)) {
704       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
705     }
706   }
707 
708   // The name of the local variable holding this op.
709   std::string valuePackName;
710   // The symbol for holding the result of this pattern. Note that the result of
711   // this pattern is not necessarily the same as the variable created by this
712   // pattern because we can use `__N` suffix to refer only a specific result if
713   // the generated op is a multi-result op.
714   std::string resultValue;
715   if (tree.getSymbol().empty()) {
716     // No symbol is explicitly bound to this op in the pattern. Generate a
717     // unique name.
718     valuePackName = resultValue = getUniqueSymbol(&resultOp);
719   } else {
720     resultValue = tree.getSymbol();
721     // Strip the index to get the name for the value pack and use it to name the
722     // local variable for the op.
723     valuePackName = SymbolInfoMap::getValuePackName(resultValue);
724   }
725 
726   // Create the local variable for this op.
727   os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
728                           valuePackName);
729   os.indent(4) << "{\n";
730 
731   // Now prepare operands used for building this op:
732   // * If the operand is non-variadic, we create a `Value*` local variable.
733   // * If the operand is variadic, we create a `SmallVector<Value*>` local
734   //   variable.
735 
736   int argIndex = 0;   // The current index to this op's ODS argument
737   int valueIndex = 0; // An index for uniquing local variable names.
738   for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
739     const auto &operand = resultOp.getOperand(argIndex);
740     std::string varName;
741     if (operand.isVariadic()) {
742       varName = formatv("tblgen_values_{0}", valueIndex++);
743       os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
744       std::string range;
745       if (tree.isNestedDagArg(argIndex)) {
746         range = childNodeNames[argIndex];
747       } else {
748         range = tree.getArgName(argIndex);
749       }
750       // Resolve the symbol for all range use so that we have a uniform way of
751       // capturing the values.
752       range = symbolInfoMap.getValueAndRangeUse(range);
753       os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range,
754                               varName);
755     } else {
756       varName = formatv("tblgen_value_{0}", valueIndex++);
757       os.indent(6) << formatv("Value *{0} = ", varName);
758       if (tree.isNestedDagArg(argIndex)) {
759         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
760       } else {
761         DagLeaf leaf = tree.getArgAsLeaf(argIndex);
762         auto symbol =
763             symbolInfoMap.getValueAndRangeUse(tree.getArgName(argIndex));
764         if (leaf.isNativeCodeCall()) {
765           os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
766         } else {
767           os << symbol;
768         }
769       }
770       os << ";\n";
771     }
772 
773     // Update to use the newly created local variable for building the op later.
774     childNodeNames[argIndex] = varName;
775   }
776 
777   // Then we create the builder call.
778 
779   // Right now we don't have general type inference in MLIR. Except a few
780   // special cases listed below, we need to supply types for all results
781   // when building an op.
782   bool isSameOperandsAndResultType =
783       resultOp.hasTrait("OpTrait::SameOperandsAndResultType");
784   bool isBroadcastable =
785       resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult");
786   bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType");
787   bool usePartialResults = valuePackName != resultValue;
788 
789   if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
790       usePartialResults || depth > 0 || resultIndex < 0) {
791     os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
792                             resultOp.getQualCppClassName());
793   } else {
794     // If depth == 0 and resultIndex >= 0, it means we are replacing the values
795     // generated from the source pattern root op. Then we can use the source
796     // pattern's value types to determine the value type of the generated op
797     // here.
798 
799     // We need to specify the types for all results.
800     int numResults = resultOp.getNumResults();
801     if (numResults != 0) {
802       os.indent(6) << "tblgen_types.clear();\n";
803       for (int i = 0; i < numResults; ++i) {
804         os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) "
805                                 "tblgen_types.push_back(v->getType());\n",
806                                 resultIndex + i);
807       }
808     }
809 
810     os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
811                             resultOp.getQualCppClassName());
812     if (numResults != 0)
813       os.indent(6) << ", tblgen_types";
814   }
815 
816   // Add operands for the builder all.
817   for (int i = 0; i < argIndex; ++i) {
818     const auto &operand = resultOp.getOperand(i);
819     // Start each operand on its own line.
820     (os << ",\n").indent(8);
821     if (!operand.name.empty()) {
822       os << "/*" << operand.name << "=*/";
823     }
824     os << childNodeNames[i];
825     // TODO(jpienaar): verify types
826   }
827 
828   // Add attributes for the builder call.
829   for (; argIndex != numOpArgs; ++argIndex) {
830     // Start each attribute on its own line.
831     (os << ",\n").indent(8);
832     // The argument in the op definition.
833     auto opArgName = resultOp.getArgName(argIndex);
834     if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
835       if (!subTree.isNativeCodeCall())
836         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
837                              "for creating attribute");
838       os << formatv("/*{0}=*/{1}", opArgName,
839                     handleReplaceWithNativeCodeCall(subTree));
840     } else {
841       auto leaf = tree.getArgAsLeaf(argIndex);
842       // The argument in the result DAG pattern.
843       auto patArgName = tree.getArgName(argIndex);
844       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
845         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
846         auto argument = resultOp.getArg(argIndex);
847         if (!argument.is<NamedAttribute *>())
848           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
849         if (!patArgName.empty())
850           os << "/*" << patArgName << "=*/";
851       } else {
852         os << "/*" << opArgName << "=*/";
853       }
854       os << handleOpArgument(leaf, patArgName);
855     }
856   }
857   os << "\n      );\n";
858   os.indent(4) << "}\n";
859 
860   return resultValue;
861 }
862 
863 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
864   emitSourceFileHeader("Rewriters", os);
865 
866   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
867   auto numPatterns = patterns.size();
868 
869   // We put the map here because it can be shared among multiple patterns.
870   RecordOperatorMap recordOpMap;
871 
872   std::vector<std::string> rewriterNames;
873   rewriterNames.reserve(numPatterns);
874 
875   std::string baseRewriterName = "GeneratedConvert";
876   int rewriterIndex = 0;
877 
878   for (Record *p : patterns) {
879     std::string name;
880     if (p->isAnonymous()) {
881       // If no name is provided, ensure unique rewriter names simply by
882       // appending unique suffix.
883       name = baseRewriterName + llvm::utostr(rewriterIndex++);
884     } else {
885       name = p->getName();
886     }
887     LLVM_DEBUG(llvm::dbgs()
888                << "=== start generating pattern '" << name << "' ===\n");
889     PatternEmitter(p, &recordOpMap, os).emit(name);
890     LLVM_DEBUG(llvm::dbgs()
891                << "=== done generating pattern '" << name << "' ===\n");
892     rewriterNames.push_back(std::move(name));
893   }
894 
895   // Emit function to add the generated matchers to the pattern list.
896   os << "void populateWithGenerated(MLIRContext *context, "
897      << "OwningRewritePatternList *patterns) {\n";
898   for (const auto &name : rewriterNames) {
899     os << "  patterns->insert<" << name << ">(context);\n";
900   }
901   os << "}\n";
902 }
903 
904 static mlir::GenRegistration
905     genRewriters("gen-rewriters", "Generate pattern rewriters",
906                  [](const RecordKeeper &records, raw_ostream &os) {
907                    emitRewriters(records, os);
908                    return false;
909                  });
910