xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 60f78453d7e88fbd15dc03e83cf95e7b210a2f2c)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Support/STLExtras.h"
23 #include "mlir/TableGen/Attribute.h"
24 #include "mlir/TableGen/Format.h"
25 #include "mlir/TableGen/GenInfo.h"
26 #include "mlir/TableGen/Operator.h"
27 #include "mlir/TableGen/Pattern.h"
28 #include "mlir/TableGen/Predicate.h"
29 #include "mlir/TableGen/Type.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/FormatAdapters.h"
34 #include "llvm/Support/PrettyStackTrace.h"
35 #include "llvm/Support/Signals.h"
36 #include "llvm/TableGen/Error.h"
37 #include "llvm/TableGen/Main.h"
38 #include "llvm/TableGen/Record.h"
39 #include "llvm/TableGen/TableGenBackend.h"
40 
41 using namespace llvm;
42 using namespace mlir;
43 using namespace mlir::tblgen;
44 
45 namespace llvm {
46 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
47   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
48                      raw_ostream &os, StringRef style) {
49     os << v.first << ":" << v.second;
50   }
51 };
52 } // end namespace llvm
53 
54 // Gets the dynamic value pack's name by removing the index suffix from
55 // `symbol`. Returns `symbol` itself if it does not contain an index.
56 //
57 // We can use `name__<index>` to access the `<index>`-th value in the dynamic
58 // value pack bound to `name`. `name` is typically the results of an
59 // multi-result op.
60 static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) {
61   StringRef name, indexStr;
62   unsigned idx = 0;
63   std::tie(name, indexStr) = symbol.rsplit("__");
64   if (indexStr.consumeInteger(10, idx)) {
65     // The second part is not an index.
66     return symbol;
67   }
68   if (index)
69     *index = idx;
70   return name;
71 }
72 
73 // Formats all values from a dynamic value pack `symbol` according to the given
74 // `fmt` string. The `fmt` string should use `{0}` as a placeholder for `symbol`
75 // and `{1}` as a placeholder for the value index, which will be offsetted by
76 // `offset`. The `symbol` value pack has a total of `count` values.
77 //
78 // This extracts one value from the pack if `symbol` contains an index,
79 // otherwise it extracts all values sequentially and returns them as a
80 // comma-separated list.
81 static std::string formatValuePack(const char *fmt, StringRef symbol,
82                                    unsigned count, unsigned offset) {
83   auto getNthValue = [fmt, offset](StringRef results,
84                                    unsigned index) -> std::string {
85     return formatv(fmt, results, index + offset);
86   };
87 
88   unsigned index = 0;
89   StringRef name = getValuePackName(symbol, &index);
90   if (name != symbol) {
91     // The symbol contains an index.
92     return getNthValue(name, index);
93   }
94 
95   // The symbol does not contain an index. Treat the symbol as a whole.
96   SmallVector<std::string, 4> values;
97   values.reserve(count);
98   for (unsigned i = 0; i < count; ++i)
99     values.emplace_back(getNthValue(symbol, i));
100   return llvm::join(values, ", ");
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // PatternSymbolResolver
105 //===----------------------------------------------------------------------===//
106 
107 namespace {
108 // A class for resolving symbols bound in patterns.
109 //
110 // Symbols can be bound to op arguments and ops in the source pattern and ops
111 // in result patterns. For example, in
112 //
113 // ```
114 // def : Pattern<(SrcOp:$op1 $arg0, %arg1),
115 //               [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>;
116 // ```
117 //
118 // `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`.
119 // `$op2` is bound to `ResOp1`.
120 //
121 // If a symbol binds to a multi-result op and it does not have the `__N`
122 // suffix, the symbol is expanded to the whole value pack generated by the
123 // multi-result op. If the symbol has a `__N` suffix, then it will expand to
124 // only the N-th result.
125 //
126 // This class keeps track of such symbols and translates them into their bound
127 // values.
128 //
129 // Note that we also generate local variables for unnamed DAG nodes, like
130 // `(ResOp3)` in the above. Since we don't bind a symbol to the op, the
131 // generated local variable will be implicitly named. Those implicit names are
132 // not tracked in this class.
133 class PatternSymbolResolver {
134 public:
135   PatternSymbolResolver(const StringMap<Argument> &srcArgs,
136                         const StringMap<const Operator *> &srcOperations);
137 
138   // Marks the given `symbol` as bound to a value pack with `numValues` and
139   // returns true on success. Returns false if the `symbol` is already bound.
140   bool add(StringRef symbol, int numValues);
141 
142   // Queries the substitution for the given `symbol`. Returns empty string if
143   // symbol not found. If the symbol represents a value pack, returns all the
144   // values separated via comma.
145   std::string query(StringRef symbol) const;
146 
147   // Returns how many static values the given `symbol` correspond to. Returns a
148   // negative value if the given symbol is not bound.
149   //
150   // Normally a symbol would correspond to just one value; for symbols bound to
151   // multi-result ops, it can be more than one.
152   int getValueCount(StringRef symbol) const;
153 
154 private:
155   // Symbols bound to arguments in source pattern.
156   const StringMap<Argument> &sourceArguments;
157   // Symbols bound to ops (for their results) in source pattern.
158   const StringMap<const Operator *> &sourceOps;
159   // Symbols bound to ops (for their results) in result patterns.
160   // Key: symbol; value: number of values inside the pack
161   StringMap<int> resultOps;
162 };
163 } // end anonymous namespace
164 
165 PatternSymbolResolver::PatternSymbolResolver(
166     const StringMap<Argument> &srcArgs,
167     const StringMap<const Operator *> &srcOperations)
168     : sourceArguments(srcArgs), sourceOps(srcOperations) {}
169 
170 bool PatternSymbolResolver::add(StringRef symbol, int numValues) {
171   StringRef name = getValuePackName(symbol);
172   return resultOps.try_emplace(name, numValues).second;
173 }
174 
175 std::string PatternSymbolResolver::query(StringRef symbol) const {
176   StringRef name = getValuePackName(symbol);
177   // Handle symbols bound to generated ops
178   auto resOpIt = resultOps.find(name);
179   if (resOpIt != resultOps.end())
180     return formatValuePack("{0}.getOperation()->getResult({1})", symbol,
181                            resOpIt->second, /*offset=*/0);
182 
183   // Handle symbols bound to matched op arguments
184   auto srcArgIt = sourceArguments.find(symbol);
185   if (srcArgIt != sourceArguments.end())
186     return symbol;
187 
188   // Handle symbols bound to matched op results
189   auto srcOpIt = sourceOps.find(name);
190   if (srcOpIt != sourceOps.end())
191     return formatValuePack("{0}->getResult({1})", symbol,
192                            srcOpIt->second->getNumResults(), /*offset=*/0);
193   return {};
194 }
195 
196 int PatternSymbolResolver::getValueCount(StringRef symbol) const {
197   StringRef name = getValuePackName(symbol);
198   // Handle symbols bound to generated ops
199   auto resOpIt = resultOps.find(name);
200   if (resOpIt != resultOps.end())
201     return name == symbol ? resOpIt->second : 1;
202 
203   // Handle symbols bound to matched op arguments
204   if (sourceArguments.count(symbol))
205     return 1;
206 
207   // Handle symbols bound to matched op results
208   auto srcOpIt = sourceOps.find(name);
209   if (srcOpIt != sourceOps.end())
210     return name == symbol ? srcOpIt->second->getNumResults() : 1;
211   return -1;
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // PatternEmitter
216 //===----------------------------------------------------------------------===//
217 
218 namespace {
219 class PatternEmitter {
220 public:
221   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
222 
223   // Emits the mlir::RewritePattern struct named `rewriteName`.
224   void emit(StringRef rewriteName);
225 
226 private:
227   // Emits the code for matching ops.
228   void emitMatchLogic(DagNode tree);
229 
230   // Emits the code for rewriting ops.
231   void emitRewriteLogic();
232 
233   //===--------------------------------------------------------------------===//
234   // Match utilities
235   //===--------------------------------------------------------------------===//
236 
237   // Emits C++ statements for matching the op constrained by the given DAG
238   // `tree`.
239   void emitOpMatch(DagNode tree, int depth);
240 
241   // Emits C++ statements for matching the `index`-th argument of the given DAG
242   // `tree` as an operand.
243   void emitOperandMatch(DagNode tree, int index, int depth, int indent);
244 
245   // Emits C++ statements for matching the `index`-th argument of the given DAG
246   // `tree` as an attribute.
247   void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
248 
249   //===--------------------------------------------------------------------===//
250   // Rewrite utilities
251   //===--------------------------------------------------------------------===//
252 
253   // Entry point for handling a result pattern rooted at `resultTree` and
254   // dispatches to concrete handlers. The given tree is the `resultIndex`-th
255   // argument of the enclosing DAG.
256   std::string handleResultPattern(DagNode resultTree, int resultIndex,
257                                   int depth);
258 
259   // Emits the C++ statement to replace the matched DAG with a value built via
260   // calling native C++ code.
261   std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
262 
263   // Returns the C++ expression referencing the old value serving as the
264   // replacement.
265   std::string handleReplaceWithValue(DagNode tree);
266 
267   // Emits the C++ statement to build a new op out of the given DAG `tree` and
268   // returns the variable name that this op is assigned to. If the root op in
269   // DAG `tree` has a specified name, the created op will be assigned to a
270   // variable of the given name. Otherwise, a unique name will be used as the
271   // result value name.
272   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
273 
274   // Returns the C++ expression to construct a constant attribute of the given
275   // `value` for the given attribute kind `attr`.
276   std::string handleConstantAttr(Attribute attr, StringRef value);
277 
278   // Returns the C++ expression to build an argument from the given DAG `leaf`.
279   // `patArgName` is used to bound the argument to the source pattern.
280   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
281 
282   //===--------------------------------------------------------------------===//
283   // General utilities
284   //===--------------------------------------------------------------------===//
285 
286   // Collects all of the operations within the given dag tree.
287   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
288 
289   // Returns a unique name for a value of the given `op`.
290   std::string getUniqueValueName(const Operator *op);
291 
292   //===--------------------------------------------------------------------===//
293   // Symbol utilities
294   //===--------------------------------------------------------------------===//
295 
296   // Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol
297   // is already bound.
298   void addSymbol(StringRef symbol, int numValues);
299 
300   // Gets the substitution for `symbol`. Aborts if `symbol` is not bound.
301   std::string resolveSymbol(StringRef symbol);
302 
303   // Returns how many static values the given DAG `node` correspond to.
304   int getNodeValueCount(DagNode node);
305 
306 private:
307   // Pattern instantiation location followed by the location of multiclass
308   // prototypes used. This is intended to be used as a whole to
309   // PrintFatalError() on errors.
310   ArrayRef<llvm::SMLoc> loc;
311   // Op's TableGen Record to wrapper object
312   RecordOperatorMap *opMap;
313   // Handy wrapper for pattern being emitted
314   Pattern pattern;
315   PatternSymbolResolver symbolResolver;
316   // The next unused ID for newly created values
317   unsigned nextValueId;
318   raw_ostream &os;
319 
320   // Format contexts containing placeholder substitutations.
321   FmtContext fmtCtx;
322 
323   // Number of op processed.
324   int opCounter = 0;
325 };
326 } // end anonymous namespace
327 
328 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
329                                raw_ostream &os)
330     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
331       symbolResolver(pattern.getSourcePatternBoundArgs(),
332                      pattern.getSourcePatternBoundOps()),
333       nextValueId(0), os(os) {
334   fmtCtx.withBuilder("rewriter");
335 }
336 
337 std::string PatternEmitter::handleConstantAttr(Attribute attr,
338                                                StringRef value) {
339   if (!attr.isConstBuildable())
340     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
341                              " does not have the 'constBuilderCall' field");
342 
343   // TODO(jpienaar): Verify the constants here
344   return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value);
345 }
346 
347 // Helper function to match patterns.
348 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
349   Operator &op = tree.getDialectOp(opMap);
350   if (op.isVariadic()) {
351     PrintFatalError(loc, formatv("matching op '{0}' with variadic "
352                                  "operands/results is unsupported right now",
353                                  op.getOperationName()));
354   }
355 
356   int indent = 4 + 2 * depth;
357   // Skip the operand matching at depth 0 as the pattern rewriter already does.
358   if (depth != 0) {
359     // Skip if there is no defining operation (e.g., arguments to function).
360     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
361     os.indent(indent) << formatv(
362         "if (!isa<{1}>(op{0})) return matchFailure();\n", depth,
363         op.getQualCppClassName());
364   }
365   if (tree.getNumArgs() != op.getNumArgs()) {
366     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
367                                  "pattern vs. {2} in definition",
368                                  op.getOperationName(), tree.getNumArgs(),
369                                  op.getNumArgs()));
370   }
371 
372   // If the operand's name is set, set to that variable.
373   auto name = tree.getSymbol();
374   if (!name.empty())
375     os.indent(indent) << formatv("{0} = op{1};\n", name, depth);
376 
377   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
378     auto opArg = op.getArg(i);
379 
380     // Handle nested DAG construct first
381     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
382       os.indent(indent) << "{\n";
383       os.indent(indent + 2)
384           << formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
385                      depth + 1, depth, i);
386       emitOpMatch(argTree, depth + 1);
387       os.indent(indent + 2)
388           << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
389       os.indent(indent) << "}\n";
390       continue;
391     }
392 
393     // Next handle DAG leaf: operand or attribute
394     if (opArg.is<NamedTypeConstraint *>()) {
395       emitOperandMatch(tree, i, depth, indent);
396     } else if (opArg.is<NamedAttribute *>()) {
397       emitAttributeMatch(tree, i, depth, indent);
398     } else {
399       PrintFatalError(loc, "unhandled case when matching op");
400     }
401   }
402 }
403 
404 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
405                                       int indent) {
406   Operator &op = tree.getDialectOp(opMap);
407   auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
408   auto matcher = tree.getArgAsLeaf(index);
409 
410   // If a constraint is specified, we need to generate C++ statements to
411   // check the constraint.
412   if (!matcher.isUnspecified()) {
413     if (!matcher.isOperandMatcher()) {
414       PrintFatalError(
415           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
416                        op.getOperationName(), index + 1));
417     }
418 
419     // Only need to verify if the matcher's type is different from the one
420     // of op definition.
421     if (operand->constraint != matcher.getAsConstraint()) {
422       auto self = formatv("op{0}->getOperand({1})->getType()", depth, index);
423       os.indent(indent) << "if (!("
424                         << tgfmt(matcher.getConditionTemplate(),
425                                  &fmtCtx.withSelf(self))
426                         << ")) return matchFailure();\n";
427     }
428   }
429 
430   // Capture the value
431   auto name = tree.getArgName(index);
432   if (!name.empty()) {
433     os.indent(indent) << formatv("{0} = op{1}->getOperand({2});\n", name, depth,
434                                  index);
435   }
436 }
437 
438 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
439                                         int indent) {
440   Operator &op = tree.getDialectOp(opMap);
441   auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
442   const auto &attr = namedAttr->attr;
443 
444   os.indent(indent) << "{\n";
445   indent += 2;
446   os.indent(indent) << formatv(
447       "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
448       attr.getStorageType(), namedAttr->name);
449 
450   // TODO(antiagainst): This should use getter method to avoid duplication.
451   if (attr.hasDefaultValueInitializer()) {
452     os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
453                       << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
454                                attr.getDefaultValueInitializer())
455                       << ";\n";
456   } else if (attr.isOptional()) {
457     // For a missing attribute that is optional according to definition, we
458     // should just capature a mlir::Attribute() to signal the missing state.
459     // That is precisely what getAttr() returns on missing attributes.
460   } else {
461     os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
462   }
463 
464   auto matcher = tree.getArgAsLeaf(index);
465   if (!matcher.isUnspecified()) {
466     if (!matcher.isAttrMatcher()) {
467       PrintFatalError(
468           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
469                        op.getOperationName(), index + 1));
470     }
471 
472     // If a constraint is specified, we need to generate C++ statements to
473     // check the constraint.
474     os.indent(indent) << "if (!("
475                       << tgfmt(matcher.getConditionTemplate(),
476                                &fmtCtx.withSelf("tblgen_attr"))
477                       << ")) return matchFailure();\n";
478   }
479 
480   // Capture the value
481   auto name = tree.getArgName(index);
482   if (!name.empty()) {
483     os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
484   }
485 
486   indent -= 2;
487   os.indent(indent) << "}\n";
488 }
489 
490 void PatternEmitter::emitMatchLogic(DagNode tree) {
491   emitOpMatch(tree, 0);
492 
493   for (auto &appliedConstraint : pattern.getConstraints()) {
494     auto &constraint = appliedConstraint.constraint;
495     auto &entities = appliedConstraint.entities;
496 
497     auto condition = constraint.getConditionTemplate();
498     auto cmd = "if (!({0})) return matchFailure();\n";
499 
500     if (isa<TypeConstraint>(constraint)) {
501       auto self = formatv("({0}->getType())", resolveSymbol(entities.front()));
502       os.indent(4) << formatv(cmd,
503                               tgfmt(condition, &fmtCtx.withSelf(self.str())));
504     } else if (isa<AttrConstraint>(constraint)) {
505       PrintFatalError(
506           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
507     } else {
508       // TODO(b/138794486): replace formatv arguments with the exact specified
509       // args.
510       if (entities.size() > 4) {
511         PrintFatalError(loc, "only support up to 4-entity constraints now");
512       }
513       SmallVector<std::string, 4> names;
514       int i = 0;
515       for (int e = entities.size(); i < e; ++i)
516         names.push_back(resolveSymbol(entities[i]));
517       std::string self = appliedConstraint.self;
518       if (!self.empty())
519         self = resolveSymbol(self);
520       for (; i < 4; ++i)
521         names.push_back("<unused>");
522       os.indent(4) << formatv(cmd,
523                               tgfmt(condition, &fmtCtx.withSelf(self), names[0],
524                                     names[1], names[2], names[3]));
525     }
526   }
527 }
528 
529 void PatternEmitter::collectOps(DagNode tree,
530                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
531   // Check if this tree is an operation.
532   if (tree.isOperation())
533     ops.insert(&tree.getDialectOp(opMap));
534 
535   // Recurse the arguments of the tree.
536   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
537     if (auto child = tree.getArgAsNestedDag(i))
538       collectOps(child, ops);
539 }
540 
541 void PatternEmitter::emit(StringRef rewriteName) {
542   // Get the DAG tree for the source pattern.
543   DagNode sourceTree = pattern.getSourcePattern();
544 
545   const Operator &rootOp = pattern.getSourceRootOp();
546   auto rootName = rootOp.getOperationName();
547 
548   // Collect the set of result operations.
549   llvm::SmallPtrSet<const Operator *, 4> resultOps;
550   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i)
551     collectOps(pattern.getResultPattern(i), resultOps);
552 
553   // Emit RewritePattern for Pattern.
554   auto locs = pattern.getLocation();
555   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
556                 make_range(locs.rbegin(), locs.rend()));
557   os << formatv(R"(struct {0} : public RewritePattern {
558   {0}(MLIRContext *context)
559       : RewritePattern("{1}", {{)",
560                 rewriteName, rootName);
561   interleaveComma(resultOps, os, [&](const Operator *op) {
562     os << '"' << op->getOperationName() << '"';
563   });
564   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
565 
566   // Emit matchAndRewrite() function.
567   os << R"(
568   PatternMatchResult matchAndRewrite(Operation *op0,
569                                      PatternRewriter &rewriter) const override {
570 )";
571 
572   os.indent(4) << "// Variables for capturing values and attributes used for "
573                   "creating ops\n";
574   // Create local variables for storing the arguments bound to symbols.
575   for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
576     auto fieldName = arg.first();
577     if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) {
578       os.indent(4) << formatv("{0} {1};\n", namedAttr->attr.getStorageType(),
579                               fieldName);
580     } else {
581       os.indent(4) << "Value *" << fieldName << ";\n";
582     }
583   }
584   // Create local variables for storing the ops bound to symbols.
585   for (const auto &result : pattern.getSourcePatternBoundOps()) {
586     os.indent(4) << formatv("Operation *{0};\n", result.getKey());
587   }
588   // TODO(jpienaar): capture ops with consistent numbering so that it can be
589   // reused for fused loc.
590   os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
591                           pattern.getSourcePattern().getNumOps());
592 
593   os.indent(4) << "// Match\n";
594   os.indent(4) << "tblgen_ops[0] = op0;\n";
595   emitMatchLogic(sourceTree);
596   os << "\n";
597 
598   os.indent(4) << "// Rewrite\n";
599   emitRewriteLogic();
600 
601   os.indent(4) << "return matchSuccess();\n";
602   os << "  };\n";
603   os << "};\n";
604 }
605 
606 void PatternEmitter::emitRewriteLogic() {
607   const Operator &rootOp = pattern.getSourceRootOp();
608   int numExpectedResults = rootOp.getNumResults();
609   int numResultPatterns = pattern.getNumResultPatterns();
610 
611   // First register all symbols bound to ops generated in result patterns.
612   for (const auto &boundOp : pattern.getResultPatternBoundOps()) {
613     addSymbol(boundOp.getKey(), boundOp.getValue()->getNumResults());
614   }
615 
616   // Only the last N static values generated are used to replace the matched
617   // root N-result op. We need to calculate the starting index (of the results
618   // of the matched op) each result pattern is to replace.
619   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
620   int replStartIndex = -1;
621   for (int i = numResultPatterns - 1; i >= 0; --i) {
622     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
623     offsets[i] = offsets[i + 1] - numValues;
624     if (offsets[i] == 0) {
625       replStartIndex = i;
626     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
627       auto error = formatv(
628           "cannot use the same multi-result op '{0}' to generate both "
629           "auxiliary values and values to be used for replacing the matched op",
630           pattern.getResultPattern(i).getSymbol());
631       PrintFatalError(loc, error);
632     }
633   }
634 
635   if (offsets.front() > 0) {
636     const char error[] = "no enough values generated to replace the matched op";
637     PrintFatalError(loc, error);
638   }
639 
640   os.indent(4) << "auto loc = rewriter.getFusedLoc({";
641   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
642     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
643   }
644   os << "}); (void)loc;\n";
645 
646   // Collect the replacement value for each result
647   llvm::SmallVector<std::string, 2> resultValues;
648   for (int i = 0; i < numResultPatterns; ++i) {
649     DagNode resultTree = pattern.getResultPattern(i);
650     resultValues.push_back(handleResultPattern(resultTree, offsets[i], 0));
651   }
652 
653   // Emit the final replaceOp() statement
654   os.indent(4) << "rewriter.replaceOp(op0, {";
655   interleave(
656       ArrayRef<std::string>(resultValues).drop_front(replStartIndex),
657       [&](const std::string &name) { os << name; }, [&]() { os << ", "; });
658   os << "});\n";
659 }
660 
661 std::string PatternEmitter::getUniqueValueName(const Operator *op) {
662   return formatv("v{0}{1}", op->getCppClassName(), nextValueId++);
663 }
664 
665 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
666                                                 int resultIndex, int depth) {
667   if (resultTree.isNativeCodeCall())
668     return handleReplaceWithNativeCodeCall(resultTree);
669 
670   if (resultTree.isReplaceWithValue())
671     return handleReplaceWithValue(resultTree);
672 
673   // Create the op and get the local variable for it.
674   auto results = handleOpCreation(resultTree, resultIndex, depth);
675   // We need to get all the values out of this local variable if we've created a
676   // multi-result op.
677   const auto &numResults = pattern.getDialectOp(resultTree).getNumResults();
678   return formatValuePack("{0}.getOperation()->getResult({1})", results,
679                          numResults, /*offset=*/0);
680 }
681 
682 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
683   assert(tree.isReplaceWithValue());
684 
685   if (tree.getNumArgs() != 1) {
686     PrintFatalError(
687         loc, "replaceWithValue directive must take exactly one argument");
688   }
689 
690   if (!tree.getSymbol().empty()) {
691     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
692   }
693 
694   return resolveSymbol(tree.getArgName(0));
695 }
696 
697 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) {
698   if (leaf.isConstantAttr()) {
699     auto constAttr = leaf.getAsConstantAttr();
700     return handleConstantAttr(constAttr.getAttribute(),
701                               constAttr.getConstantValue());
702   }
703   if (leaf.isEnumAttrCase()) {
704     auto enumCase = leaf.getAsEnumAttrCase();
705     if (enumCase.isStrCase())
706       return handleConstantAttr(enumCase, enumCase.getSymbol());
707     // This is an enum case backed by an IntegerAttr. We need to get its value
708     // to build the constant.
709     std::string val = std::to_string(enumCase.getValue());
710     return handleConstantAttr(enumCase, val);
711   }
712   pattern.ensureBoundInSourcePattern(argName);
713   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
714     return argName;
715   }
716   if (leaf.isNativeCodeCall()) {
717     return tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
718   }
719   PrintFatalError(loc, "unhandled case when rewriting op");
720 }
721 
722 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
723   auto fmt = tree.getNativeCodeTemplate();
724   // TODO(b/138794486): replace formatv arguments with the exact specified args.
725   SmallVector<std::string, 8> attrs(8);
726   if (tree.getNumArgs() > 8) {
727     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
728                              Twine(tree.getNumArgs()));
729   }
730   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
731     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
732   }
733   return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
734                attrs[5], attrs[6], attrs[7]);
735 }
736 
737 void PatternEmitter::addSymbol(StringRef symbol, int numValues) {
738   if (!symbolResolver.add(symbol, numValues))
739     PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol));
740 }
741 
742 std::string PatternEmitter::resolveSymbol(StringRef symbol) {
743   auto subst = symbolResolver.query(symbol);
744   if (subst.empty())
745     PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol));
746   return subst;
747 }
748 
749 int PatternEmitter::getNodeValueCount(DagNode node) {
750   if (node.isOperation()) {
751     // First to see whether this op is bound and we just want a specific result
752     // of it with `__N` suffix in symbol.
753     int count = symbolResolver.getValueCount(node.getSymbol());
754     if (count >= 0)
755       return count;
756 
757     // No symbol. Then we are using all the results.
758     return pattern.getDialectOp(node).getNumResults();
759   }
760   // TODO(antiagainst): This considers all NativeCodeCall as returning one
761   // value. Enhance if multi-value ones are needed.
762   return 1;
763 }
764 
765 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
766                                              int depth) {
767   Operator &resultOp = tree.getDialectOp(opMap);
768   auto numOpArgs = resultOp.getNumArgs();
769 
770   if (resultOp.isVariadic()) {
771     PrintFatalError(loc, formatv("generating op '{0}' with variadic "
772                                  "operands/results is unsupported now",
773                                  resultOp.getOperationName()));
774   }
775 
776   if (numOpArgs != tree.getNumArgs()) {
777     PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
778                                  "{1} in pattern vs. {2} in definition",
779                                  resultOp.getOperationName(), tree.getNumArgs(),
780                                  numOpArgs));
781   }
782 
783   // A map to collect all nested DAG child nodes' names, with operand index as
784   // the key. This includes both bound and unbound child nodes. Bound child
785   // nodes will additionally be tracked in `symbolResolver` so they can be
786   // referenced by other patterns. Unbound child nodes will only be used once
787   // to build this op.
788   llvm::DenseMap<unsigned, std::string> childNodeNames;
789 
790   // First go through all the child nodes who are nested DAG constructs to
791   // create ops for them, so that we can use the results in the current node.
792   // This happens in a recursive manner.
793   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
794     if (auto child = tree.getArgAsNestedDag(i)) {
795       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
796     }
797   }
798 
799   // Use the specified name for this op if available. Generate one otherwise.
800   std::string resultValue = tree.getSymbol();
801   if (resultValue.empty())
802     resultValue = getUniqueValueName(&resultOp);
803   // Strip the index to get the name for the value pack. This will be used to
804   // name the local variable for the op.
805   StringRef valuePackName = getValuePackName(resultValue);
806 
807   // Then we build the new op corresponding to this DAG node.
808 
809   // Right now we don't have general type inference in MLIR. Except a few
810   // special cases listed below, we need to supply types for all results
811   // when building an op.
812   bool isSameOperandsAndResultType =
813       resultOp.hasTrait("SameOperandsAndResultType");
814   bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult");
815   bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType");
816   bool usePartialResults = valuePackName != resultValue;
817 
818   if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
819       usePartialResults || depth > 0 || resultIndex < 0) {
820     os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
821                             valuePackName, resultOp.getQualCppClassName());
822   } else {
823     // If depth == 0 and resultIndex >= 0, it means we are replacing the values
824     // generated from the source pattern root op. Then we can use the source
825     // pattern's value types to determine the value type of the generated op
826     // here.
827 
828     // We need to specify the types for all results.
829     auto resultTypes =
830         formatValuePack("op0->getResult({1})->getType()", valuePackName,
831                         resultOp.getNumResults(), resultIndex);
832 
833     os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
834                             valuePackName, resultOp.getQualCppClassName())
835                  << (resultTypes.empty() ? "" : ", ") << resultTypes;
836   }
837 
838   // Create the builder call for the result.
839   // Add operands.
840   int i = 0;
841   for (int e = resultOp.getNumOperands(); i < e; ++i) {
842     const auto &operand = resultOp.getOperand(i);
843 
844     // Start each operand on its own line.
845     (os << ",\n").indent(6);
846 
847     if (!operand.name.empty())
848       os << "/*" << operand.name << "=*/";
849 
850     if (tree.isNestedDagArg(i)) {
851       os << childNodeNames[i];
852     } else {
853       DagLeaf leaf = tree.getArgAsLeaf(i);
854       auto symbol = resolveSymbol(tree.getArgName(i));
855       if (leaf.isNativeCodeCall()) {
856         os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
857       } else {
858         os << symbol;
859       }
860     }
861     // TODO(jpienaar): verify types
862   }
863 
864   // Add attributes.
865   for (int e = tree.getNumArgs(); i != e; ++i) {
866     // Start each attribute on its own line.
867     (os << ",\n").indent(6);
868     // The argument in the op definition.
869     auto opArgName = resultOp.getArgName(i);
870     if (auto subTree = tree.getArgAsNestedDag(i)) {
871       if (!subTree.isNativeCodeCall())
872         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
873                              "for creating attribute");
874       os << formatv("/*{0}=*/{1}", opArgName,
875                     handleReplaceWithNativeCodeCall(subTree));
876     } else {
877       auto leaf = tree.getArgAsLeaf(i);
878       // The argument in the result DAG pattern.
879       auto patArgName = tree.getArgName(i);
880       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
881         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
882         auto argument = resultOp.getArg(i);
883         if (!argument.is<NamedAttribute *>())
884           PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
885         if (!patArgName.empty())
886           os << "/*" << patArgName << "=*/";
887       } else {
888         os << "/*" << opArgName << "=*/";
889       }
890       os << handleOpArgument(leaf, patArgName);
891     }
892   }
893   os << "\n    );\n";
894 
895   return resultValue;
896 }
897 
898 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
899   emitSourceFileHeader("Rewriters", os);
900 
901   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
902   auto numPatterns = patterns.size();
903 
904   // We put the map here because it can be shared among multiple patterns.
905   RecordOperatorMap recordOpMap;
906 
907   std::vector<std::string> rewriterNames;
908   rewriterNames.reserve(numPatterns);
909 
910   std::string baseRewriterName = "GeneratedConvert";
911   int rewriterIndex = 0;
912 
913   for (Record *p : patterns) {
914     std::string name;
915     if (p->isAnonymous()) {
916       // If no name is provided, ensure unique rewriter names simply by
917       // appending unique suffix.
918       name = baseRewriterName + llvm::utostr(rewriterIndex++);
919     } else {
920       name = p->getName();
921     }
922     PatternEmitter(p, &recordOpMap, os).emit(name);
923     rewriterNames.push_back(std::move(name));
924   }
925 
926   // Emit function to add the generated matchers to the pattern list.
927   os << "void populateWithGenerated(MLIRContext *context, "
928      << "OwningRewritePatternList *patterns) {\n";
929   for (const auto &name : rewriterNames) {
930     os << "  patterns->insert<" << name << ">(context);\n";
931   }
932   os << "}\n";
933 }
934 
935 static mlir::GenRegistration
936     genRewriters("gen-rewriters", "Generate pattern rewriters",
937                  [](const RecordKeeper &records, raw_ostream &os) {
938                    emitRewriters(records, os);
939                    return false;
940                  });
941