xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 1358df19ca73165cdbd64d099cb5c7ccfd23e477)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Support/STLExtras.h"
23 #include "mlir/TableGen/Attribute.h"
24 #include "mlir/TableGen/Format.h"
25 #include "mlir/TableGen/GenInfo.h"
26 #include "mlir/TableGen/Operator.h"
27 #include "mlir/TableGen/Pattern.h"
28 #include "mlir/TableGen/Predicate.h"
29 #include "mlir/TableGen/Type.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/FormatAdapters.h"
35 #include "llvm/Support/PrettyStackTrace.h"
36 #include "llvm/Support/Signals.h"
37 #include "llvm/TableGen/Error.h"
38 #include "llvm/TableGen/Main.h"
39 #include "llvm/TableGen/Record.h"
40 #include "llvm/TableGen/TableGenBackend.h"
41 
42 using namespace llvm;
43 using namespace mlir;
44 using namespace mlir::tblgen;
45 
46 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
47 
48 namespace llvm {
49 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
50   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
51                      raw_ostream &os, StringRef style) {
52     os << v.first << ":" << v.second;
53   }
54 };
55 } // end namespace llvm
56 
57 //===----------------------------------------------------------------------===//
58 // PatternEmitter
59 //===----------------------------------------------------------------------===//
60 
61 namespace {
62 class PatternEmitter {
63 public:
64   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
65 
66   // Emits the mlir::RewritePattern struct named `rewriteName`.
67   void emit(StringRef rewriteName);
68 
69 private:
70   // Emits the code for matching ops.
71   void emitMatchLogic(DagNode tree);
72 
73   // Emits the code for rewriting ops.
74   void emitRewriteLogic();
75 
76   //===--------------------------------------------------------------------===//
77   // Match utilities
78   //===--------------------------------------------------------------------===//
79 
80   // Emits C++ statements for matching the op constrained by the given DAG
81   // `tree`.
82   void emitOpMatch(DagNode tree, int depth);
83 
84   // Emits C++ statements for matching the `index`-th argument of the given DAG
85   // `tree` as an operand.
86   void emitOperandMatch(DagNode tree, int index, int depth, int indent);
87 
88   // Emits C++ statements for matching the `index`-th argument of the given DAG
89   // `tree` as an attribute.
90   void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
91 
92   //===--------------------------------------------------------------------===//
93   // Rewrite utilities
94   //===--------------------------------------------------------------------===//
95 
96   // The entry point for handling a result pattern rooted at `resultTree`. This
97   // method dispatches to concrete handlers according to `resultTree`'s kind and
98   // returns a symbol representing the whole value pack. Callers are expected to
99   // further resolve the symbol according to the specific use case.
100   //
101   // `depth` is the nesting level of `resultTree`; 0 means top-level result
102   // pattern. For top-level result pattern, `resultIndex` indicates which result
103   // of the matched root op this pattern is intended to replace, which can be
104   // used to deduce the result type of the op generated from this result
105   // pattern.
106   std::string handleResultPattern(DagNode resultTree, int resultIndex,
107                                   int depth);
108 
109   // Emits the C++ statement to replace the matched DAG with a value built via
110   // calling native C++ code.
111   std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
112 
113   // Returns the C++ expression referencing the old value serving as the
114   // replacement.
115   std::string handleReplaceWithValue(DagNode tree);
116 
117   // Emits the C++ statement to build a new op out of the given DAG `tree` and
118   // returns the variable name that this op is assigned to. If the root op in
119   // DAG `tree` has a specified name, the created op will be assigned to a
120   // variable of the given name. Otherwise, a unique name will be used as the
121   // result value name.
122   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
123 
124   // Returns the C++ expression to construct a constant attribute of the given
125   // `value` for the given attribute kind `attr`.
126   std::string handleConstantAttr(Attribute attr, StringRef value);
127 
128   // Returns the C++ expression to build an argument from the given DAG `leaf`.
129   // `patArgName` is used to bound the argument to the source pattern.
130   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
131 
132   //===--------------------------------------------------------------------===//
133   // General utilities
134   //===--------------------------------------------------------------------===//
135 
136   // Collects all of the operations within the given dag tree.
137   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
138 
139   // Returns a unique symbol for a local variable of the given `op`.
140   std::string getUniqueSymbol(const Operator *op);
141 
142   //===--------------------------------------------------------------------===//
143   // Symbol utilities
144   //===--------------------------------------------------------------------===//
145 
146   // Returns how many static values the given DAG `node` correspond to.
147   int getNodeValueCount(DagNode node);
148 
149 private:
150   // Pattern instantiation location followed by the location of multiclass
151   // prototypes used. This is intended to be used as a whole to
152   // PrintFatalError() on errors.
153   ArrayRef<llvm::SMLoc> loc;
154 
155   // Op's TableGen Record to wrapper object.
156   RecordOperatorMap *opMap;
157 
158   // Handy wrapper for pattern being emitted.
159   Pattern pattern;
160 
161   // Map for all bound symbols' info.
162   SymbolInfoMap symbolInfoMap;
163 
164   // The next unused ID for newly created values.
165   unsigned nextValueId;
166 
167   raw_ostream &os;
168 
169   // Format contexts containing placeholder substitutions.
170   FmtContext fmtCtx;
171 
172   // Number of op processed.
173   int opCounter = 0;
174 };
175 } // end anonymous namespace
176 
177 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
178                                raw_ostream &os)
179     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
180       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
181   fmtCtx.withBuilder("rewriter");
182 }
183 
184 std::string PatternEmitter::handleConstantAttr(Attribute attr,
185                                                StringRef value) {
186   if (!attr.isConstBuildable())
187     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
188                              " does not have the 'constBuilderCall' field");
189 
190   // TODO(jpienaar): Verify the constants here
191   return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value);
192 }
193 
194 // Helper function to match patterns.
195 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
196   Operator &op = tree.getDialectOp(opMap);
197   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
198                           << op.getOperationName() << "' at depth " << depth
199                           << '\n');
200 
201   int indent = 4 + 2 * depth;
202   os.indent(indent) << formatv(
203       "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
204       depth, op.getQualCppClassName());
205   // Skip the operand matching at depth 0 as the pattern rewriter already does.
206   if (depth != 0) {
207     // Skip if there is no defining operation (e.g., arguments to function).
208     os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
209                                  depth);
210   }
211   if (tree.getNumArgs() != op.getNumArgs()) {
212     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
213                                  "pattern vs. {2} in definition",
214                                  op.getOperationName(), tree.getNumArgs(),
215                                  op.getNumArgs()));
216   }
217 
218   // If the operand's name is set, set to that variable.
219   auto name = tree.getSymbol();
220   if (!name.empty())
221     os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
222 
223   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
224     auto opArg = op.getArg(i);
225 
226     // Handle nested DAG construct first
227     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
228       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
229         if (operand->isVariadic()) {
230           auto error = formatv("use nested DAG construct to match op {0}'s "
231                                "variadic operand #{1} unsupported now",
232                                op.getOperationName(), i);
233           PrintFatalError(loc, error);
234         }
235       }
236       os.indent(indent) << "{\n";
237 
238       os.indent(indent + 2) << formatv(
239           "auto *op{0} = "
240           "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n",
241           depth + 1, depth, i);
242       emitOpMatch(argTree, depth + 1);
243       os.indent(indent + 2)
244           << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
245       os.indent(indent) << "}\n";
246       continue;
247     }
248 
249     // Next handle DAG leaf: operand or attribute
250     if (opArg.is<NamedTypeConstraint *>()) {
251       emitOperandMatch(tree, i, depth, indent);
252     } else if (opArg.is<NamedAttribute *>()) {
253       emitAttributeMatch(tree, i, depth, indent);
254     } else {
255       PrintFatalError(loc, "unhandled case when matching op");
256     }
257   }
258   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
259                           << op.getOperationName() << "' at depth " << depth
260                           << '\n');
261 }
262 
263 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
264                                       int indent) {
265   Operator &op = tree.getDialectOp(opMap);
266   auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
267   auto matcher = tree.getArgAsLeaf(index);
268 
269   // If a constraint is specified, we need to generate C++ statements to
270   // check the constraint.
271   if (!matcher.isUnspecified()) {
272     if (!matcher.isOperandMatcher()) {
273       PrintFatalError(
274           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
275                        op.getOperationName(), index + 1));
276     }
277 
278     // Only need to verify if the matcher's type is different from the one
279     // of op definition.
280     if (operand->constraint != matcher.getAsConstraint()) {
281       if (operand->isVariadic()) {
282         auto error = formatv(
283             "further constrain op {0}'s variadic operand #{1} unsupported now",
284             op.getOperationName(), index);
285         PrintFatalError(loc, error);
286       }
287       auto self =
288           formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
289                   depth, index);
290       os.indent(indent) << "if (!("
291                         << tgfmt(matcher.getConditionTemplate(),
292                                  &fmtCtx.withSelf(self))
293                         << ")) return matchFailure();\n";
294     }
295   }
296 
297   // Capture the value
298   auto name = tree.getArgName(index);
299   if (!name.empty()) {
300     os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
301                                  name, depth, index);
302   }
303 }
304 
305 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
306                                         int indent) {
307   Operator &op = tree.getDialectOp(opMap);
308   auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
309   const auto &attr = namedAttr->attr;
310 
311   os.indent(indent) << "{\n";
312   indent += 2;
313   os.indent(indent) << formatv(
314       "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
315       attr.getStorageType(), namedAttr->name);
316 
317   // TODO(antiagainst): This should use getter method to avoid duplication.
318   if (attr.hasDefaultValueInitializer()) {
319     os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
320                       << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
321                                attr.getDefaultValueInitializer())
322                       << ";\n";
323   } else if (attr.isOptional()) {
324     // For a missing attribute that is optional according to definition, we
325     // should just capture a mlir::Attribute() to signal the missing state.
326     // That is precisely what getAttr() returns on missing attributes.
327   } else {
328     os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
329   }
330 
331   auto matcher = tree.getArgAsLeaf(index);
332   if (!matcher.isUnspecified()) {
333     if (!matcher.isAttrMatcher()) {
334       PrintFatalError(
335           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
336                        op.getOperationName(), index + 1));
337     }
338 
339     // If a constraint is specified, we need to generate C++ statements to
340     // check the constraint.
341     os.indent(indent) << "if (!("
342                       << tgfmt(matcher.getConditionTemplate(),
343                                &fmtCtx.withSelf("tblgen_attr"))
344                       << ")) return matchFailure();\n";
345   }
346 
347   // Capture the value
348   auto name = tree.getArgName(index);
349   if (!name.empty()) {
350     os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
351   }
352 
353   indent -= 2;
354   os.indent(indent) << "}\n";
355 }
356 
357 void PatternEmitter::emitMatchLogic(DagNode tree) {
358   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
359   emitOpMatch(tree, 0);
360 
361   for (auto &appliedConstraint : pattern.getConstraints()) {
362     auto &constraint = appliedConstraint.constraint;
363     auto &entities = appliedConstraint.entities;
364 
365     auto condition = constraint.getConditionTemplate();
366     auto cmd = "if (!({0})) return matchFailure();\n";
367 
368     if (isa<TypeConstraint>(constraint)) {
369       auto self = formatv("({0}->getType())",
370                           symbolInfoMap.getValueAndRangeUse(entities.front()));
371       os.indent(4) << formatv(cmd,
372                               tgfmt(condition, &fmtCtx.withSelf(self.str())));
373     } else if (isa<AttrConstraint>(constraint)) {
374       PrintFatalError(
375           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
376     } else {
377       // TODO(b/138794486): replace formatv arguments with the exact specified
378       // args.
379       if (entities.size() > 4) {
380         PrintFatalError(loc, "only support up to 4-entity constraints now");
381       }
382       SmallVector<std::string, 4> names;
383       int i = 0;
384       for (int e = entities.size(); i < e; ++i)
385         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
386       std::string self = appliedConstraint.self;
387       if (!self.empty())
388         self = symbolInfoMap.getValueAndRangeUse(self);
389       for (; i < 4; ++i)
390         names.push_back("<unused>");
391       os.indent(4) << formatv(cmd,
392                               tgfmt(condition, &fmtCtx.withSelf(self), names[0],
393                                     names[1], names[2], names[3]));
394     }
395   }
396   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
397 }
398 
399 void PatternEmitter::collectOps(DagNode tree,
400                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
401   // Check if this tree is an operation.
402   if (tree.isOperation()) {
403     const Operator &op = tree.getDialectOp(opMap);
404     LLVM_DEBUG(llvm::dbgs()
405                << "found operation " << op.getOperationName() << '\n');
406     ops.insert(&op);
407   }
408 
409   // Recurse the arguments of the tree.
410   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
411     if (auto child = tree.getArgAsNestedDag(i))
412       collectOps(child, ops);
413 }
414 
415 void PatternEmitter::emit(StringRef rewriteName) {
416   // Get the DAG tree for the source pattern.
417   DagNode sourceTree = pattern.getSourcePattern();
418 
419   const Operator &rootOp = pattern.getSourceRootOp();
420   auto rootName = rootOp.getOperationName();
421 
422   // Collect the set of result operations.
423   llvm::SmallPtrSet<const Operator *, 4> resultOps;
424   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
425   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
426     collectOps(pattern.getResultPattern(i), resultOps);
427   }
428   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
429 
430   // Emit RewritePattern for Pattern.
431   auto locs = pattern.getLocation();
432   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
433                 make_range(locs.rbegin(), locs.rend()));
434   os << formatv(R"(struct {0} : public RewritePattern {
435   {0}(MLIRContext *context)
436       : RewritePattern("{1}", {{)",
437                 rewriteName, rootName);
438   // Sort result operators by name.
439   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
440                                                          resultOps.end());
441   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
442     return lhs->getOperationName() < rhs->getOperationName();
443   });
444   interleaveComma(sortedResultOps, os, [&](const Operator *op) {
445     os << '"' << op->getOperationName() << '"';
446   });
447   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
448 
449   // Emit matchAndRewrite() function.
450   os << R"(
451   PatternMatchResult matchAndRewrite(Operation *op0,
452                                      PatternRewriter &rewriter) const override {
453 )";
454 
455   // Register all symbols bound in the source pattern.
456   pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
457 
458   LLVM_DEBUG(
459       llvm::dbgs() << "start creating local variables for capturing matches\n");
460   os.indent(4) << "// Variables for capturing values and attributes used for "
461                   "creating ops\n";
462   // Create local variables for storing the arguments and results bound
463   // to symbols.
464   for (const auto &symbolInfoPair : symbolInfoMap) {
465     StringRef symbol = symbolInfoPair.getKey();
466     auto &info = symbolInfoPair.getValue();
467     os.indent(4) << info.getVarDecl(symbol);
468   }
469   // TODO(jpienaar): capture ops with consistent numbering so that it can be
470   // reused for fused loc.
471   os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
472                           pattern.getSourcePattern().getNumOps());
473   LLVM_DEBUG(
474       llvm::dbgs() << "done creating local variables for capturing matches\n");
475 
476   os.indent(4) << "// Match\n";
477   os.indent(4) << "tblgen_ops[0] = op0;\n";
478   emitMatchLogic(sourceTree);
479   os << "\n";
480 
481   os.indent(4) << "// Rewrite\n";
482   emitRewriteLogic();
483 
484   os.indent(4) << "return matchSuccess();\n";
485   os << "  };\n";
486   os << "};\n";
487 }
488 
489 void PatternEmitter::emitRewriteLogic() {
490   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
491   const Operator &rootOp = pattern.getSourceRootOp();
492   int numExpectedResults = rootOp.getNumResults();
493   int numResultPatterns = pattern.getNumResultPatterns();
494 
495   // First register all symbols bound to ops generated in result patterns.
496   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
497 
498   // Only the last N static values generated are used to replace the matched
499   // root N-result op. We need to calculate the starting index (of the results
500   // of the matched op) each result pattern is to replace.
501   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
502   // If we don't need to replace any value at all, set the replacement starting
503   // index as the number of result patterns so we skip all of them when trying
504   // to replace the matched op's results.
505   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
506   for (int i = numResultPatterns - 1; i >= 0; --i) {
507     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
508     offsets[i] = offsets[i + 1] - numValues;
509     if (offsets[i] == 0) {
510       if (replStartIndex == -1)
511         replStartIndex = i;
512     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
513       auto error = formatv(
514           "cannot use the same multi-result op '{0}' to generate both "
515           "auxiliary values and values to be used for replacing the matched op",
516           pattern.getResultPattern(i).getSymbol());
517       PrintFatalError(loc, error);
518     }
519   }
520 
521   if (offsets.front() > 0) {
522     const char error[] = "no enough values generated to replace the matched op";
523     PrintFatalError(loc, error);
524   }
525 
526   os.indent(4) << "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n";
527   os.indent(4) << "auto loc = rewriter.getFusedLoc({";
528   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
529     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
530   }
531   os << "}); (void)loc;\n";
532 
533   // Process each result pattern and record the result symbol.
534   llvm::SmallVector<std::string, 2> resultValues;
535   for (int i = 0; i < numResultPatterns; ++i) {
536     DagNode resultTree = pattern.getResultPattern(i);
537     resultValues.push_back(handleResultPattern(resultTree, offsets[i], 0));
538   }
539 
540   os.indent(4) << "SmallVector<Value *, 4> tblgen_values;";
541   // Only use the last portion for replacing the matched root op's results.
542   auto range = llvm::makeArrayRef(resultValues).drop_front(replStartIndex);
543   for (const auto &val : range) {
544     os.indent(4) << "\n";
545     // Resolve each symbol for all range use so that we can loop over them.
546     os << symbolInfoMap.getAllRangeUse(
547         val, "    for (auto *v : {0}) tblgen_values.push_back(v);", "\n");
548   }
549   os.indent(4) << "\n";
550   os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n";
551   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
552 }
553 
554 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
555   return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
556 }
557 
558 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
559                                                 int resultIndex, int depth) {
560   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
561   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
562   LLVM_DEBUG(llvm::dbgs() << '\n');
563 
564   if (resultTree.isNativeCodeCall()) {
565     auto symbol = handleReplaceWithNativeCodeCall(resultTree);
566     symbolInfoMap.bindValue(symbol);
567     return symbol;
568   }
569 
570   if (resultTree.isReplaceWithValue()) {
571     return handleReplaceWithValue(resultTree);
572   }
573 
574   // Normal op creation.
575   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
576   if (resultTree.getSymbol().empty()) {
577     // This is an op not explicitly bound to a symbol in the rewrite rule.
578     // Register the auto-generated symbol for it.
579     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
580   }
581   return symbol;
582 }
583 
584 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
585   assert(tree.isReplaceWithValue());
586 
587   if (tree.getNumArgs() != 1) {
588     PrintFatalError(
589         loc, "replaceWithValue directive must take exactly one argument");
590   }
591 
592   if (!tree.getSymbol().empty()) {
593     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
594   }
595 
596   return tree.getArgName(0);
597 }
598 
599 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
600                                              StringRef patArgName) {
601   if (leaf.isConstantAttr()) {
602     auto constAttr = leaf.getAsConstantAttr();
603     return handleConstantAttr(constAttr.getAttribute(),
604                               constAttr.getConstantValue());
605   }
606   if (leaf.isEnumAttrCase()) {
607     auto enumCase = leaf.getAsEnumAttrCase();
608     if (enumCase.isStrCase())
609       return handleConstantAttr(enumCase, enumCase.getSymbol());
610     // This is an enum case backed by an IntegerAttr. We need to get its value
611     // to build the constant.
612     std::string val = std::to_string(enumCase.getValue());
613     return handleConstantAttr(enumCase, val);
614   }
615 
616   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
617   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
618   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
619     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
620                             << "' (via symbol ref)\n");
621     return argName;
622   }
623   if (leaf.isNativeCodeCall()) {
624     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
625     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
626                             << "' (via NativeCodeCall)\n");
627     return repl;
628   }
629   PrintFatalError(loc, "unhandled case when rewriting op");
630 }
631 
632 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
633   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
634   LLVM_DEBUG(tree.print(llvm::dbgs()));
635   LLVM_DEBUG(llvm::dbgs() << '\n');
636 
637   auto fmt = tree.getNativeCodeTemplate();
638   // TODO(b/138794486): replace formatv arguments with the exact specified args.
639   SmallVector<std::string, 8> attrs(8);
640   if (tree.getNumArgs() > 8) {
641     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
642                              Twine(tree.getNumArgs()));
643   }
644   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
645     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
646     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argment #" << i
647                             << " replacement: " << attrs[i] << "\n");
648   }
649   return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
650                attrs[5], attrs[6], attrs[7]);
651 }
652 
653 int PatternEmitter::getNodeValueCount(DagNode node) {
654   if (node.isOperation()) {
655     // If the op is bound to a symbol in the rewrite rule, query its result
656     // count from the symbol info map.
657     auto symbol = node.getSymbol();
658     if (!symbol.empty()) {
659       return symbolInfoMap.getStaticValueCount(symbol);
660     }
661     // Otherwise this is an unbound op; we will use all its results.
662     return pattern.getDialectOp(node).getNumResults();
663   }
664   // TODO(antiagainst): This considers all NativeCodeCall as returning one
665   // value. Enhance if multi-value ones are needed.
666   return 1;
667 }
668 
669 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
670                                              int depth) {
671   Operator &resultOp = tree.getDialectOp(opMap);
672   auto numOpArgs = resultOp.getNumArgs();
673 
674   if (numOpArgs != tree.getNumArgs()) {
675     PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
676                                  "{1} in pattern vs. {2} in definition",
677                                  resultOp.getOperationName(), tree.getNumArgs(),
678                                  numOpArgs));
679   }
680 
681   // A map to collect all nested DAG child nodes' names, with operand index as
682   // the key. This includes both bound and unbound child nodes.
683   llvm::DenseMap<unsigned, std::string> childNodeNames;
684 
685   // First go through all the child nodes who are nested DAG constructs to
686   // create ops for them and remember the symbol names for them, so that we can
687   // use the results in the current node. This happens in a recursive manner.
688   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
689     if (auto child = tree.getArgAsNestedDag(i)) {
690       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
691     }
692   }
693 
694   // The name of the local variable holding this op.
695   std::string valuePackName;
696   // The symbol for holding the result of this pattern. Note that the result of
697   // this pattern is not necessarily the same as the variable created by this
698   // pattern because we can use `__N` suffix to refer only a specific result if
699   // the generated op is a multi-result op.
700   std::string resultValue;
701   if (tree.getSymbol().empty()) {
702     // No symbol is explicitly bound to this op in the pattern. Generate a
703     // unique name.
704     valuePackName = resultValue = getUniqueSymbol(&resultOp);
705   } else {
706     resultValue = tree.getSymbol();
707     // Strip the index to get the name for the value pack and use it to name the
708     // local variable for the op.
709     valuePackName = SymbolInfoMap::getValuePackName(resultValue);
710   }
711 
712   // Create the local variable for this op.
713   os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
714                           valuePackName);
715   os.indent(4) << "{\n";
716 
717   // Now prepare operands used for building this op:
718   // * If the operand is non-variadic, we create a `Value*` local variable.
719   // * If the operand is variadic, we create a `SmallVector<Value*>` local
720   //   variable.
721 
722   int argIndex = 0;   // The current index to this op's ODS argument
723   int valueIndex = 0; // An index for uniquing local variable names.
724   for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
725     const auto &operand = resultOp.getOperand(argIndex);
726     std::string varName;
727     if (operand.isVariadic()) {
728       varName = formatv("tblgen_values_{0}", valueIndex++);
729       os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
730       std::string range;
731       if (tree.isNestedDagArg(argIndex)) {
732         range = childNodeNames[argIndex];
733       } else {
734         range = tree.getArgName(argIndex);
735       }
736       // Resolve the symbol for all range use so that we have a uniform way of
737       // capturing the values.
738       range = symbolInfoMap.getValueAndRangeUse(range);
739       os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range,
740                               varName);
741     } else {
742       varName = formatv("tblgen_value_{0}", valueIndex++);
743       os.indent(6) << formatv("Value *{0} = ", varName);
744       if (tree.isNestedDagArg(argIndex)) {
745         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
746       } else {
747         DagLeaf leaf = tree.getArgAsLeaf(argIndex);
748         auto symbol =
749             symbolInfoMap.getValueAndRangeUse(tree.getArgName(argIndex));
750         if (leaf.isNativeCodeCall()) {
751           os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
752         } else {
753           os << symbol;
754         }
755       }
756       os << ";\n";
757     }
758 
759     // Update to use the newly created local variable for building the op later.
760     childNodeNames[argIndex] = varName;
761   }
762 
763   // Then we create the builder call.
764 
765   // Right now we don't have general type inference in MLIR. Except a few
766   // special cases listed below, we need to supply types for all results
767   // when building an op.
768   bool isSameOperandsAndResultType =
769       resultOp.hasTrait("OpTrait::SameOperandsAndResultType");
770   bool isBroadcastable =
771       resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult");
772   bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType");
773   bool usePartialResults = valuePackName != resultValue;
774 
775   if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
776       usePartialResults || depth > 0 || resultIndex < 0) {
777     os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
778                             resultOp.getQualCppClassName());
779   } else {
780     // If depth == 0 and resultIndex >= 0, it means we are replacing the values
781     // generated from the source pattern root op. Then we can use the source
782     // pattern's value types to determine the value type of the generated op
783     // here.
784 
785     // We need to specify the types for all results.
786     int numResults = resultOp.getNumResults();
787     if (numResults != 0) {
788       os.indent(6) << "tblgen_types.clear();\n";
789       for (int i = 0; i < numResults; ++i) {
790         os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) "
791                                 "tblgen_types.push_back(v->getType());\n",
792                                 resultIndex + i);
793       }
794     }
795 
796     os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
797                             resultOp.getQualCppClassName());
798     if (numResults != 0)
799       os.indent(6) << ", tblgen_types";
800   }
801 
802   // Add operands for the builder all.
803   for (int i = 0; i < argIndex; ++i) {
804     const auto &operand = resultOp.getOperand(i);
805     // Start each operand on its own line.
806     (os << ",\n").indent(8);
807     if (!operand.name.empty()) {
808       os << "/*" << operand.name << "=*/";
809     }
810     os << childNodeNames[i];
811     // TODO(jpienaar): verify types
812   }
813 
814   // Add attributes for the builder call.
815   for (; argIndex != numOpArgs; ++argIndex) {
816     // Start each attribute on its own line.
817     (os << ",\n").indent(8);
818     // The argument in the op definition.
819     auto opArgName = resultOp.getArgName(argIndex);
820     if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
821       if (!subTree.isNativeCodeCall())
822         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
823                              "for creating attribute");
824       os << formatv("/*{0}=*/{1}", opArgName,
825                     handleReplaceWithNativeCodeCall(subTree));
826     } else {
827       auto leaf = tree.getArgAsLeaf(argIndex);
828       // The argument in the result DAG pattern.
829       auto patArgName = tree.getArgName(argIndex);
830       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
831         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
832         auto argument = resultOp.getArg(argIndex);
833         if (!argument.is<NamedAttribute *>())
834           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
835         if (!patArgName.empty())
836           os << "/*" << patArgName << "=*/";
837       } else {
838         os << "/*" << opArgName << "=*/";
839       }
840       os << handleOpArgument(leaf, patArgName);
841     }
842   }
843   os << "\n      );\n";
844   os.indent(4) << "}\n";
845 
846   return resultValue;
847 }
848 
849 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
850   emitSourceFileHeader("Rewriters", os);
851 
852   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
853   auto numPatterns = patterns.size();
854 
855   // We put the map here because it can be shared among multiple patterns.
856   RecordOperatorMap recordOpMap;
857 
858   std::vector<std::string> rewriterNames;
859   rewriterNames.reserve(numPatterns);
860 
861   std::string baseRewriterName = "GeneratedConvert";
862   int rewriterIndex = 0;
863 
864   for (Record *p : patterns) {
865     std::string name;
866     if (p->isAnonymous()) {
867       // If no name is provided, ensure unique rewriter names simply by
868       // appending unique suffix.
869       name = baseRewriterName + llvm::utostr(rewriterIndex++);
870     } else {
871       name = p->getName();
872     }
873     LLVM_DEBUG(llvm::dbgs()
874                << "=== start generating pattern '" << name << "' ===\n");
875     PatternEmitter(p, &recordOpMap, os).emit(name);
876     LLVM_DEBUG(llvm::dbgs()
877                << "=== done generating pattern '" << name << "' ===\n");
878     rewriterNames.push_back(std::move(name));
879   }
880 
881   // Emit function to add the generated matchers to the pattern list.
882   os << "void populateWithGenerated(MLIRContext *context, "
883      << "OwningRewritePatternList *patterns) {\n";
884   for (const auto &name : rewriterNames) {
885     os << "  patterns->insert<" << name << ">(context);\n";
886   }
887   os << "}\n";
888 }
889 
890 static mlir::GenRegistration
891     genRewriters("gen-rewriters", "Generate pattern rewriters",
892                  [](const RecordKeeper &records, raw_ostream &os) {
893                    emitRewriters(records, os);
894                    return false;
895                  });
896