xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision c72d849eb9b416511c846e5baadb9270075acd2f)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Support/STLExtras.h"
23 #include "mlir/TableGen/Attribute.h"
24 #include "mlir/TableGen/Format.h"
25 #include "mlir/TableGen/GenInfo.h"
26 #include "mlir/TableGen/Operator.h"
27 #include "mlir/TableGen/Pattern.h"
28 #include "mlir/TableGen/Predicate.h"
29 #include "mlir/TableGen/Type.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/FormatAdapters.h"
34 #include "llvm/Support/PrettyStackTrace.h"
35 #include "llvm/Support/Signals.h"
36 #include "llvm/TableGen/Error.h"
37 #include "llvm/TableGen/Main.h"
38 #include "llvm/TableGen/Record.h"
39 #include "llvm/TableGen/TableGenBackend.h"
40 
41 using namespace llvm;
42 using namespace mlir;
43 using namespace mlir::tblgen;
44 
45 namespace llvm {
46 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
47   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
48                      raw_ostream &os, StringRef style) {
49     os << v.first << ":" << v.second;
50   }
51 };
52 } // end namespace llvm
53 
54 // Returns the bound symbol for the given op argument or op named `symbol`.
55 //
56 // Arguments and ops bound in the source pattern are grouped into a
57 // transient `PatternState` struct. This struct can be accessed in both
58 // `match()` and `rewrite()` via the local variable named as `s`.
59 static Twine getBoundSymbol(const StringRef &symbol) {
60   return Twine("s.") + symbol;
61 }
62 
63 // Gets the dynamic value pack's name by removing the index suffix from
64 // `symbol`. Returns `symbol` itself if it does not contain an index.
65 //
66 // We can use `name__<index>` to access the `<index>`-th value in the dynamic
67 // value pack bound to `name`. `name` is typically the results of an
68 // multi-result op.
69 static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) {
70   StringRef name, indexStr;
71   unsigned idx = 0;
72   std::tie(name, indexStr) = symbol.rsplit("__");
73   if (indexStr.consumeInteger(10, idx)) {
74     // The second part is not an index.
75     return symbol;
76   }
77   if (index)
78     *index = idx;
79   return name;
80 }
81 
82 // Formats all values from a dynamic value pack `symbol` according to the given
83 // `fmt` string. The `fmt` string should use `{0}` as a placeholder for `symbol`
84 // and `{1}` as a placeholder for the value index, which will be offsetted by
85 // `offset`. The `symbol` value pack has a total of `count` values.
86 //
87 // This extracts one value from the pack if `symbol` contains an index,
88 // otherwise it extracts all values sequentially and returns them as a
89 // comma-separated list.
90 static std::string formatValuePack(const char *fmt, StringRef symbol,
91                                    unsigned count, unsigned offset) {
92   auto getNthValue = [fmt, offset](StringRef results,
93                                    unsigned index) -> std::string {
94     return formatv(fmt, results, index + offset);
95   };
96 
97   unsigned index = 0;
98   StringRef name = getValuePackName(symbol, &index);
99   if (name != symbol) {
100     // The symbol contains an index.
101     return getNthValue(name, index);
102   }
103 
104   // The symbol does not contain an index. Treat the symbol as a whole.
105   SmallVector<std::string, 4> values;
106   values.reserve(count);
107   for (unsigned i = 0; i < count; ++i)
108     values.emplace_back(getNthValue(symbol, i));
109   return llvm::join(values, ", ");
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // PatternSymbolResolver
114 //===----------------------------------------------------------------------===//
115 
116 namespace {
117 // A class for resolving symbols bound in patterns.
118 //
119 // Symbols can be bound to op arguments and ops in the source pattern and ops
120 // in result patterns. For example, in
121 //
122 // ```
123 // def : Pattern<(SrcOp:$op1 $arg0, %arg1),
124 //               [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>;
125 // ```
126 //
127 // `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`.
128 // `$op2` is bound to `ResOp1`.
129 //
130 // If a symbol binds to a multi-result op and it does not have the `__N`
131 // suffix, the symbol is expanded to the whole value pack generated by the
132 // multi-result op. If the symbol has a `__N` suffix, then it will expand to
133 // only the N-th result.
134 //
135 // This class keeps track of such symbols and translates them into their bound
136 // values.
137 //
138 // Note that we also generate local variables for unnamed DAG nodes, like
139 // `(ResOp3)` in the above. Since we don't bind a symbol to the op, the
140 // generated local variable will be implicitly named. Those implicit names are
141 // not tracked in this class.
142 class PatternSymbolResolver {
143 public:
144   PatternSymbolResolver(const StringMap<Argument> &srcArgs,
145                         const StringMap<const Operator *> &srcOperations);
146 
147   // Marks the given `symbol` as bound to a value pack with `numValues` and
148   // returns true on success. Returns false if the `symbol` is already bound.
149   bool add(StringRef symbol, int numValues);
150 
151   // Queries the substitution for the given `symbol`. Returns empty string if
152   // symbol not found. If the symbol represents a value pack, returns all the
153   // values separated via comma.
154   std::string query(StringRef symbol) const;
155 
156   // Returns how many static values the given `symbol` correspond to. Returns a
157   // negative value if the given symbol is not bound.
158   //
159   // Normally a symbol would correspond to just one value; for symbols bound to
160   // multi-result ops, it can be more than one.
161   int getValueCount(StringRef symbol) const;
162 
163 private:
164   // Symbols bound to arguments in source pattern.
165   const StringMap<Argument> &sourceArguments;
166   // Symbols bound to ops (for their results) in source pattern.
167   const StringMap<const Operator *> &sourceOps;
168   // Symbols bound to ops (for their results) in result patterns.
169   // Key: symbol; value: number of values inside the pack
170   StringMap<int> resultOps;
171 };
172 } // end anonymous namespace
173 
174 PatternSymbolResolver::PatternSymbolResolver(
175     const StringMap<Argument> &srcArgs,
176     const StringMap<const Operator *> &srcOperations)
177     : sourceArguments(srcArgs), sourceOps(srcOperations) {}
178 
179 bool PatternSymbolResolver::add(StringRef symbol, int numValues) {
180   StringRef name = getValuePackName(symbol);
181   return resultOps.try_emplace(name, numValues).second;
182 }
183 
184 std::string PatternSymbolResolver::query(StringRef symbol) const {
185   StringRef name = getValuePackName(symbol);
186   // Handle symbols bound to generated ops
187   auto resOpIt = resultOps.find(name);
188   if (resOpIt != resultOps.end())
189     return formatValuePack("{0}.getOperation()->getResult({1})", symbol,
190                            resOpIt->second, /*offset=*/0);
191 
192   // Handle symbols bound to matched op arguments
193   auto srcArgIt = sourceArguments.find(symbol);
194   if (srcArgIt != sourceArguments.end())
195     return getBoundSymbol(symbol).str();
196 
197   // Handle symbols bound to matched op results
198   auto srcOpIt = sourceOps.find(name);
199   if (srcOpIt != sourceOps.end())
200     return formatValuePack("{0}->getResult({1})", getBoundSymbol(symbol).str(),
201                            srcOpIt->second->getNumResults(), /*offset=*/0);
202   return {};
203 }
204 
205 int PatternSymbolResolver::getValueCount(StringRef symbol) const {
206   StringRef name = getValuePackName(symbol);
207   // Handle symbols bound to generated ops
208   auto resOpIt = resultOps.find(name);
209   if (resOpIt != resultOps.end())
210     return name == symbol ? resOpIt->second : 1;
211 
212   // Handle symbols bound to matched op arguments
213   if (sourceArguments.count(symbol))
214     return 1;
215 
216   // Handle symbols bound to matched op results
217   auto srcOpIt = sourceOps.find(name);
218   if (srcOpIt != sourceOps.end())
219     return name == symbol ? srcOpIt->second->getNumResults() : 1;
220   return -1;
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // PatternEmitter
225 //===----------------------------------------------------------------------===//
226 
227 namespace {
228 class PatternEmitter {
229 public:
230   static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper,
231                    raw_ostream &os);
232 
233 private:
234   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
235 
236   // Emits the mlir::RewritePattern struct named `rewriteName`.
237   void emit(StringRef rewriteName);
238 
239   // Emits the match() method.
240   void emitMatchMethod(DagNode tree);
241 
242   // Collects all of the operations within the given dag tree.
243   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
244 
245   // Emits the rewrite() method.
246   void emitRewriteMethod();
247 
248   // Emits C++ statements for matching the op constrained by the given DAG
249   // `tree`.
250   void emitOpMatch(DagNode tree, int depth);
251 
252   // Emits C++ statements for matching the `index`-th argument of the given DAG
253   // `tree` as an operand.
254   void emitOperandMatch(DagNode tree, int index, int depth, int indent);
255 
256   // Emits C++ statements for matching the `index`-th argument of the given DAG
257   // `tree` as an attribute.
258   void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
259 
260   // Returns a unique name for an value of the given `op`.
261   std::string getUniqueValueName(const Operator *op);
262 
263   // Entry point for handling a rewrite pattern rooted at `resultTree` and
264   // dispatches to concrete handlers. The given tree is the `resultIndex`-th
265   // argument of the enclosing DAG.
266   std::string handleRewritePattern(DagNode resultTree, int resultIndex,
267                                    int depth);
268 
269   // Emits the C++ statement to replace the matched DAG with a value built via
270   // calling native C++ code.
271   std::string emitReplaceWithNativeCodeCall(DagNode resultTree);
272 
273   // Returns the C++ expression referencing the old value serving as the
274   // replacement.
275   std::string handleReplaceWithValue(DagNode tree);
276 
277   // Emits the C++ statement to build a new op out of the given DAG `tree` and
278   // returns the variable name that this op is assigned to. If the root op in
279   // DAG `tree` has a specified name, the created op will be assigned to a
280   // variable of the given name. Otherwise, a unique name will be used as the
281   // result value name.
282   std::string emitOpCreate(DagNode tree, int resultIndex, int depth);
283 
284   // Returns the C++ expression to construct a constant attribute of the given
285   // `value` for the given attribute kind `attr`.
286   std::string handleConstantAttr(Attribute attr, StringRef value);
287 
288   // Returns the C++ expression to build an argument from the given DAG `leaf`.
289   // `patArgName` is used to bound the argument to the source pattern.
290   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
291 
292   // Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol
293   // is already bound.
294   void addSymbol(StringRef symbol, int numValues);
295 
296   // Gets the substitution for `symbol`. Aborts if `symbol` is not bound.
297   std::string resolveSymbol(StringRef symbol);
298 
299   // Returns how many static values the given DAG `node` correspond to.
300   int getNodeValueCount(DagNode node);
301 
302 private:
303   // Pattern instantiation location followed by the location of multiclass
304   // prototypes used. This is intended to be used as a whole to
305   // PrintFatalError() on errors.
306   ArrayRef<llvm::SMLoc> loc;
307   // Op's TableGen Record to wrapper object
308   RecordOperatorMap *opMap;
309   // Handy wrapper for pattern being emitted
310   Pattern pattern;
311   PatternSymbolResolver symbolResolver;
312   // The next unused ID for newly created values
313   unsigned nextValueId;
314   raw_ostream &os;
315 
316   // Format contexts containing placeholder substitutations for match().
317   FmtContext matchCtx;
318   // Format contexts containing placeholder substitutations for rewrite().
319   FmtContext rewriteCtx;
320 
321   // Number of op processed.
322   int opCounter = 0;
323 };
324 } // end anonymous namespace
325 
326 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
327                                raw_ostream &os)
328     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
329       symbolResolver(pattern.getSourcePatternBoundArgs(),
330                      pattern.getSourcePatternBoundOps()),
331       nextValueId(0), os(os) {
332   matchCtx.withBuilder("mlir::Builder(ctx)");
333   rewriteCtx.withBuilder("rewriter");
334 }
335 
336 std::string PatternEmitter::handleConstantAttr(Attribute attr,
337                                                StringRef value) {
338   if (!attr.isConstBuildable())
339     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
340                              " does not have the 'constBuilderCall' field");
341 
342   // TODO(jpienaar): Verify the constants here
343   return tgfmt(attr.getConstBuilderTemplate(),
344                &rewriteCtx.withBuilder("rewriter"), value);
345 }
346 
347 // Helper function to match patterns.
348 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
349   Operator &op = tree.getDialectOp(opMap);
350   if (op.isVariadic()) {
351     PrintFatalError(loc, formatv("matching op '{0}' with variadic "
352                                  "operands/results is unsupported right now",
353                                  op.getOperationName()));
354   }
355 
356   int indent = 4 + 2 * depth;
357   // Skip the operand matching at depth 0 as the pattern rewriter already does.
358   if (depth != 0) {
359     // Skip if there is no defining operation (e.g., arguments to function).
360     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
361     os.indent(indent) << formatv(
362         "if (!isa<{1}>(op{0})) return matchFailure();\n", depth,
363         op.getQualCppClassName());
364   }
365   if (tree.getNumArgs() != op.getNumArgs()) {
366     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
367                                  "pattern vs. {2} in definition",
368                                  op.getOperationName(), tree.getNumArgs(),
369                                  op.getNumArgs()));
370   }
371 
372   // If the operand's name is set, set to that variable.
373   auto name = tree.getSymbol();
374   if (!name.empty())
375     os.indent(indent) << formatv("{0} = op{1};\n", getBoundSymbol(name), depth);
376 
377   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
378     auto opArg = op.getArg(i);
379 
380     // Handle nested DAG construct first
381     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
382       os.indent(indent) << "{\n";
383       os.indent(indent + 2)
384           << formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
385                      depth + 1, depth, i);
386       emitOpMatch(argTree, depth + 1);
387       os.indent(indent + 2)
388           << formatv("s.autogeneratedRewritePatternOps[{0}] = op{1};\n",
389                      ++opCounter, depth + 1);
390       os.indent(indent) << "}\n";
391       continue;
392     }
393 
394     // Next handle DAG leaf: operand or attribute
395     if (opArg.is<NamedTypeConstraint *>()) {
396       emitOperandMatch(tree, i, depth, indent);
397     } else if (opArg.is<NamedAttribute *>()) {
398       emitAttributeMatch(tree, i, depth, indent);
399     } else {
400       PrintFatalError(loc, "unhandled case when matching op");
401     }
402   }
403 }
404 
405 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
406                                       int indent) {
407   Operator &op = tree.getDialectOp(opMap);
408   auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
409   auto matcher = tree.getArgAsLeaf(index);
410 
411   // If a constraint is specified, we need to generate C++ statements to
412   // check the constraint.
413   if (!matcher.isUnspecified()) {
414     if (!matcher.isOperandMatcher()) {
415       PrintFatalError(
416           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
417                        op.getOperationName(), index + 1));
418     }
419 
420     // Only need to verify if the matcher's type is different from the one
421     // of op definition.
422     if (operand->constraint != matcher.getAsConstraint()) {
423       auto self = formatv("op{0}->getOperand({1})->getType()", depth, index);
424       os.indent(indent) << "if (!("
425                         << tgfmt(matcher.getConditionTemplate(),
426                                  &matchCtx.withSelf(self))
427                         << ")) return matchFailure();\n";
428     }
429   }
430 
431   // Capture the value
432   auto name = tree.getArgName(index);
433   if (!name.empty()) {
434     os.indent(indent) << getBoundSymbol(name) << " = op" << depth
435                       << "->getOperand(" << index << ");\n";
436   }
437 }
438 
439 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
440                                         int indent) {
441   Operator &op = tree.getDialectOp(opMap);
442   auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
443   const auto &attr = namedAttr->attr;
444 
445   os.indent(indent) << "{\n";
446   indent += 2;
447   os.indent(indent) << formatv(
448       "auto attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
449       attr.getStorageType(), namedAttr->name);
450 
451   // TODO(antiagainst): This should use getter method to avoid duplication.
452   if (attr.hasDefaultValueInitializer()) {
453     os.indent(indent) << "if (!attr) attr = "
454                       << tgfmt(attr.getConstBuilderTemplate(), &matchCtx,
455                                attr.getDefaultValueInitializer())
456                       << ";\n";
457   } else if (attr.isOptional()) {
458     // For a missing attribut that is optional according to definition, we
459     // should just capature a mlir::Attribute() to signal the missing state.
460     // That is precisely what getAttr() returns on missing attributes.
461   } else {
462     os.indent(indent) << "if (!attr) return matchFailure();\n";
463   }
464 
465   auto matcher = tree.getArgAsLeaf(index);
466   if (!matcher.isUnspecified()) {
467     if (!matcher.isAttrMatcher()) {
468       PrintFatalError(
469           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
470                        op.getOperationName(), index + 1));
471     }
472 
473     // If a constraint is specified, we need to generate C++ statements to
474     // check the constraint.
475     os.indent(indent) << "if (!("
476                       << tgfmt(matcher.getConditionTemplate(),
477                                &matchCtx.withSelf("attr"))
478                       << ")) return matchFailure();\n";
479   }
480 
481   // Capture the value
482   auto name = tree.getArgName(index);
483   if (!name.empty()) {
484     os.indent(indent) << getBoundSymbol(name) << " = attr;\n";
485   }
486 
487   indent -= 2;
488   os.indent(indent) << "}\n";
489 }
490 
491 void PatternEmitter::emitMatchMethod(DagNode tree) {
492   // Emit the heading.
493   os << R"(
494   PatternMatchResult match(Operation *op0) const override {
495     auto ctx = op0->getContext(); (void)ctx;
496     auto state = llvm::make_unique<MatchedState>();
497     auto &s = *state;
498     s.autogeneratedRewritePatternOps[0] = op0;
499 )";
500 
501   emitOpMatch(tree, 0);
502 
503   for (auto &appliedConstraint : pattern.getConstraints()) {
504     auto &constraint = appliedConstraint.constraint;
505     auto &entities = appliedConstraint.entities;
506 
507     auto condition = constraint.getConditionTemplate();
508     auto cmd = "if (!({0})) return matchFailure();\n";
509 
510     if (isa<TypeConstraint>(constraint)) {
511       auto self = formatv("({0}->getType())", resolveSymbol(entities.front()));
512       os.indent(4) << formatv(cmd,
513                               tgfmt(condition, &matchCtx.withSelf(self.str())));
514     } else if (isa<AttrConstraint>(constraint)) {
515       PrintFatalError(
516           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
517     } else {
518       // TODO(b/138794486): replace formatv arguments with the exact specified
519       // args.
520       if (entities.size() > 4) {
521         PrintFatalError(loc, "only support up to 4-entity constraints now");
522       }
523       SmallVector<std::string, 4> names;
524       int i = 0;
525       for (int e = entities.size(); i < e; ++i)
526         names.push_back(resolveSymbol(entities[i]));
527       std::string self = appliedConstraint.self;
528       if (!self.empty())
529         self = resolveSymbol(self);
530       for (; i < 4; ++i)
531         names.push_back("<unused>");
532       os.indent(4) << formatv(cmd,
533                               tgfmt(condition, &matchCtx.withSelf(self),
534                                     names[0], names[1], names[2], names[3]));
535     }
536   }
537 
538   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
539 }
540 
541 void PatternEmitter::collectOps(DagNode tree,
542                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
543   // Check if this tree is an operation.
544   if (tree.isOperation())
545     ops.insert(&tree.getDialectOp(opMap));
546 
547   // Recurse the arguments of the tree.
548   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
549     if (auto child = tree.getArgAsNestedDag(i))
550       collectOps(child, ops);
551 }
552 
553 void PatternEmitter::emit(StringRef rewriteName) {
554   // Get the DAG tree for the source pattern
555   DagNode tree = pattern.getSourcePattern();
556 
557   const Operator &rootOp = pattern.getSourceRootOp();
558   auto rootName = rootOp.getOperationName();
559 
560   // Collect the set of result operations.
561   llvm::SmallPtrSet<const Operator *, 4> results;
562   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i)
563     collectOps(pattern.getResultPattern(i), results);
564 
565   // Emit RewritePattern for Pattern.
566   auto locs = pattern.getLocation();
567   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
568                 make_range(locs.rbegin(), locs.rend()));
569   os << formatv(R"(struct {0} : public RewritePattern {
570   {0}(MLIRContext *context) : RewritePattern("{1}", {{)",
571                 rewriteName, rootName);
572   interleaveComma(results, os, [&](const Operator *op) {
573     os << '"' << op->getOperationName() << '"';
574   });
575   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
576 
577   // Emit matched state.
578   os << "  struct MatchedState : public PatternState {\n";
579   for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
580     auto fieldName = arg.first();
581     if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) {
582       os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName
583                    << ";\n";
584     } else {
585       os.indent(4) << "Value *" << fieldName << ";\n";
586     }
587   }
588   for (const auto &result : pattern.getSourcePatternBoundOps()) {
589     os.indent(4) << "Operation *" << result.getKey() << ";\n";
590   }
591   // TODO(jpienaar): Change to matchAndRewrite & capture ops with consistent
592   // numbering so that it can be reused for fused loc.
593   os.indent(4) << "Operation* autogeneratedRewritePatternOps["
594                << pattern.getSourcePattern().getNumOps() << "];\n";
595   os << "  };\n";
596 
597   emitMatchMethod(tree);
598   emitRewriteMethod();
599 
600   os << "};\n";
601 }
602 
603 void PatternEmitter::emitRewriteMethod() {
604   const Operator &rootOp = pattern.getSourceRootOp();
605   int numExpectedResults = rootOp.getNumResults();
606   int numResultPatterns = pattern.getNumResultPatterns();
607 
608   // First register all symbols bound to ops generated in result patterns.
609   for (const auto &boundOp : pattern.getResultPatternBoundOps()) {
610     addSymbol(boundOp.getKey(), boundOp.getValue()->getNumResults());
611   }
612 
613   // Only the last N static values generated are used to replace the matched
614   // root N-result op. We need to calculate the starting index (of the results
615   // of the matched op) each result pattern is to replace.
616   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
617   int replStartIndex = -1;
618   for (int i = numResultPatterns - 1; i >= 0; --i) {
619     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
620     offsets[i] = offsets[i + 1] - numValues;
621     if (offsets[i] == 0) {
622       replStartIndex = i;
623     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
624       auto error = formatv(
625           "cannot use the same multi-result op '{0}' to generate both "
626           "auxiliary values and values to be used for replacing the matched op",
627           pattern.getResultPattern(i).getSymbol());
628       PrintFatalError(loc, error);
629     }
630   }
631 
632   if (offsets.front() > 0) {
633     const char error[] = "no enough values generated to replace the matched op";
634     PrintFatalError(loc, error);
635   }
636 
637   os << R"(
638   void rewrite(Operation *op, std::unique_ptr<PatternState> state,
639                PatternRewriter &rewriter) const override {
640     auto& s = *static_cast<MatchedState *>(state.get());
641     auto loc = rewriter.getFusedLoc({)";
642   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
643     os << (i ? ", " : "") << "s.autogeneratedRewritePatternOps[" << i
644        << "]->getLoc()";
645   }
646   os << "}); (void)loc;\n";
647 
648   // Collect the replacement value for each result
649   llvm::SmallVector<std::string, 2> resultValues;
650   for (int i = 0; i < numResultPatterns; ++i) {
651     DagNode resultTree = pattern.getResultPattern(i);
652     resultValues.push_back(handleRewritePattern(resultTree, offsets[i], 0));
653   }
654 
655   // Emit the final replaceOp() statement
656   os.indent(4) << "rewriter.replaceOp(op, {";
657   interleave(
658       ArrayRef<std::string>(resultValues).drop_front(replStartIndex),
659       [&](const std::string &name) { os << name; }, [&]() { os << ", "; });
660   os << "});\n  }\n";
661 }
662 
663 std::string PatternEmitter::getUniqueValueName(const Operator *op) {
664   return formatv("v{0}{1}", op->getCppClassName(), nextValueId++);
665 }
666 
667 std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
668                                                  int resultIndex, int depth) {
669   if (resultTree.isNativeCodeCall())
670     return emitReplaceWithNativeCodeCall(resultTree);
671 
672   if (resultTree.isReplaceWithValue())
673     return handleReplaceWithValue(resultTree);
674 
675   // Create the op and get the local variable for it.
676   auto results = emitOpCreate(resultTree, resultIndex, depth);
677   // We need to get all the values out of this local variable if we've created a
678   // multi-result op.
679   const auto &numResults = pattern.getDialectOp(resultTree).getNumResults();
680   return formatValuePack("{0}.getOperation()->getResult({1})", results,
681                          numResults, /*offset=*/0);
682 }
683 
684 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
685   assert(tree.isReplaceWithValue());
686 
687   if (tree.getNumArgs() != 1) {
688     PrintFatalError(
689         loc, "replaceWithValue directive must take exactly one argument");
690   }
691 
692   if (!tree.getSymbol().empty()) {
693     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
694   }
695 
696   return resolveSymbol(tree.getArgName(0));
697 }
698 
699 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) {
700   if (leaf.isConstantAttr()) {
701     auto constAttr = leaf.getAsConstantAttr();
702     return handleConstantAttr(constAttr.getAttribute(),
703                               constAttr.getConstantValue());
704   }
705   if (leaf.isEnumAttrCase()) {
706     auto enumCase = leaf.getAsEnumAttrCase();
707     if (enumCase.isStrCase())
708       return handleConstantAttr(enumCase, enumCase.getSymbol());
709     // This is an enum case backed by an IntegerAttr. We need to get its value
710     // to build the constant.
711     std::string val = std::to_string(enumCase.getValue());
712     return handleConstantAttr(enumCase, val);
713   }
714   pattern.ensureBoundInSourcePattern(argName);
715   std::string result = getBoundSymbol(argName).str();
716   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
717     return result;
718   }
719   if (leaf.isNativeCodeCall()) {
720     return tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(result));
721   }
722   PrintFatalError(loc, "unhandled case when rewriting op");
723 }
724 
725 std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) {
726   auto fmt = tree.getNativeCodeTemplate();
727   // TODO(b/138794486): replace formatv arguments with the exact specified args.
728   SmallVector<std::string, 8> attrs(8);
729   if (tree.getNumArgs() > 8) {
730     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
731                              Twine(tree.getNumArgs()));
732   }
733   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
734     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
735   }
736   return tgfmt(fmt, &rewriteCtx, attrs[0], attrs[1], attrs[2], attrs[3],
737                attrs[4], attrs[5], attrs[6], attrs[7]);
738 }
739 
740 void PatternEmitter::addSymbol(StringRef symbol, int numValues) {
741   if (!symbolResolver.add(symbol, numValues))
742     PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol));
743 }
744 
745 std::string PatternEmitter::resolveSymbol(StringRef symbol) {
746   auto subst = symbolResolver.query(symbol);
747   if (subst.empty())
748     PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol));
749   return subst;
750 }
751 
752 int PatternEmitter::getNodeValueCount(DagNode node) {
753   if (node.isOperation()) {
754     // First to see whether this op is bound and we just want a specific result
755     // of it with `__N` suffix in symbol.
756     int count = symbolResolver.getValueCount(node.getSymbol());
757     if (count >= 0)
758       return count;
759 
760     // No symbol. Then we are using all the results.
761     return pattern.getDialectOp(node).getNumResults();
762   }
763   // TODO(antiagainst): This considers all NativeCodeCall as returning one
764   // value. Enhance if multi-value ones are needed.
765   return 1;
766 }
767 
768 std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
769                                          int depth) {
770   Operator &resultOp = tree.getDialectOp(opMap);
771   auto numOpArgs = resultOp.getNumArgs();
772 
773   if (resultOp.isVariadic()) {
774     PrintFatalError(loc, formatv("generating op '{0}' with variadic "
775                                  "operands/results is unsupported now",
776                                  resultOp.getOperationName()));
777   }
778 
779   if (numOpArgs != tree.getNumArgs()) {
780     PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
781                                  "{1} in pattern vs. {2} in definition",
782                                  resultOp.getOperationName(), tree.getNumArgs(),
783                                  numOpArgs));
784   }
785 
786   // A map to collect all nested DAG child nodes' names, with operand index as
787   // the key. This includes both bound and unbound child nodes. Bound child
788   // nodes will additionally be tracked in `symbolResolver` so they can be
789   // referenced by other patterns. Unbound child nodes will only be used once
790   // to build this op.
791   llvm::DenseMap<unsigned, std::string> childNodeNames;
792 
793   // First go through all the child nodes who are nested DAG constructs to
794   // create ops for them, so that we can use the results in the current node.
795   // This happens in a recursive manner.
796   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
797     if (auto child = tree.getArgAsNestedDag(i)) {
798       childNodeNames[i] = handleRewritePattern(child, i, depth + 1);
799     }
800   }
801 
802   // Use the specified name for this op if available. Generate one otherwise.
803   std::string resultValue = tree.getSymbol();
804   if (resultValue.empty())
805     resultValue = getUniqueValueName(&resultOp);
806   // Strip the index to get the name for the value pack. This will be used to
807   // name the local variable for the op.
808   StringRef valuePackName = getValuePackName(resultValue);
809 
810   // Then we build the new op corresponding to this DAG node.
811 
812   // Right now we don't have general type inference in MLIR. Except a few
813   // special cases listed below, we need to supply types for all results
814   // when building an op.
815   bool isSameOperandsAndResultType =
816       resultOp.hasTrait("SameOperandsAndResultType");
817   bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult");
818   bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType");
819   bool usePartialResults = valuePackName != resultValue;
820 
821   if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
822       usePartialResults || depth > 0 || resultIndex < 0) {
823     os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
824                             valuePackName, resultOp.getQualCppClassName());
825   } else {
826     // If depth == 0 and resultIndex >= 0, it means we are replacing the values
827     // generated from the source pattern root op. Then we can use the source
828     // pattern's value types to determine the value type of the generated op
829     // here.
830 
831     // We need to specify the types for all results.
832     auto resultTypes =
833         formatValuePack("op->getResult({1})->getType()", valuePackName,
834                         resultOp.getNumResults(), resultIndex);
835 
836     os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
837                             valuePackName, resultOp.getQualCppClassName())
838                  << (resultTypes.empty() ? "" : ", ") << resultTypes;
839   }
840 
841   // Create the builder call for the result.
842   // Add operands.
843   int i = 0;
844   for (int e = resultOp.getNumOperands(); i < e; ++i) {
845     const auto &operand = resultOp.getOperand(i);
846 
847     // Start each operand on its own line.
848     (os << ",\n").indent(6);
849 
850     if (!operand.name.empty())
851       os << "/*" << operand.name << "=*/";
852 
853     if (tree.isNestedDagArg(i)) {
854       os << childNodeNames[i];
855     } else {
856       DagLeaf leaf = tree.getArgAsLeaf(i);
857       auto symbol = resolveSymbol(tree.getArgName(i));
858       if (leaf.isNativeCodeCall()) {
859         os << tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(symbol));
860       } else {
861         os << symbol;
862       }
863     }
864     // TODO(jpienaar): verify types
865   }
866 
867   // Add attributes.
868   for (int e = tree.getNumArgs(); i != e; ++i) {
869     // Start each attribute on its own line.
870     (os << ",\n").indent(6);
871     // The argument in the op definition.
872     auto opArgName = resultOp.getArgName(i);
873     if (auto subTree = tree.getArgAsNestedDag(i)) {
874       if (!subTree.isNativeCodeCall())
875         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
876                              "for creating attribute");
877       os << formatv("/*{0}=*/{1}", opArgName,
878                     emitReplaceWithNativeCodeCall(subTree));
879     } else {
880       auto leaf = tree.getArgAsLeaf(i);
881       // The argument in the result DAG pattern.
882       auto patArgName = tree.getArgName(i);
883       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
884         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
885         auto argument = resultOp.getArg(i);
886         if (!argument.is<NamedAttribute *>())
887           PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
888         if (!patArgName.empty())
889           os << "/*" << patArgName << "=*/";
890       } else {
891         os << "/*" << opArgName << "=*/";
892       }
893       os << handleOpArgument(leaf, patArgName);
894     }
895   }
896   os << "\n    );\n";
897 
898   return resultValue;
899 }
900 
901 void PatternEmitter::emit(StringRef rewriteName, Record *p,
902                           RecordOperatorMap *mapper, raw_ostream &os) {
903   PatternEmitter(p, mapper, os).emit(rewriteName);
904 }
905 
906 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
907   emitSourceFileHeader("Rewriters", os);
908 
909   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
910   auto numPatterns = patterns.size();
911 
912   // We put the map here because it can be shared among multiple patterns.
913   RecordOperatorMap recordOpMap;
914 
915   std::vector<std::string> rewriterNames;
916   rewriterNames.reserve(numPatterns);
917 
918   std::string baseRewriterName = "GeneratedConvert";
919   int rewriterIndex = 0;
920 
921   for (Record *p : patterns) {
922     std::string name;
923     if (p->isAnonymous()) {
924       // If no name is provided, ensure unique rewriter names simply by
925       // appending unique suffix.
926       name = baseRewriterName + llvm::utostr(rewriterIndex++);
927     } else {
928       name = p->getName();
929     }
930     PatternEmitter::emit(name, p, &recordOpMap, os);
931     rewriterNames.push_back(std::move(name));
932   }
933 
934   // Emit function to add the generated matchers to the pattern list.
935   os << "void populateWithGenerated(MLIRContext *context, "
936      << "OwningRewritePatternList *patterns) {\n";
937   for (const auto &name : rewriterNames) {
938     os << "  patterns->push_back(llvm::make_unique<" << name
939        << ">(context));\n";
940   }
941   os << "}\n";
942 }
943 
944 static mlir::GenRegistration
945     genRewriters("gen-rewriters", "Generate pattern rewriters",
946                  [](const RecordKeeper &records, raw_ostream &os) {
947                    emitRewriters(records, os);
948                    return false;
949                  });
950