xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision e032d0dc63578d5786cf9638c2c15ddd41f39274)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Support/STLExtras.h"
23 #include "mlir/TableGen/Attribute.h"
24 #include "mlir/TableGen/Format.h"
25 #include "mlir/TableGen/GenInfo.h"
26 #include "mlir/TableGen/Operator.h"
27 #include "mlir/TableGen/Pattern.h"
28 #include "mlir/TableGen/Predicate.h"
29 #include "mlir/TableGen/Type.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/FormatAdapters.h"
34 #include "llvm/Support/PrettyStackTrace.h"
35 #include "llvm/Support/Signals.h"
36 #include "llvm/TableGen/Error.h"
37 #include "llvm/TableGen/Main.h"
38 #include "llvm/TableGen/Record.h"
39 #include "llvm/TableGen/TableGenBackend.h"
40 
41 using namespace llvm;
42 using namespace mlir;
43 using namespace mlir::tblgen;
44 
45 namespace llvm {
46 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
47   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
48                      raw_ostream &os, StringRef style) {
49     os << v.first << ":" << v.second;
50   }
51 };
52 } // end namespace llvm
53 
54 // Returns the bound symbol for the given op argument or op named `symbol`.
55 //
56 // Arguments and ops bound in the source pattern are grouped into a
57 // transient `PatternState` struct. This struct can be accessed in both
58 // `match()` and `rewrite()` via the local variable named as `s`.
59 static Twine getBoundSymbol(const StringRef &symbol) {
60   return Twine("s.") + symbol;
61 }
62 
63 // Gets the dynamic value pack's name by removing the index suffix from
64 // `symbol`. Returns `symbol` itself if it does not contain an index.
65 //
66 // We can use `name__<index>` to access the `<index>`-th value in the dynamic
67 // value pack bound to `name`. `name` is typically the results of an
68 // multi-result op.
69 static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) {
70   StringRef name, indexStr;
71   unsigned idx = 0;
72   std::tie(name, indexStr) = symbol.rsplit("__");
73   if (indexStr.consumeInteger(10, idx)) {
74     // The second part is not an index.
75     return symbol;
76   }
77   if (index)
78     *index = idx;
79   return name;
80 }
81 
82 // Formats all values from a dynamic value pack `symbol` according to the given
83 // `fmt` string. The `fmt` string should use `{0}` as a placeholder for `symbol`
84 // and `{1}` as a placeholder for the value index, which will be offsetted by
85 // `offset`. The `symbol` value pack has a total of `count` values.
86 //
87 // This extracts one value from the pack if `symbol` contains an index,
88 // otherwise it extracts all values sequentially and returns them as a
89 // comma-separated list.
90 static std::string formatValuePack(const char *fmt, StringRef symbol,
91                                    unsigned count, unsigned offset) {
92   auto getNthValue = [fmt, offset](StringRef results,
93                                    unsigned index) -> std::string {
94     return formatv(fmt, results, index + offset);
95   };
96 
97   unsigned index = 0;
98   StringRef name = getValuePackName(symbol, &index);
99   if (name != symbol) {
100     // The symbol contains an index.
101     return getNthValue(name, index);
102   }
103 
104   // The symbol does not contain an index. Treat the symbol as a whole.
105   SmallVector<std::string, 4> values;
106   values.reserve(count);
107   for (unsigned i = 0; i < count; ++i)
108     values.emplace_back(getNthValue(symbol, i));
109   return llvm::join(values, ", ");
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // PatternSymbolResolver
114 //===----------------------------------------------------------------------===//
115 
116 namespace {
117 // A class for resolving symbols bound in patterns.
118 //
119 // Symbols can be bound to op arguments and ops in the source pattern and ops
120 // in result patterns. For example, in
121 //
122 // ```
123 // def : Pattern<(SrcOp:$op1 $arg0, %arg1),
124 //               [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>;
125 // ```
126 //
127 // `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`.
128 // `$op2` is bound to `ResOp1`.
129 //
130 // If a symbol binds to a multi-result op and it does not have the `__N`
131 // suffix, the symbol is expanded to the whole value pack generated by the
132 // multi-result op. If the symbol has a `__N` suffix, then it will expand to
133 // only the N-th result.
134 //
135 // This class keeps track of such symbols and translates them into their bound
136 // values.
137 //
138 // Note that we also generate local variables for unnamed DAG nodes, like
139 // `(ResOp3)` in the above. Since we don't bind a symbol to the op, the
140 // generated local variable will be implicitly named. Those implicit names are
141 // not tracked in this class.
142 class PatternSymbolResolver {
143 public:
144   PatternSymbolResolver(const StringMap<Argument> &srcArgs,
145                         const StringMap<const Operator *> &srcOperations);
146 
147   // Marks the given `symbol` as bound to a value pack with `numValues` and
148   // returns true on success. Returns false if the `symbol` is already bound.
149   bool add(StringRef symbol, int numValues);
150 
151   // Queries the substitution for the given `symbol`. Returns empty string if
152   // symbol not found. If the symbol represents a value pack, returns all the
153   // values separated via comma.
154   std::string query(StringRef symbol) const;
155 
156   // Returns how many static values the given `symbol` correspond to. Returns a
157   // negative value if the given symbol is not bound.
158   //
159   // Normally a symbol would correspond to just one value; for symbols bound to
160   // multi-result ops, it can be more than one.
161   int getValueCount(StringRef symbol) const;
162 
163 private:
164   // Symbols bound to arguments in source pattern.
165   const StringMap<Argument> &sourceArguments;
166   // Symbols bound to ops (for their results) in source pattern.
167   const StringMap<const Operator *> &sourceOps;
168   // Symbols bound to ops (for their results) in result patterns.
169   // Key: symbol; value: number of values inside the pack
170   StringMap<int> resultOps;
171 };
172 } // end anonymous namespace
173 
174 PatternSymbolResolver::PatternSymbolResolver(
175     const StringMap<Argument> &srcArgs,
176     const StringMap<const Operator *> &srcOperations)
177     : sourceArguments(srcArgs), sourceOps(srcOperations) {}
178 
179 bool PatternSymbolResolver::add(StringRef symbol, int numValues) {
180   StringRef name = getValuePackName(symbol);
181   return resultOps.try_emplace(name, numValues).second;
182 }
183 
184 std::string PatternSymbolResolver::query(StringRef symbol) const {
185   StringRef name = getValuePackName(symbol);
186   // Handle symbols bound to generated ops
187   auto resOpIt = resultOps.find(name);
188   if (resOpIt != resultOps.end())
189     return formatValuePack("{0}.getOperation()->getResult({1})", symbol,
190                            resOpIt->second, /*offset=*/0);
191 
192   // Handle symbols bound to matched op arguments
193   auto srcArgIt = sourceArguments.find(symbol);
194   if (srcArgIt != sourceArguments.end())
195     return getBoundSymbol(symbol).str();
196 
197   // Handle symbols bound to matched op results
198   auto srcOpIt = sourceOps.find(name);
199   if (srcOpIt != sourceOps.end())
200     return formatValuePack("{0}->getResult({1})", getBoundSymbol(symbol).str(),
201                            srcOpIt->second->getNumResults(), /*offset=*/0);
202   return {};
203 }
204 
205 int PatternSymbolResolver::getValueCount(StringRef symbol) const {
206   StringRef name = getValuePackName(symbol);
207   // Handle symbols bound to generated ops
208   auto resOpIt = resultOps.find(name);
209   if (resOpIt != resultOps.end())
210     return name == symbol ? resOpIt->second : 1;
211 
212   // Handle symbols bound to matched op arguments
213   if (sourceArguments.count(symbol))
214     return 1;
215 
216   // Handle symbols bound to matched op results
217   auto srcOpIt = sourceOps.find(name);
218   if (srcOpIt != sourceOps.end())
219     return name == symbol ? srcOpIt->second->getNumResults() : 1;
220   return -1;
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // PatternEmitter
225 //===----------------------------------------------------------------------===//
226 
227 namespace {
228 class PatternEmitter {
229 public:
230   static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper,
231                    raw_ostream &os);
232 
233 private:
234   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
235 
236   // Emits the mlir::RewritePattern struct named `rewriteName`.
237   void emit(StringRef rewriteName);
238 
239   // Emits the match() method.
240   void emitMatchMethod(DagNode tree);
241 
242   // Collects all of the operations within the given dag tree.
243   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
244 
245   // Emits the rewrite() method.
246   void emitRewriteMethod();
247 
248   // Emits C++ statements for matching the op constrained by the given DAG
249   // `tree`.
250   void emitOpMatch(DagNode tree, int depth);
251 
252   // Emits C++ statements for matching the `index`-th argument of the given DAG
253   // `tree` as an operand.
254   void emitOperandMatch(DagNode tree, int index, int depth, int indent);
255 
256   // Emits C++ statements for matching the `index`-th argument of the given DAG
257   // `tree` as an attribute.
258   void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
259 
260   // Returns a unique name for an value of the given `op`.
261   std::string getUniqueValueName(const Operator *op);
262 
263   // Entry point for handling a rewrite pattern rooted at `resultTree` and
264   // dispatches to concrete handlers. The given tree is the `resultIndex`-th
265   // argument of the enclosing DAG.
266   std::string handleRewritePattern(DagNode resultTree, int resultIndex,
267                                    int depth);
268 
269   // Emits the C++ statement to replace the matched DAG with a value built via
270   // calling native C++ code.
271   std::string emitReplaceWithNativeCodeCall(DagNode resultTree);
272 
273   // Returns the C++ expression referencing the old value serving as the
274   // replacement.
275   std::string handleReplaceWithValue(DagNode tree);
276 
277   // Handles the `verifyUnusedValue` directive: emitting C++ statements to check
278   // the `index`-th result of the source op is not used.
279   void handleVerifyUnusedValue(DagNode tree, int index);
280 
281   // Emits the C++ statement to build a new op out of the given DAG `tree` and
282   // returns the variable name that this op is assigned to. If the root op in
283   // DAG `tree` has a specified name, the created op will be assigned to a
284   // variable of the given name. Otherwise, a unique name will be used as the
285   // result value name.
286   std::string emitOpCreate(DagNode tree, int resultIndex, int depth);
287 
288   // Returns the C++ expression to construct a constant attribute of the given
289   // `value` for the given attribute kind `attr`.
290   std::string handleConstantAttr(Attribute attr, StringRef value);
291 
292   // Returns the C++ expression to build an argument from the given DAG `leaf`.
293   // `patArgName` is used to bound the argument to the source pattern.
294   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
295 
296   // Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol
297   // is already bound.
298   void addSymbol(StringRef symbol, int numValues);
299 
300   // Gets the substitution for `symbol`. Aborts if `symbol` is not bound.
301   std::string resolveSymbol(StringRef symbol);
302 
303   // Returns how many static values the given DAG `node` correspond to.
304   int getNodeValueCount(DagNode node);
305 
306 private:
307   // Pattern instantiation location followed by the location of multiclass
308   // prototypes used. This is intended to be used as a whole to
309   // PrintFatalError() on errors.
310   ArrayRef<llvm::SMLoc> loc;
311   // Op's TableGen Record to wrapper object
312   RecordOperatorMap *opMap;
313   // Handy wrapper for pattern being emitted
314   Pattern pattern;
315   PatternSymbolResolver symbolResolver;
316   // The next unused ID for newly created values
317   unsigned nextValueId;
318   raw_ostream &os;
319 
320   // Format contexts containing placeholder substitutations for match().
321   FmtContext matchCtx;
322   // Format contexts containing placeholder substitutations for rewrite().
323   FmtContext rewriteCtx;
324 
325   // Number of op processed.
326   int opCounter = 0;
327 };
328 } // end anonymous namespace
329 
330 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
331                                raw_ostream &os)
332     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
333       symbolResolver(pattern.getSourcePatternBoundArgs(),
334                      pattern.getSourcePatternBoundOps()),
335       nextValueId(0), os(os) {
336   matchCtx.withBuilder("mlir::Builder(ctx)");
337   rewriteCtx.withBuilder("rewriter");
338 }
339 
340 std::string PatternEmitter::handleConstantAttr(Attribute attr,
341                                                StringRef value) {
342   if (!attr.isConstBuildable())
343     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
344                              " does not have the 'constBuilderCall' field");
345 
346   // TODO(jpienaar): Verify the constants here
347   return tgfmt(attr.getConstBuilderTemplate(),
348                &rewriteCtx.withBuilder("rewriter"), value);
349 }
350 
351 // Helper function to match patterns.
352 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
353   Operator &op = tree.getDialectOp(opMap);
354   if (op.isVariadic()) {
355     PrintFatalError(loc, formatv("matching op '{0}' with variadic "
356                                  "operands/results is unsupported right now",
357                                  op.getOperationName()));
358   }
359 
360   int indent = 4 + 2 * depth;
361   // Skip the operand matching at depth 0 as the pattern rewriter already does.
362   if (depth != 0) {
363     // Skip if there is no defining operation (e.g., arguments to function).
364     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
365     os.indent(indent) << formatv(
366         "if (!isa<{1}>(op{0})) return matchFailure();\n", depth,
367         op.getQualCppClassName());
368   }
369   if (tree.getNumArgs() != op.getNumArgs()) {
370     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
371                                  "pattern vs. {2} in definition",
372                                  op.getOperationName(), tree.getNumArgs(),
373                                  op.getNumArgs()));
374   }
375 
376   // If the operand's name is set, set to that variable.
377   auto name = tree.getSymbol();
378   if (!name.empty())
379     os.indent(indent) << formatv("{0} = op{1};\n", getBoundSymbol(name), depth);
380 
381   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
382     auto opArg = op.getArg(i);
383 
384     // Handle nested DAG construct first
385     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
386       os.indent(indent) << "{\n";
387       os.indent(indent + 2)
388           << formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
389                      depth + 1, depth, i);
390       emitOpMatch(argTree, depth + 1);
391       os.indent(indent + 2)
392           << formatv("s.autogeneratedRewritePatternOps[{0}] = op{1};\n",
393                      ++opCounter, depth + 1);
394       os.indent(indent) << "}\n";
395       continue;
396     }
397 
398     // Next handle DAG leaf: operand or attribute
399     if (opArg.is<NamedTypeConstraint *>()) {
400       emitOperandMatch(tree, i, depth, indent);
401     } else if (opArg.is<NamedAttribute *>()) {
402       emitAttributeMatch(tree, i, depth, indent);
403     } else {
404       PrintFatalError(loc, "unhandled case when matching op");
405     }
406   }
407 }
408 
409 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
410                                       int indent) {
411   Operator &op = tree.getDialectOp(opMap);
412   auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
413   auto matcher = tree.getArgAsLeaf(index);
414 
415   // If a constraint is specified, we need to generate C++ statements to
416   // check the constraint.
417   if (!matcher.isUnspecified()) {
418     if (!matcher.isOperandMatcher()) {
419       PrintFatalError(
420           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
421                        op.getOperationName(), index + 1));
422     }
423 
424     // Only need to verify if the matcher's type is different from the one
425     // of op definition.
426     if (operand->constraint != matcher.getAsConstraint()) {
427       auto self = formatv("op{0}->getOperand({1})->getType()", depth, index);
428       os.indent(indent) << "if (!("
429                         << tgfmt(matcher.getConditionTemplate(),
430                                  &matchCtx.withSelf(self))
431                         << ")) return matchFailure();\n";
432     }
433   }
434 
435   // Capture the value
436   auto name = tree.getArgName(index);
437   if (!name.empty()) {
438     os.indent(indent) << getBoundSymbol(name) << " = op" << depth
439                       << "->getOperand(" << index << ");\n";
440   }
441 }
442 
443 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
444                                         int indent) {
445   Operator &op = tree.getDialectOp(opMap);
446   auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
447   const auto &attr = namedAttr->attr;
448 
449   os.indent(indent) << "{\n";
450   indent += 2;
451   os.indent(indent) << formatv(
452       "auto attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
453       attr.getStorageType(), namedAttr->name);
454 
455   // TODO(antiagainst): This should use getter method to avoid duplication.
456   if (attr.hasDefaultValueInitializer()) {
457     os.indent(indent) << "if (!attr) attr = "
458                       << tgfmt(attr.getConstBuilderTemplate(), &matchCtx,
459                                attr.getDefaultValueInitializer())
460                       << ";\n";
461   } else if (attr.isOptional()) {
462     // For a missing attribut that is optional according to definition, we
463     // should just capature a mlir::Attribute() to signal the missing state.
464     // That is precisely what getAttr() returns on missing attributes.
465   } else {
466     os.indent(indent) << "if (!attr) return matchFailure();\n";
467   }
468 
469   auto matcher = tree.getArgAsLeaf(index);
470   if (!matcher.isUnspecified()) {
471     if (!matcher.isAttrMatcher()) {
472       PrintFatalError(
473           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
474                        op.getOperationName(), index + 1));
475     }
476 
477     // If a constraint is specified, we need to generate C++ statements to
478     // check the constraint.
479     os.indent(indent) << "if (!("
480                       << tgfmt(matcher.getConditionTemplate(),
481                                &matchCtx.withSelf("attr"))
482                       << ")) return matchFailure();\n";
483   }
484 
485   // Capture the value
486   auto name = tree.getArgName(index);
487   if (!name.empty()) {
488     os.indent(indent) << getBoundSymbol(name) << " = attr;\n";
489   }
490 
491   indent -= 2;
492   os.indent(indent) << "}\n";
493 }
494 
495 void PatternEmitter::emitMatchMethod(DagNode tree) {
496   // Emit the heading.
497   os << R"(
498   PatternMatchResult match(Operation *op0) const override {
499     auto ctx = op0->getContext(); (void)ctx;
500     auto state = llvm::make_unique<MatchedState>();
501     auto &s = *state;
502     s.autogeneratedRewritePatternOps[0] = op0;
503 )";
504 
505   // The rewrite pattern may specify that certain outputs should be unused in
506   // the source IR. Check it here.
507   for (int i = 0, e = pattern.getNumResultPatterns(); i < e; ++i) {
508     DagNode resultTree = pattern.getResultPattern(i);
509     if (resultTree.isVerifyUnusedValue()) {
510       handleVerifyUnusedValue(resultTree, i);
511     }
512   }
513 
514   emitOpMatch(tree, 0);
515 
516   for (auto &appliedConstraint : pattern.getConstraints()) {
517     auto &constraint = appliedConstraint.constraint;
518     auto &entities = appliedConstraint.entities;
519 
520     auto condition = constraint.getConditionTemplate();
521     auto cmd = "if (!({0})) return matchFailure();\n";
522 
523     if (isa<TypeConstraint>(constraint)) {
524       auto self = formatv("({0}->getType())", resolveSymbol(entities.front()));
525       os.indent(4) << formatv(cmd,
526                               tgfmt(condition, &matchCtx.withSelf(self.str())));
527     } else if (isa<AttrConstraint>(constraint)) {
528       PrintFatalError(
529           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
530     } else {
531       // TODO(fengliuai): replace formatv arguments with the exact specified
532       // args.
533       if (entities.size() > 4) {
534         PrintFatalError(loc, "only support up to 4-entity constraints now");
535       }
536       SmallVector<std::string, 4> names;
537       int i = 0;
538       for (int e = entities.size(); i < e; ++i)
539         names.push_back(resolveSymbol(entities[i]));
540       for (; i < 4; ++i)
541         names.push_back("<unused>");
542       os.indent(4) << formatv(cmd, tgfmt(condition, &matchCtx, names[0],
543                                          names[1], names[2], names[3]));
544     }
545   }
546 
547   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
548 }
549 
550 void PatternEmitter::collectOps(DagNode tree,
551                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
552   // Check if this tree is an operation.
553   if (tree.isOperation())
554     ops.insert(&tree.getDialectOp(opMap));
555 
556   // Recurse the arguments of the tree.
557   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
558     if (auto child = tree.getArgAsNestedDag(i))
559       collectOps(child, ops);
560 }
561 
562 void PatternEmitter::emit(StringRef rewriteName) {
563   // Get the DAG tree for the source pattern
564   DagNode tree = pattern.getSourcePattern();
565 
566   const Operator &rootOp = pattern.getSourceRootOp();
567   auto rootName = rootOp.getOperationName();
568 
569   // Collect the set of result operations.
570   llvm::SmallPtrSet<const Operator *, 4> results;
571   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i)
572     collectOps(pattern.getResultPattern(i), results);
573 
574   // Emit RewritePattern for Pattern.
575   auto locs = pattern.getLocation();
576   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
577                 make_range(locs.rbegin(), locs.rend()));
578   os << formatv(R"(struct {0} : public RewritePattern {
579   {0}(MLIRContext *context) : RewritePattern("{1}", {{)",
580                 rewriteName, rootName);
581   interleaveComma(results, os, [&](const Operator *op) {
582     os << '"' << op->getOperationName() << '"';
583   });
584   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
585 
586   // Emit matched state.
587   os << "  struct MatchedState : public PatternState {\n";
588   for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
589     auto fieldName = arg.first();
590     if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) {
591       os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName
592                    << ";\n";
593     } else {
594       os.indent(4) << "Value *" << fieldName << ";\n";
595     }
596   }
597   for (const auto &result : pattern.getSourcePatternBoundOps()) {
598     os.indent(4) << "Operation *" << result.getKey() << ";\n";
599   }
600   // TODO(jpienaar): Change to matchAndRewrite & capture ops with consistent
601   // numbering so that it can be reused for fused loc.
602   os.indent(4) << "Operation* autogeneratedRewritePatternOps["
603                << pattern.getSourcePattern().getNumOps() << "];\n";
604   os << "  };\n";
605 
606   emitMatchMethod(tree);
607   emitRewriteMethod();
608 
609   os << "};\n";
610 }
611 
612 void PatternEmitter::emitRewriteMethod() {
613   const Operator &rootOp = pattern.getSourceRootOp();
614   int numExpectedResults = rootOp.getNumResults();
615   int numResultPatterns = pattern.getNumResultPatterns();
616 
617   // First register all symbols bound to ops generated in result patterns.
618   for (const auto &boundOp : pattern.getResultPatternBoundOps()) {
619     addSymbol(boundOp.getKey(), boundOp.getValue()->getNumResults());
620   }
621 
622   // Only the last N static values generated are used to replace the matched
623   // root N-result op. We need to calculate the starting index (of the results
624   // of the matched op) each result pattern is to replace.
625   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
626   int replStartIndex = -1;
627   for (int i = numResultPatterns - 1; i >= 0; --i) {
628     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
629     offsets[i] = offsets[i + 1] - numValues;
630     if (offsets[i] == 0) {
631       replStartIndex = i;
632     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
633       auto error = formatv(
634           "cannot use the same multi-result op '{0}' to generate both "
635           "auxiliary values and values to be used for replacing the matched op",
636           pattern.getResultPattern(i).getSymbol());
637       PrintFatalError(loc, error);
638     }
639   }
640 
641   if (offsets.front() > 0) {
642     const char error[] = "no enough values generated to replace the matched op";
643     PrintFatalError(loc, error);
644   }
645 
646   os << R"(
647   void rewrite(Operation *op, std::unique_ptr<PatternState> state,
648                PatternRewriter &rewriter) const override {
649     auto& s = *static_cast<MatchedState *>(state.get());
650     auto loc = rewriter.getFusedLoc({)";
651   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
652     os << (i ? ", " : "") << "s.autogeneratedRewritePatternOps[" << i
653        << "]->getLoc()";
654   }
655   os << "}); (void)loc;\n";
656 
657   // Collect the replacement value for each result
658   llvm::SmallVector<std::string, 2> resultValues;
659   for (int i = 0; i < numResultPatterns; ++i) {
660     DagNode resultTree = pattern.getResultPattern(i);
661     resultValues.push_back(handleRewritePattern(resultTree, offsets[i], 0));
662   }
663 
664   // Emit the final replaceOp() statement
665   os.indent(4) << "rewriter.replaceOp(op, {";
666   interleave(
667       ArrayRef<std::string>(resultValues).drop_front(replStartIndex),
668       [&](const std::string &name) { os << name; }, [&]() { os << ", "; });
669   os << "});\n  }\n";
670 }
671 
672 std::string PatternEmitter::getUniqueValueName(const Operator *op) {
673   return formatv("v{0}{1}", op->getCppClassName(), nextValueId++);
674 }
675 
676 std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
677                                                  int resultIndex, int depth) {
678   if (resultTree.isNativeCodeCall())
679     return emitReplaceWithNativeCodeCall(resultTree);
680 
681   if (resultTree.isVerifyUnusedValue()) {
682     if (depth > 0) {
683       // TODO: Revisit this when we have use cases of matching an intermediate
684       // multi-result op with no uses of its certain results.
685       PrintFatalError(loc, "verifyUnusedValue directive can only be used to "
686                            "verify top-level result");
687     }
688 
689     if (!resultTree.getSymbol().empty()) {
690       PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
691     }
692 
693     // The C++ statements to check that this result value is unused are already
694     // emitted in the match() method. So returning a nullptr here directly
695     // should be safe because the C++ RewritePattern harness will use it to
696     // replace nothing.
697     return "nullptr";
698   }
699 
700   if (resultTree.isReplaceWithValue())
701     return handleReplaceWithValue(resultTree);
702 
703   // Create the op and get the local variable for it.
704   auto results = emitOpCreate(resultTree, resultIndex, depth);
705   // We need to get all the values out of this local variable if we've created a
706   // multi-result op.
707   const auto &numResults = pattern.getDialectOp(resultTree).getNumResults();
708   return formatValuePack("{0}.getOperation()->getResult({1})", results,
709                          numResults, /*offset=*/0);
710 }
711 
712 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
713   assert(tree.isReplaceWithValue());
714 
715   if (tree.getNumArgs() != 1) {
716     PrintFatalError(
717         loc, "replaceWithValue directive must take exactly one argument");
718   }
719 
720   if (!tree.getSymbol().empty()) {
721     PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
722   }
723 
724   return resolveSymbol(tree.getArgName(0));
725 }
726 
727 void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
728   assert(tree.isVerifyUnusedValue());
729 
730   os.indent(4) << "if (!op0->getResult(" << index
731                << ")->use_empty()) return matchFailure();\n";
732 }
733 
734 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) {
735   if (leaf.isConstantAttr()) {
736     auto constAttr = leaf.getAsConstantAttr();
737     return handleConstantAttr(constAttr.getAttribute(),
738                               constAttr.getConstantValue());
739   }
740   if (leaf.isEnumAttrCase()) {
741     auto enumCase = leaf.getAsEnumAttrCase();
742     if (enumCase.isStrCase())
743       return handleConstantAttr(enumCase, enumCase.getSymbol());
744     // This is an enum case backed by an IntegerAttr. We need to get its value
745     // to build the constant.
746     std::string val = std::to_string(enumCase.getValue());
747     return handleConstantAttr(enumCase, val);
748   }
749   pattern.ensureBoundInSourcePattern(argName);
750   std::string result = getBoundSymbol(argName).str();
751   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
752     return result;
753   }
754   if (leaf.isNativeCodeCall()) {
755     return tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(result));
756   }
757   PrintFatalError(loc, "unhandled case when rewriting op");
758 }
759 
760 std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) {
761   auto fmt = tree.getNativeCodeTemplate();
762   // TODO(fengliuai): replace formatv arguments with the exact specified args.
763   SmallVector<std::string, 8> attrs(8);
764   if (tree.getNumArgs() > 8) {
765     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
766                              Twine(tree.getNumArgs()));
767   }
768   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
769     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
770   }
771   return tgfmt(fmt, &rewriteCtx, attrs[0], attrs[1], attrs[2], attrs[3],
772                attrs[4], attrs[5], attrs[6], attrs[7]);
773 }
774 
775 void PatternEmitter::addSymbol(StringRef symbol, int numValues) {
776   if (!symbolResolver.add(symbol, numValues))
777     PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol));
778 }
779 
780 std::string PatternEmitter::resolveSymbol(StringRef symbol) {
781   auto subst = symbolResolver.query(symbol);
782   if (subst.empty())
783     PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol));
784   return subst;
785 }
786 
787 int PatternEmitter::getNodeValueCount(DagNode node) {
788   if (node.isOperation()) {
789     // First to see whether this op is bound and we just want a specific result
790     // of it with `__N` suffix in symbol.
791     int count = symbolResolver.getValueCount(node.getSymbol());
792     if (count >= 0)
793       return count;
794 
795     // No symbol. Then we are using all the results.
796     return pattern.getDialectOp(node).getNumResults();
797   }
798   // TODO(antiagainst): This considers all NativeCodeCall as returning one
799   // value. Enhance if multi-value ones are needed.
800   return 1;
801 }
802 
803 std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
804                                          int depth) {
805   Operator &resultOp = tree.getDialectOp(opMap);
806   auto numOpArgs = resultOp.getNumArgs();
807 
808   if (resultOp.isVariadic()) {
809     PrintFatalError(loc, formatv("generating op '{0}' with variadic "
810                                  "operands/results is unsupported now",
811                                  resultOp.getOperationName()));
812   }
813 
814   if (numOpArgs != tree.getNumArgs()) {
815     PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
816                                  "{1} in pattern vs. {2} in definition",
817                                  resultOp.getOperationName(), tree.getNumArgs(),
818                                  numOpArgs));
819   }
820 
821   // A map to collect all nested DAG child nodes' names, with operand index as
822   // the key. This includes both bound and unbound child nodes. Bound child
823   // nodes will additionally be tracked in `symbolResolver` so they can be
824   // referenced by other patterns. Unbound child nodes will only be used once
825   // to build this op.
826   llvm::DenseMap<unsigned, std::string> childNodeNames;
827 
828   // First go through all the child nodes who are nested DAG constructs to
829   // create ops for them, so that we can use the results in the current node.
830   // This happens in a recursive manner.
831   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
832     if (auto child = tree.getArgAsNestedDag(i)) {
833       childNodeNames[i] = handleRewritePattern(child, i, depth + 1);
834     }
835   }
836 
837   // Use the specified name for this op if available. Generate one otherwise.
838   std::string resultValue = tree.getSymbol();
839   if (resultValue.empty())
840     resultValue = getUniqueValueName(&resultOp);
841   // Strip the index to get the name for the value pack. This will be used to
842   // name the local variable for the op.
843   StringRef valuePackName = getValuePackName(resultValue);
844 
845   // Then we build the new op corresponding to this DAG node.
846 
847   // Right now we don't have general type inference in MLIR. Except a few
848   // special cases listed below, we need to supply types for all results
849   // when building an op.
850   bool isSameOperandsAndResultType =
851       resultOp.hasTrait("SameOperandsAndResultType");
852   bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult");
853   bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType");
854   bool usePartialResults = valuePackName != resultValue;
855 
856   if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
857       usePartialResults || depth > 0 || resultIndex < 0) {
858     os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
859                             valuePackName, resultOp.getQualCppClassName());
860   } else {
861     // If depth == 0 and resultIndex >= 0, it means we are replacing the values
862     // generated from the source pattern root op. Then we can use the source
863     // pattern's value types to determine the value type of the generated op
864     // here.
865 
866     // We need to specify the types for all results.
867     auto resultTypes =
868         formatValuePack("op->getResult({1})->getType()", valuePackName,
869                         resultOp.getNumResults(), resultIndex);
870 
871     os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
872                             valuePackName, resultOp.getQualCppClassName())
873                  << (resultTypes.empty() ? "" : ", ") << resultTypes;
874   }
875 
876   // Create the builder call for the result.
877   // Add operands.
878   int i = 0;
879   for (int e = resultOp.getNumOperands(); i < e; ++i) {
880     const auto &operand = resultOp.getOperand(i);
881 
882     // Start each operand on its own line.
883     (os << ",\n").indent(6);
884 
885     if (!operand.name.empty())
886       os << "/*" << operand.name << "=*/";
887 
888     if (tree.isNestedDagArg(i)) {
889       os << childNodeNames[i];
890     } else {
891       DagLeaf leaf = tree.getArgAsLeaf(i);
892       auto symbol = resolveSymbol(tree.getArgName(i));
893       if (leaf.isNativeCodeCall()) {
894         os << tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(symbol));
895       } else {
896         os << symbol;
897       }
898     }
899     // TODO(jpienaar): verify types
900   }
901 
902   // Add attributes.
903   for (int e = tree.getNumArgs(); i != e; ++i) {
904     // Start each attribute on its own line.
905     (os << ",\n").indent(6);
906     // The argument in the op definition.
907     auto opArgName = resultOp.getArgName(i);
908     if (auto subTree = tree.getArgAsNestedDag(i)) {
909       if (!subTree.isNativeCodeCall())
910         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
911                              "for creating attribute");
912       os << formatv("/*{0}=*/{1}", opArgName,
913                     emitReplaceWithNativeCodeCall(subTree));
914     } else {
915       auto leaf = tree.getArgAsLeaf(i);
916       // The argument in the result DAG pattern.
917       auto patArgName = tree.getArgName(i);
918       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
919         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
920         auto argument = resultOp.getArg(i);
921         if (!argument.is<NamedAttribute *>())
922           PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
923         if (!patArgName.empty())
924           os << "/*" << patArgName << "=*/";
925       } else {
926         os << "/*" << opArgName << "=*/";
927       }
928       os << handleOpArgument(leaf, patArgName);
929     }
930   }
931   os << "\n    );\n";
932 
933   return resultValue;
934 }
935 
936 void PatternEmitter::emit(StringRef rewriteName, Record *p,
937                           RecordOperatorMap *mapper, raw_ostream &os) {
938   PatternEmitter(p, mapper, os).emit(rewriteName);
939 }
940 
941 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
942   emitSourceFileHeader("Rewriters", os);
943 
944   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
945   auto numPatterns = patterns.size();
946 
947   // We put the map here because it can be shared among multiple patterns.
948   RecordOperatorMap recordOpMap;
949 
950   std::vector<std::string> rewriterNames;
951   rewriterNames.reserve(numPatterns);
952 
953   std::string baseRewriterName = "GeneratedConvert";
954   int rewriterIndex = 0;
955 
956   for (Record *p : patterns) {
957     std::string name;
958     if (p->isAnonymous()) {
959       // If no name is provided, ensure unique rewriter names simply by
960       // appending unique suffix.
961       name = baseRewriterName + llvm::utostr(rewriterIndex++);
962     } else {
963       name = p->getName();
964     }
965     PatternEmitter::emit(name, p, &recordOpMap, os);
966     rewriterNames.push_back(std::move(name));
967   }
968 
969   // Emit function to add the generated matchers to the pattern list.
970   os << "void populateWithGenerated(MLIRContext *context, "
971      << "OwningRewritePatternList *patterns) {\n";
972   for (const auto &name : rewriterNames) {
973     os << "  patterns->push_back(llvm::make_unique<" << name
974        << ">(context));\n";
975   }
976   os << "}\n";
977 }
978 
979 static mlir::GenRegistration
980     genRewriters("gen-rewriters", "Generate pattern rewriters",
981                  [](const RecordKeeper &records, raw_ostream &os) {
982                    emitRewriters(records, os);
983                    return false;
984                  });
985