xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision c7dab559bae8eeb0b14038aa592f9fb03025bcc1)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Support/STLExtras.h"
23 #include "mlir/TableGen/Attribute.h"
24 #include "mlir/TableGen/Format.h"
25 #include "mlir/TableGen/GenInfo.h"
26 #include "mlir/TableGen/Operator.h"
27 #include "mlir/TableGen/Pattern.h"
28 #include "mlir/TableGen/Predicate.h"
29 #include "mlir/TableGen/Type.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/FormatAdapters.h"
34 #include "llvm/Support/PrettyStackTrace.h"
35 #include "llvm/Support/Signals.h"
36 #include "llvm/TableGen/Error.h"
37 #include "llvm/TableGen/Main.h"
38 #include "llvm/TableGen/Record.h"
39 #include "llvm/TableGen/TableGenBackend.h"
40 
41 using namespace llvm;
42 using namespace mlir;
43 using namespace mlir::tblgen;
44 
45 namespace llvm {
46 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
47   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
48                      raw_ostream &os, StringRef style) {
49     os << v.first << ":" << v.second;
50   }
51 };
52 } // end namespace llvm
53 
54 // Returns the bound symbol for the given op argument or op named `symbol`.
55 //
56 // Arguments and ops bound in the source pattern are grouped into a
57 // transient `PatternState` struct. This struct can be accessed in both
58 // `match()` and `rewrite()` via the local variable named as `s`.
59 static Twine getBoundSymbol(const StringRef &symbol) {
60   return Twine("s.") + symbol;
61 }
62 
63 // Gets the dynamic value pack's name by removing the index suffix from
64 // `symbol`. Returns `symbol` itself if it does not contain an index.
65 //
66 // We can use `name__<index>` to access the `<index>`-th value in the dynamic
67 // value pack bound to `name`. `name` is typically the results of an
68 // multi-result op.
69 static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) {
70   StringRef name, indexStr;
71   unsigned idx = 0;
72   std::tie(name, indexStr) = symbol.rsplit("__");
73   if (indexStr.consumeInteger(10, idx)) {
74     // The second part is not an index.
75     return symbol;
76   }
77   if (index)
78     *index = idx;
79   return name;
80 }
81 
82 // Formats all values from a dynamic value pack `symbol` according to the given
83 // `fmt` string. The `fmt` string should use `{0}` as a placeholder for `symbol`
84 // and `{1}` as a placeholder for the value index, which will be offsetted by
85 // `offset`. The `symbol` value pack has a total of `count` values.
86 //
87 // This extracts one value from the pack if `symbol` contains an index,
88 // otherwise it extracts all values sequentially and returns them as a
89 // comma-separated list.
90 static std::string formatValuePack(const char *fmt, StringRef symbol,
91                                    unsigned count, unsigned offset) {
92   auto getNthValue = [fmt, offset](StringRef results,
93                                    unsigned index) -> std::string {
94     return formatv(fmt, results, index + offset);
95   };
96 
97   unsigned index = 0;
98   StringRef name = getValuePackName(symbol, &index);
99   if (name != symbol) {
100     // The symbol contains an index.
101     return getNthValue(name, index);
102   }
103 
104   // The symbol does not contain an index. Treat the symbol as a whole.
105   SmallVector<std::string, 4> values;
106   values.reserve(count);
107   for (unsigned i = 0; i < count; ++i)
108     values.emplace_back(getNthValue(symbol, i));
109   return llvm::join(values, ", ");
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // PatternSymbolResolver
114 //===----------------------------------------------------------------------===//
115 
116 namespace {
117 // A class for resolving symbols bound in patterns.
118 //
119 // Symbols can be bound to op arguments and ops in the source pattern and ops
120 // in result patterns. For example, in
121 //
122 // ```
123 // def : Pattern<(SrcOp:$op1 $arg0, %arg1),
124 //               [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>;
125 // ```
126 //
127 // `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`.
128 // `$op2` is bound to `ResOp1`.
129 //
130 // If a symbol binds to a multi-result op and it does not have the `__N`
131 // suffix, the symbol is expanded to the whole value pack generated by the
132 // multi-result op. If the symbol has a `__N` suffix, then it will expand to
133 // only the N-th result.
134 //
135 // This class keeps track of such symbols and translates them into their bound
136 // values.
137 //
138 // Note that we also generate local variables for unnamed DAG nodes, like
139 // `(ResOp3)` in the above. Since we don't bind a symbol to the op, the
140 // generated local variable will be implicitly named. Those implicit names are
141 // not tracked in this class.
142 class PatternSymbolResolver {
143 public:
144   PatternSymbolResolver(const StringMap<Argument> &srcArgs,
145                         const StringMap<const Operator *> &srcOperations);
146 
147   // Marks the given `symbol` as bound to a value pack with `numValues` and
148   // returns true on success. Returns false if the `symbol` is already bound.
149   bool add(StringRef symbol, int numValues);
150 
151   // Queries the substitution for the given `symbol`. Returns empty string if
152   // symbol not found. If the symbol represents a value pack, returns all the
153   // values separated via comma.
154   std::string query(StringRef symbol) const;
155 
156 private:
157   // Symbols bound to arguments in source pattern.
158   const StringMap<Argument> &sourceArguments;
159   // Symbols bound to ops (for their results) in source pattern.
160   const StringMap<const Operator *> &sourceOps;
161   // Symbols bound to ops (for their results) in result patterns.
162   // Key: symbol; value: number of values inside the pack
163   StringMap<int> resultOps;
164 };
165 } // end anonymous namespace
166 
167 PatternSymbolResolver::PatternSymbolResolver(
168     const StringMap<Argument> &srcArgs,
169     const StringMap<const Operator *> &srcOperations)
170     : sourceArguments(srcArgs), sourceOps(srcOperations) {}
171 
172 bool PatternSymbolResolver::add(StringRef symbol, int numValues) {
173   StringRef name = getValuePackName(symbol);
174   return resultOps.try_emplace(name, numValues).second;
175 }
176 
177 std::string PatternSymbolResolver::query(StringRef symbol) const {
178   {
179     StringRef name = getValuePackName(symbol);
180     auto it = resultOps.find(name);
181     if (it != resultOps.end())
182       return formatValuePack("{0}.getOperation()->getResult({1})", symbol,
183                              it->second, /*offset=*/0);
184   }
185   {
186     auto it = sourceArguments.find(symbol);
187     if (it != sourceArguments.end())
188       return getBoundSymbol(symbol).str();
189   }
190   {
191     StringRef name = getValuePackName(symbol);
192     auto it = sourceOps.find(name);
193     if (it != sourceOps.end())
194       return formatValuePack("{0}->getResult({1})",
195                              getBoundSymbol(symbol).str(),
196                              it->second->getNumResults(), /*offset=*/0);
197   }
198   return {};
199 }
200 
201 //===----------------------------------------------------------------------===//
202 // PatternEmitter
203 //===----------------------------------------------------------------------===//
204 
205 namespace {
206 class PatternEmitter {
207 public:
208   static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper,
209                    raw_ostream &os);
210 
211 private:
212   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
213 
214   // Emits the mlir::RewritePattern struct named `rewriteName`.
215   void emit(StringRef rewriteName);
216 
217   // Emits the match() method.
218   void emitMatchMethod(DagNode tree);
219 
220   // Collects all of the operations within the given dag tree.
221   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
222 
223   // Emits the rewrite() method.
224   void emitRewriteMethod();
225 
226   // Emits C++ statements for matching the op constrained by the given DAG
227   // `tree`.
228   void emitOpMatch(DagNode tree, int depth);
229 
230   // Emits C++ statements for matching the `index`-th argument of the given DAG
231   // `tree` as an operand.
232   void emitOperandMatch(DagNode tree, int index, int depth, int indent);
233 
234   // Emits C++ statements for matching the `index`-th argument of the given DAG
235   // `tree` as an attribute.
236   void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
237 
238   // Returns a unique name for an value of the given `op`.
239   std::string getUniqueValueName(const Operator *op);
240 
241   // Entry point for handling a rewrite pattern rooted at `resultTree` and
242   // dispatches to concrete handlers. The given tree is the `resultIndex`-th
243   // argument of the enclosing DAG.
244   std::string handleRewritePattern(DagNode resultTree, int resultIndex,
245                                    int depth);
246 
247   // Emits the C++ statement to replace the matched DAG with a value built via
248   // calling native C++ code.
249   std::string emitReplaceWithNativeCodeCall(DagNode resultTree);
250 
251   // Returns the C++ expression referencing the old value serving as the
252   // replacement.
253   std::string handleReplaceWithValue(DagNode tree);
254 
255   // Handles the `verifyUnusedValue` directive: emitting C++ statements to check
256   // the `index`-th result of the source op is not used.
257   void handleVerifyUnusedValue(DagNode tree, int index);
258 
259   // Emits the C++ statement to build a new op out of the given DAG `tree` and
260   // returns the variable name that this op is assigned to. If the root op in
261   // DAG `tree` has a specified name, the created op will be assigned to a
262   // variable of the given name. Otherwise, a unique name will be used as the
263   // result value name.
264   std::string emitOpCreate(DagNode tree, int resultIndex, int depth);
265 
266   // Returns the C++ expression to construct a constant attribute of the given
267   // `value` for the given attribute kind `attr`.
268   std::string handleConstantAttr(Attribute attr, StringRef value);
269 
270   // Returns the C++ expression to build an argument from the given DAG `leaf`.
271   // `patArgName` is used to bound the argument to the source pattern.
272   std::string handleOpArgument(DagLeaf leaf, llvm::StringRef patArgName);
273 
274   // Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol
275   // is already bound.
276   void addSymbol(DagNode node);
277 
278   // Gets the substitution for `symbol`. Aborts if `symbol` is not bound.
279   std::string resolveSymbol(StringRef symbol);
280 
281 private:
282   // Pattern instantiation location followed by the location of multiclass
283   // prototypes used. This is intended to be used as a whole to
284   // PrintFatalError() on errors.
285   ArrayRef<llvm::SMLoc> loc;
286   // Op's TableGen Record to wrapper object
287   RecordOperatorMap *opMap;
288   // Handy wrapper for pattern being emitted
289   Pattern pattern;
290   PatternSymbolResolver symbolResolver;
291   // The next unused ID for newly created values
292   unsigned nextValueId;
293   raw_ostream &os;
294 
295   // Format contexts containing placeholder substitutations for match().
296   FmtContext matchCtx;
297   // Format contexts containing placeholder substitutations for rewrite().
298   FmtContext rewriteCtx;
299 
300   // Number of op processed.
301   int opCounter = 0;
302 };
303 } // end anonymous namespace
304 
305 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
306                                raw_ostream &os)
307     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
308       symbolResolver(pattern.getSourcePatternBoundArgs(),
309                      pattern.getSourcePatternBoundOps()),
310       nextValueId(0), os(os) {
311   matchCtx.withBuilder("mlir::Builder(ctx)");
312   rewriteCtx.withBuilder("rewriter");
313 }
314 
315 std::string PatternEmitter::handleConstantAttr(Attribute attr,
316                                                StringRef value) {
317   if (!attr.isConstBuildable())
318     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
319                              " does not have the 'constBuilderCall' field");
320 
321   // TODO(jpienaar): Verify the constants here
322   return tgfmt(attr.getConstBuilderTemplate(),
323                &rewriteCtx.withBuilder("rewriter"), value);
324 }
325 
326 // Helper function to match patterns.
327 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
328   Operator &op = tree.getDialectOp(opMap);
329   if (op.isVariadic()) {
330     PrintFatalError(loc, formatv("matching op '{0}' with variadic "
331                                  "operands/results is unsupported right now",
332                                  op.getOperationName()));
333   }
334 
335   int indent = 4 + 2 * depth;
336   // Skip the operand matching at depth 0 as the pattern rewriter already does.
337   if (depth != 0) {
338     // Skip if there is no defining operation (e.g., arguments to function).
339     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
340     os.indent(indent) << formatv(
341         "if (!isa<{1}>(op{0})) return matchFailure();\n", depth,
342         op.getQualCppClassName());
343   }
344   if (tree.getNumArgs() != op.getNumArgs()) {
345     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
346                                  "pattern vs. {2} in definition",
347                                  op.getOperationName(), tree.getNumArgs(),
348                                  op.getNumArgs()));
349   }
350 
351   // If the operand's name is set, set to that variable.
352   auto name = tree.getOpName();
353   if (!name.empty())
354     os.indent(indent) << formatv("{0} = op{1};\n", getBoundSymbol(name), depth);
355 
356   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
357     auto opArg = op.getArg(i);
358 
359     // Handle nested DAG construct first
360     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
361       os.indent(indent) << "{\n";
362       os.indent(indent + 2)
363           << formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
364                      depth + 1, depth, i);
365       emitOpMatch(argTree, depth + 1);
366       os.indent(indent + 2)
367           << formatv("s.autogeneratedRewritePatternOps[{0}] = op{1};\n",
368                      ++opCounter, depth + 1);
369       os.indent(indent) << "}\n";
370       continue;
371     }
372 
373     // Next handle DAG leaf: operand or attribute
374     if (opArg.is<NamedTypeConstraint *>()) {
375       emitOperandMatch(tree, i, depth, indent);
376     } else if (opArg.is<NamedAttribute *>()) {
377       emitAttributeMatch(tree, i, depth, indent);
378     } else {
379       PrintFatalError(loc, "unhandled case when matching op");
380     }
381   }
382 }
383 
384 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
385                                       int indent) {
386   Operator &op = tree.getDialectOp(opMap);
387   auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
388   auto matcher = tree.getArgAsLeaf(index);
389 
390   // If a constraint is specified, we need to generate C++ statements to
391   // check the constraint.
392   if (!matcher.isUnspecified()) {
393     if (!matcher.isOperandMatcher()) {
394       PrintFatalError(
395           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
396                        op.getOperationName(), index + 1));
397     }
398 
399     // Only need to verify if the matcher's type is different from the one
400     // of op definition.
401     if (operand->constraint != matcher.getAsConstraint()) {
402       auto self = formatv("op{0}->getOperand({1})->getType()", depth, index);
403       os.indent(indent) << "if (!("
404                         << tgfmt(matcher.getConditionTemplate(),
405                                  &matchCtx.withSelf(self))
406                         << ")) return matchFailure();\n";
407     }
408   }
409 
410   // Capture the value
411   auto name = tree.getArgName(index);
412   if (!name.empty()) {
413     os.indent(indent) << getBoundSymbol(name) << " = op" << depth
414                       << "->getOperand(" << index << ");\n";
415   }
416 }
417 
418 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
419                                         int indent) {
420   Operator &op = tree.getDialectOp(opMap);
421   auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
422   const auto &attr = namedAttr->attr;
423 
424   os.indent(indent) << "{\n";
425   indent += 2;
426   os.indent(indent) << formatv(
427       "auto attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
428       attr.getStorageType(), namedAttr->name);
429 
430   // TODO(antiagainst): This should use getter method to avoid duplication.
431   if (attr.hasDefaultValueInitializer()) {
432     os.indent(indent) << "if (!attr) attr = "
433                       << tgfmt(attr.getConstBuilderTemplate(), &matchCtx,
434                                attr.getDefaultValueInitializer())
435                       << ";\n";
436   } else if (attr.isOptional()) {
437     // For a missing attribut that is optional according to definition, we
438     // should just capature a mlir::Attribute() to signal the missing state.
439     // That is precisely what getAttr() returns on missing attributes.
440   } else {
441     os.indent(indent) << "if (!attr) return matchFailure();\n";
442   }
443 
444   auto matcher = tree.getArgAsLeaf(index);
445   if (!matcher.isUnspecified()) {
446     if (!matcher.isAttrMatcher()) {
447       PrintFatalError(
448           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
449                        op.getOperationName(), index + 1));
450     }
451 
452     // If a constraint is specified, we need to generate C++ statements to
453     // check the constraint.
454     os.indent(indent) << "if (!("
455                       << tgfmt(matcher.getConditionTemplate(),
456                                &matchCtx.withSelf("attr"))
457                       << ")) return matchFailure();\n";
458   }
459 
460   // Capture the value
461   auto name = tree.getArgName(index);
462   if (!name.empty()) {
463     os.indent(indent) << getBoundSymbol(name) << " = attr;\n";
464   }
465 
466   indent -= 2;
467   os.indent(indent) << "}\n";
468 }
469 
470 void PatternEmitter::emitMatchMethod(DagNode tree) {
471   // Emit the heading.
472   os << R"(
473   PatternMatchResult match(Operation *op0) const override {
474     auto ctx = op0->getContext(); (void)ctx;
475     auto state = llvm::make_unique<MatchedState>();
476     auto &s = *state;
477     s.autogeneratedRewritePatternOps[0] = op0;
478 )";
479 
480   // The rewrite pattern may specify that certain outputs should be unused in
481   // the source IR. Check it here.
482   for (int i = 0, e = pattern.getNumResults(); i < e; ++i) {
483     DagNode resultTree = pattern.getResultPattern(i);
484     if (resultTree.isVerifyUnusedValue()) {
485       handleVerifyUnusedValue(resultTree, i);
486     }
487   }
488 
489   emitOpMatch(tree, 0);
490 
491   for (auto &appliedConstraint : pattern.getConstraints()) {
492     auto &constraint = appliedConstraint.constraint;
493     auto &entities = appliedConstraint.entities;
494 
495     auto condition = constraint.getConditionTemplate();
496     auto cmd = "if (!({0})) return matchFailure();\n";
497 
498     if (isa<TypeConstraint>(constraint)) {
499       auto self = formatv("({0}->getType())", resolveSymbol(entities.front()));
500       os.indent(4) << formatv(cmd,
501                               tgfmt(condition, &matchCtx.withSelf(self.str())));
502     } else if (isa<AttrConstraint>(constraint)) {
503       PrintFatalError(
504           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
505     } else {
506       // TODO(fengliuai): replace formatv arguments with the exact specified
507       // args.
508       if (entities.size() > 4) {
509         PrintFatalError(loc, "only support up to 4-entity constraints now");
510       }
511       SmallVector<std::string, 4> names;
512       int i = 0;
513       for (int e = entities.size(); i < e; ++i)
514         names.push_back(resolveSymbol(entities[i]));
515       for (; i < 4; ++i)
516         names.push_back("<unused>");
517       os.indent(4) << formatv(cmd, tgfmt(condition, &matchCtx, names[0],
518                                          names[1], names[2], names[3]));
519     }
520   }
521 
522   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
523 }
524 
525 void PatternEmitter::collectOps(DagNode tree,
526                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
527   // Check if this tree is an operation.
528   if (tree.isOperation())
529     ops.insert(&tree.getDialectOp(opMap));
530 
531   // Recurse the arguments of the tree.
532   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
533     if (auto child = tree.getArgAsNestedDag(i))
534       collectOps(child, ops);
535 }
536 
537 void PatternEmitter::emit(StringRef rewriteName) {
538   // Get the DAG tree for the source pattern
539   DagNode tree = pattern.getSourcePattern();
540 
541   const Operator &rootOp = pattern.getSourceRootOp();
542   auto rootName = rootOp.getOperationName();
543 
544   // Collect the set of result operations.
545   llvm::SmallPtrSet<const Operator *, 4> results;
546   for (unsigned i = 0, e = pattern.getNumResults(); i != e; ++i)
547     collectOps(pattern.getResultPattern(i), results);
548 
549   // Emit RewritePattern for Pattern.
550   auto locs = pattern.getLocation();
551   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
552                 make_range(locs.rbegin(), locs.rend()));
553   os << formatv(R"(struct {0} : public RewritePattern {
554   {0}(MLIRContext *context) : RewritePattern("{1}", {{)",
555                 rewriteName, rootName);
556   interleaveComma(results, os, [&](const Operator *op) {
557     os << '"' << op->getOperationName() << '"';
558   });
559   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
560 
561   // Emit matched state.
562   os << "  struct MatchedState : public PatternState {\n";
563   for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
564     auto fieldName = arg.first();
565     if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) {
566       os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName
567                    << ";\n";
568     } else {
569       os.indent(4) << "Value *" << fieldName << ";\n";
570     }
571   }
572   for (const auto &result : pattern.getSourcePatternBoundOps()) {
573     os.indent(4) << "Operation *" << result.getKey() << ";\n";
574   }
575   // TODO(jpienaar): Change to matchAndRewrite & capture ops with consistent
576   // numbering so that it can be reused for fused loc.
577   os.indent(4) << "Operation* autogeneratedRewritePatternOps["
578                << pattern.getSourcePattern().getNumOps() << "];\n";
579   os << "  };\n";
580 
581   emitMatchMethod(tree);
582   emitRewriteMethod();
583 
584   os << "};\n";
585 }
586 
587 void PatternEmitter::emitRewriteMethod() {
588   const Operator &rootOp = pattern.getSourceRootOp();
589   int numExpectedResults = rootOp.getNumResults();
590   int numProvidedResults = pattern.getNumResults();
591 
592   os << R"(
593   void rewrite(Operation *op, std::unique_ptr<PatternState> state,
594                PatternRewriter &rewriter) const override {
595     auto& s = *static_cast<MatchedState *>(state.get());
596     auto loc = rewriter.getFusedLoc({)";
597   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
598     os << (i ? ", " : "") << "s.autogeneratedRewritePatternOps[" << i
599        << "]->getLoc()";
600   }
601   os << "}); (void)loc;\n";
602 
603   // Collect the replacement value for each result
604   llvm::SmallVector<std::string, 2> resultValues;
605   for (int i = 0; i < numProvidedResults; ++i) {
606     DagNode resultTree = pattern.getResultPattern(i);
607     resultValues.push_back(handleRewritePattern(resultTree, i, 0));
608     // Keep track of bound symbols at the top-level DAG nodes
609     addSymbol(resultTree);
610   }
611 
612   // Emit the final replaceOp() statement
613   os.indent(4) << "rewriter.replaceOp(op, {";
614   interleave(
615       // We only use the last numExpectedResults ones to replace the root op.
616       ArrayRef<std::string>(resultValues).take_back(numExpectedResults),
617       [&](const std::string &name) { os << name; }, [&]() { os << ", "; });
618   os << "});\n  }\n";
619 }
620 
621 std::string PatternEmitter::getUniqueValueName(const Operator *op) {
622   return formatv("v{0}{1}", op->getCppClassName(), nextValueId++);
623 }
624 
625 std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
626                                                  int resultIndex, int depth) {
627   if (resultTree.isNativeCodeCall())
628     return emitReplaceWithNativeCodeCall(resultTree);
629 
630   if (resultTree.isVerifyUnusedValue()) {
631     if (depth > 0) {
632       // TODO: Revisit this when we have use cases of matching an intermediate
633       // multi-result op with no uses of its certain results.
634       PrintFatalError(loc, "verifyUnusedValue directive can only be used to "
635                            "verify top-level result");
636     }
637 
638     if (!resultTree.getOpName().empty()) {
639       PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
640     }
641 
642     // The C++ statements to check that this result value is unused are already
643     // emitted in the match() method. So returning a nullptr here directly
644     // should be safe because the C++ RewritePattern harness will use it to
645     // replace nothing.
646     return "nullptr";
647   }
648 
649   if (resultTree.isReplaceWithValue())
650     return handleReplaceWithValue(resultTree);
651 
652   // Create the op and get the local variable for it.
653   auto results = emitOpCreate(resultTree, resultIndex, depth);
654   // We need to get all the values out of this local variable if we've created a
655   // multi-result op.
656   const auto &numResults = pattern.getDialectOp(resultTree).getNumResults();
657   return formatValuePack("{0}.getOperation()->getResult({1})", results,
658                          numResults, /*offset=*/0);
659 }
660 
661 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
662   assert(tree.isReplaceWithValue());
663 
664   if (tree.getNumArgs() != 1) {
665     PrintFatalError(
666         loc, "replaceWithValue directive must take exactly one argument");
667   }
668 
669   if (!tree.getOpName().empty()) {
670     PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
671   }
672 
673   return resolveSymbol(tree.getArgName(0));
674 }
675 
676 void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
677   assert(tree.isVerifyUnusedValue());
678 
679   os.indent(4) << "if (!op0->getResult(" << index
680                << ")->use_empty()) return matchFailure();\n";
681 }
682 
683 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
684                                              llvm::StringRef argName) {
685   if (leaf.isConstantAttr()) {
686     auto constAttr = leaf.getAsConstantAttr();
687     return handleConstantAttr(constAttr.getAttribute(),
688                               constAttr.getConstantValue());
689   }
690   if (leaf.isEnumAttrCase()) {
691     auto enumCase = leaf.getAsEnumAttrCase();
692     if (enumCase.isStrCase())
693       return handleConstantAttr(enumCase, enumCase.getSymbol());
694     // This is an enum case backed by an IntegerAttr. We need to get its value
695     // to build the constant.
696     std::string val = std::to_string(enumCase.getValue());
697     return handleConstantAttr(enumCase, val);
698   }
699   pattern.ensureBoundInSourcePattern(argName);
700   std::string result = getBoundSymbol(argName).str();
701   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
702     return result;
703   }
704   if (leaf.isNativeCodeCall()) {
705     return tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(result));
706   }
707   PrintFatalError(loc, "unhandled case when rewriting op");
708 }
709 
710 std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) {
711   auto fmt = tree.getNativeCodeTemplate();
712   // TODO(fengliuai): replace formatv arguments with the exact specified args.
713   SmallVector<std::string, 8> attrs(8);
714   if (tree.getNumArgs() > 8) {
715     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
716                              Twine(tree.getNumArgs()));
717   }
718   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
719     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
720   }
721   return tgfmt(fmt, &rewriteCtx, attrs[0], attrs[1], attrs[2], attrs[3],
722                attrs[4], attrs[5], attrs[6], attrs[7]);
723 }
724 
725 void PatternEmitter::addSymbol(DagNode node) {
726   StringRef symbol = node.getOpName();
727   // Skip empty-named symbols, which happen for unbound ops in result patterns.
728   if (symbol.empty())
729     return;
730   if (!symbolResolver.add(symbol, pattern.getDialectOp(node).getNumResults()))
731     PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol));
732 }
733 
734 std::string PatternEmitter::resolveSymbol(StringRef symbol) {
735   auto subst = symbolResolver.query(symbol);
736   if (subst.empty())
737     PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol));
738   return subst;
739 }
740 
741 std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
742                                          int depth) {
743   Operator &resultOp = tree.getDialectOp(opMap);
744   auto numOpArgs = resultOp.getNumArgs();
745 
746   if (resultOp.isVariadic()) {
747     PrintFatalError(loc, formatv("generating op '{0}' with variadic "
748                                  "operands/results is unsupported now",
749                                  resultOp.getOperationName()));
750   }
751 
752   if (numOpArgs != tree.getNumArgs()) {
753     PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
754                                  "{1} in pattern vs. {2} in definition",
755                                  resultOp.getOperationName(), tree.getNumArgs(),
756                                  numOpArgs));
757   }
758 
759   // A map to collect all nested DAG child nodes' names, with operand index as
760   // the key. This includes both bound and unbound child nodes. Bound child
761   // nodes will additionally be tracked in `symbolResolver` so they can be
762   // referenced by other patterns. Unbound child nodes will only be used once
763   // to build this op.
764   llvm::DenseMap<unsigned, std::string> childNodeNames;
765 
766   // First go through all the child nodes who are nested DAG constructs to
767   // create ops for them, so that we can use the results in the current node.
768   // This happens in a recursive manner.
769   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
770     if (auto child = tree.getArgAsNestedDag(i)) {
771       childNodeNames[i] = handleRewritePattern(child, i, depth + 1);
772       // Keep track of bound symbols at the middle-level DAG nodes
773       addSymbol(child);
774     }
775   }
776 
777   // Use the specified name for this op if available. Generate one otherwise.
778   std::string resultValue = tree.getOpName();
779   if (resultValue.empty())
780     resultValue = getUniqueValueName(&resultOp);
781   // Strip the index to get the name for the value pack. This will be used to
782   // name the local variable for the op.
783   StringRef valuePackName = getValuePackName(resultValue);
784 
785   // Then we build the new op corresponding to this DAG node.
786 
787   // Right now we don't have general type inference in MLIR. Except a few
788   // special cases listed below, we need to supply types for all results
789   // when building an op.
790   bool isSameOperandsAndResultType =
791       resultOp.hasTrait("SameOperandsAndResultType");
792   bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult");
793   bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType");
794   bool usePartialResults = valuePackName != resultValue;
795 
796   if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
797       usePartialResults || depth > 0) {
798     os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
799                             valuePackName, resultOp.getQualCppClassName());
800   } else {
801     // If depth == 0 we can use the equivalence of the source and target root
802     // ops in the pattern to determine the return type.
803 
804     // We need to specify the types for all results.
805     auto resultTypes =
806         formatValuePack("op->getResult({1})->getType()", valuePackName,
807                         resultOp.getNumResults(), resultIndex);
808 
809     os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
810                             valuePackName, resultOp.getQualCppClassName())
811                  << (resultTypes.empty() ? "" : ", ") << resultTypes;
812   }
813 
814   // Create the builder call for the result.
815   // Add operands.
816   int i = 0;
817   for (int e = resultOp.getNumOperands(); i < e; ++i) {
818     const auto &operand = resultOp.getOperand(i);
819 
820     // Start each operand on its own line.
821     (os << ",\n").indent(6);
822 
823     if (!operand.name.empty())
824       os << "/*" << operand.name << "=*/";
825 
826     if (tree.isNestedDagArg(i)) {
827       os << childNodeNames[i];
828     } else {
829       DagLeaf leaf = tree.getArgAsLeaf(i);
830       auto symbol = resolveSymbol(tree.getArgName(i));
831       if (leaf.isNativeCodeCall()) {
832         os << tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(symbol));
833       } else {
834         os << symbol;
835       }
836     }
837     // TODO(jpienaar): verify types
838   }
839 
840   // Add attributes.
841   for (int e = tree.getNumArgs(); i != e; ++i) {
842     // Start each attribute on its own line.
843     (os << ",\n").indent(6);
844     // The argument in the op definition.
845     auto opArgName = resultOp.getArgName(i);
846     if (auto subTree = tree.getArgAsNestedDag(i)) {
847       if (!subTree.isNativeCodeCall())
848         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
849                              "for creating attribute");
850       os << formatv("/*{0}=*/{1}", opArgName,
851                     emitReplaceWithNativeCodeCall(subTree));
852     } else {
853       auto leaf = tree.getArgAsLeaf(i);
854       // The argument in the result DAG pattern.
855       auto patArgName = tree.getArgName(i);
856       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
857         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
858         auto argument = resultOp.getArg(i);
859         if (!argument.is<NamedAttribute *>())
860           PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
861         if (!patArgName.empty())
862           os << "/*" << patArgName << "=*/";
863       } else {
864         os << "/*" << opArgName << "=*/";
865       }
866       os << handleOpArgument(leaf, patArgName);
867     }
868   }
869   os << "\n    );\n";
870 
871   return resultValue;
872 }
873 
874 void PatternEmitter::emit(StringRef rewriteName, Record *p,
875                           RecordOperatorMap *mapper, raw_ostream &os) {
876   PatternEmitter(p, mapper, os).emit(rewriteName);
877 }
878 
879 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
880   emitSourceFileHeader("Rewriters", os);
881 
882   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
883   auto numPatterns = patterns.size();
884 
885   // We put the map here because it can be shared among multiple patterns.
886   RecordOperatorMap recordOpMap;
887 
888   std::vector<std::string> rewriterNames;
889   rewriterNames.reserve(numPatterns);
890 
891   std::string baseRewriterName = "GeneratedConvert";
892   int rewriterIndex = 0;
893 
894   for (Record *p : patterns) {
895     std::string name;
896     if (p->isAnonymous()) {
897       // If no name is provided, ensure unique rewriter names simply by
898       // appending unique suffix.
899       name = baseRewriterName + llvm::utostr(rewriterIndex++);
900     } else {
901       name = p->getName();
902     }
903     PatternEmitter::emit(name, p, &recordOpMap, os);
904     rewriterNames.push_back(std::move(name));
905   }
906 
907   // Emit function to add the generated matchers to the pattern list.
908   os << "void populateWithGenerated(MLIRContext *context, "
909      << "OwningRewritePatternList *patterns) {\n";
910   for (const auto &name : rewriterNames) {
911     os << "  patterns->push_back(llvm::make_unique<" << name
912        << ">(context));\n";
913   }
914   os << "}\n";
915 }
916 
917 static mlir::GenRegistration
918     genRewriters("gen-rewriters", "Generate pattern rewriters",
919                  [](const RecordKeeper &records, raw_ostream &os) {
920                    emitRewriters(records, os);
921                    return false;
922                  });
923