xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 020f9eb68c8808c096228e20a40e5a217f4d2aab)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Support/STLExtras.h"
23 #include "mlir/TableGen/Attribute.h"
24 #include "mlir/TableGen/Format.h"
25 #include "mlir/TableGen/GenInfo.h"
26 #include "mlir/TableGen/Operator.h"
27 #include "mlir/TableGen/Pattern.h"
28 #include "mlir/TableGen/Predicate.h"
29 #include "mlir/TableGen/Type.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/FormatAdapters.h"
35 #include "llvm/Support/PrettyStackTrace.h"
36 #include "llvm/Support/Signals.h"
37 #include "llvm/TableGen/Error.h"
38 #include "llvm/TableGen/Main.h"
39 #include "llvm/TableGen/Record.h"
40 #include "llvm/TableGen/TableGenBackend.h"
41 
42 using namespace llvm;
43 using namespace mlir;
44 using namespace mlir::tblgen;
45 
46 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
47 
48 namespace llvm {
49 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
50   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
51                      raw_ostream &os, StringRef style) {
52     os << v.first << ":" << v.second;
53   }
54 };
55 } // end namespace llvm
56 
57 //===----------------------------------------------------------------------===//
58 // PatternEmitter
59 //===----------------------------------------------------------------------===//
60 
61 namespace {
62 class PatternEmitter {
63 public:
64   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
65 
66   // Emits the mlir::RewritePattern struct named `rewriteName`.
67   void emit(StringRef rewriteName);
68 
69 private:
70   // Emits the code for matching ops.
71   void emitMatchLogic(DagNode tree);
72 
73   // Emits the code for rewriting ops.
74   void emitRewriteLogic();
75 
76   //===--------------------------------------------------------------------===//
77   // Match utilities
78   //===--------------------------------------------------------------------===//
79 
80   // Emits C++ statements for matching the op constrained by the given DAG
81   // `tree`.
82   void emitOpMatch(DagNode tree, int depth);
83 
84   // Emits C++ statements for matching the `argIndex`-th argument of the given
85   // DAG `tree` as an operand.
86   void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent);
87 
88   // Emits C++ statements for matching the `argIndex`-th argument of the given
89   // DAG `tree` as an attribute.
90   void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent);
91 
92   //===--------------------------------------------------------------------===//
93   // Rewrite utilities
94   //===--------------------------------------------------------------------===//
95 
96   // The entry point for handling a result pattern rooted at `resultTree`. This
97   // method dispatches to concrete handlers according to `resultTree`'s kind and
98   // returns a symbol representing the whole value pack. Callers are expected to
99   // further resolve the symbol according to the specific use case.
100   //
101   // `depth` is the nesting level of `resultTree`; 0 means top-level result
102   // pattern. For top-level result pattern, `resultIndex` indicates which result
103   // of the matched root op this pattern is intended to replace, which can be
104   // used to deduce the result type of the op generated from this result
105   // pattern.
106   std::string handleResultPattern(DagNode resultTree, int resultIndex,
107                                   int depth);
108 
109   // Emits the C++ statement to replace the matched DAG with a value built via
110   // calling native C++ code.
111   std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
112 
113   // Returns the C++ expression referencing the old value serving as the
114   // replacement.
115   std::string handleReplaceWithValue(DagNode tree);
116 
117   // Emits the C++ statement to build a new op out of the given DAG `tree` and
118   // returns the variable name that this op is assigned to. If the root op in
119   // DAG `tree` has a specified name, the created op will be assigned to a
120   // variable of the given name. Otherwise, a unique name will be used as the
121   // result value name.
122   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
123 
124   // Returns the C++ expression to construct a constant attribute of the given
125   // `value` for the given attribute kind `attr`.
126   std::string handleConstantAttr(Attribute attr, StringRef value);
127 
128   // Returns the C++ expression to build an argument from the given DAG `leaf`.
129   // `patArgName` is used to bound the argument to the source pattern.
130   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
131 
132   //===--------------------------------------------------------------------===//
133   // General utilities
134   //===--------------------------------------------------------------------===//
135 
136   // Collects all of the operations within the given dag tree.
137   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
138 
139   // Returns a unique symbol for a local variable of the given `op`.
140   std::string getUniqueSymbol(const Operator *op);
141 
142   //===--------------------------------------------------------------------===//
143   // Symbol utilities
144   //===--------------------------------------------------------------------===//
145 
146   // Returns how many static values the given DAG `node` correspond to.
147   int getNodeValueCount(DagNode node);
148 
149 private:
150   // Pattern instantiation location followed by the location of multiclass
151   // prototypes used. This is intended to be used as a whole to
152   // PrintFatalError() on errors.
153   ArrayRef<llvm::SMLoc> loc;
154 
155   // Op's TableGen Record to wrapper object.
156   RecordOperatorMap *opMap;
157 
158   // Handy wrapper for pattern being emitted.
159   Pattern pattern;
160 
161   // Map for all bound symbols' info.
162   SymbolInfoMap symbolInfoMap;
163 
164   // The next unused ID for newly created values.
165   unsigned nextValueId;
166 
167   raw_ostream &os;
168 
169   // Format contexts containing placeholder substitutions.
170   FmtContext fmtCtx;
171 
172   // Number of op processed.
173   int opCounter = 0;
174 };
175 } // end anonymous namespace
176 
177 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
178                                raw_ostream &os)
179     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
180       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
181   fmtCtx.withBuilder("rewriter");
182 }
183 
184 std::string PatternEmitter::handleConstantAttr(Attribute attr,
185                                                StringRef value) {
186   if (!attr.isConstBuildable())
187     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
188                              " does not have the 'constBuilderCall' field");
189 
190   // TODO(jpienaar): Verify the constants here
191   return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value);
192 }
193 
194 // Helper function to match patterns.
195 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
196   Operator &op = tree.getDialectOp(opMap);
197   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
198                           << op.getOperationName() << "' at depth " << depth
199                           << '\n');
200 
201   int indent = 4 + 2 * depth;
202   os.indent(indent) << formatv(
203       "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
204       depth, op.getQualCppClassName());
205   // Skip the operand matching at depth 0 as the pattern rewriter already does.
206   if (depth != 0) {
207     // Skip if there is no defining operation (e.g., arguments to function).
208     os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
209                                  depth);
210   }
211   if (tree.getNumArgs() != op.getNumArgs()) {
212     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
213                                  "pattern vs. {2} in definition",
214                                  op.getOperationName(), tree.getNumArgs(),
215                                  op.getNumArgs()));
216   }
217 
218   // If the operand's name is set, set to that variable.
219   auto name = tree.getSymbol();
220   if (!name.empty())
221     os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
222 
223   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
224     auto opArg = op.getArg(i);
225 
226     // Handle nested DAG construct first
227     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
228       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
229         if (operand->isVariadic()) {
230           auto error = formatv("use nested DAG construct to match op {0}'s "
231                                "variadic operand #{1} unsupported now",
232                                op.getOperationName(), i);
233           PrintFatalError(loc, error);
234         }
235       }
236       os.indent(indent) << "{\n";
237 
238       os.indent(indent + 2) << formatv(
239           "auto *op{0} = "
240           "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n",
241           depth + 1, depth, i);
242       emitOpMatch(argTree, depth + 1);
243       os.indent(indent + 2)
244           << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
245       os.indent(indent) << "}\n";
246       continue;
247     }
248 
249     // Next handle DAG leaf: operand or attribute
250     if (opArg.is<NamedTypeConstraint *>()) {
251       emitOperandMatch(tree, i, depth, indent);
252     } else if (opArg.is<NamedAttribute *>()) {
253       emitAttributeMatch(tree, i, depth, indent);
254     } else {
255       PrintFatalError(loc, "unhandled case when matching op");
256     }
257   }
258   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
259                           << op.getOperationName() << "' at depth " << depth
260                           << '\n');
261 }
262 
263 void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
264                                       int indent) {
265   Operator &op = tree.getDialectOp(opMap);
266   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
267   auto matcher = tree.getArgAsLeaf(argIndex);
268 
269   // If a constraint is specified, we need to generate C++ statements to
270   // check the constraint.
271   if (!matcher.isUnspecified()) {
272     if (!matcher.isOperandMatcher()) {
273       PrintFatalError(
274           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
275                        op.getOperationName(), argIndex + 1));
276     }
277 
278     // Only need to verify if the matcher's type is different from the one
279     // of op definition.
280     if (operand->constraint != matcher.getAsConstraint()) {
281       if (operand->isVariadic()) {
282         auto error = formatv(
283             "further constrain op {0}'s variadic operand #{1} unsupported now",
284             op.getOperationName(), argIndex);
285         PrintFatalError(loc, error);
286       }
287       auto self =
288           formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
289                   depth, argIndex);
290       os.indent(indent) << "if (!("
291                         << tgfmt(matcher.getConditionTemplate(),
292                                  &fmtCtx.withSelf(self))
293                         << ")) return matchFailure();\n";
294     }
295   }
296 
297   // Capture the value
298   auto name = tree.getArgName(argIndex);
299   if (!name.empty()) {
300     // We need to subtract the number of attributes before this operand to get
301     // the index in the operand list.
302     auto numPrevAttrs = std::count_if(
303         op.arg_begin(), op.arg_begin() + argIndex,
304         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
305 
306     os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
307                                  name, depth, argIndex - numPrevAttrs);
308   }
309 }
310 
311 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
312                                         int indent) {
313   Operator &op = tree.getDialectOp(opMap);
314   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
315   const auto &attr = namedAttr->attr;
316 
317   os.indent(indent) << "{\n";
318   indent += 2;
319   os.indent(indent) << formatv(
320       "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
321       attr.getStorageType(), namedAttr->name);
322 
323   // TODO(antiagainst): This should use getter method to avoid duplication.
324   if (attr.hasDefaultValueInitializer()) {
325     os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
326                       << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
327                                attr.getDefaultValueInitializer())
328                       << ";\n";
329   } else if (attr.isOptional()) {
330     // For a missing attribute that is optional according to definition, we
331     // should just capture a mlir::Attribute() to signal the missing state.
332     // That is precisely what getAttr() returns on missing attributes.
333   } else {
334     os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
335   }
336 
337   auto matcher = tree.getArgAsLeaf(argIndex);
338   if (!matcher.isUnspecified()) {
339     if (!matcher.isAttrMatcher()) {
340       PrintFatalError(
341           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
342                        op.getOperationName(), argIndex + 1));
343     }
344 
345     // If a constraint is specified, we need to generate C++ statements to
346     // check the constraint.
347     os.indent(indent) << "if (!("
348                       << tgfmt(matcher.getConditionTemplate(),
349                                &fmtCtx.withSelf("tblgen_attr"))
350                       << ")) return matchFailure();\n";
351   }
352 
353   // Capture the value
354   auto name = tree.getArgName(argIndex);
355   if (!name.empty()) {
356     os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
357   }
358 
359   indent -= 2;
360   os.indent(indent) << "}\n";
361 }
362 
363 void PatternEmitter::emitMatchLogic(DagNode tree) {
364   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
365   emitOpMatch(tree, 0);
366 
367   for (auto &appliedConstraint : pattern.getConstraints()) {
368     auto &constraint = appliedConstraint.constraint;
369     auto &entities = appliedConstraint.entities;
370 
371     auto condition = constraint.getConditionTemplate();
372     auto cmd = "if (!({0})) return matchFailure();\n";
373 
374     if (isa<TypeConstraint>(constraint)) {
375       auto self = formatv("({0}->getType())",
376                           symbolInfoMap.getValueAndRangeUse(entities.front()));
377       os.indent(4) << formatv(cmd,
378                               tgfmt(condition, &fmtCtx.withSelf(self.str())));
379     } else if (isa<AttrConstraint>(constraint)) {
380       PrintFatalError(
381           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
382     } else {
383       // TODO(b/138794486): replace formatv arguments with the exact specified
384       // args.
385       if (entities.size() > 4) {
386         PrintFatalError(loc, "only support up to 4-entity constraints now");
387       }
388       SmallVector<std::string, 4> names;
389       int i = 0;
390       for (int e = entities.size(); i < e; ++i)
391         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
392       std::string self = appliedConstraint.self;
393       if (!self.empty())
394         self = symbolInfoMap.getValueAndRangeUse(self);
395       for (; i < 4; ++i)
396         names.push_back("<unused>");
397       os.indent(4) << formatv(cmd,
398                               tgfmt(condition, &fmtCtx.withSelf(self), names[0],
399                                     names[1], names[2], names[3]));
400     }
401   }
402   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
403 }
404 
405 void PatternEmitter::collectOps(DagNode tree,
406                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
407   // Check if this tree is an operation.
408   if (tree.isOperation()) {
409     const Operator &op = tree.getDialectOp(opMap);
410     LLVM_DEBUG(llvm::dbgs()
411                << "found operation " << op.getOperationName() << '\n');
412     ops.insert(&op);
413   }
414 
415   // Recurse the arguments of the tree.
416   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
417     if (auto child = tree.getArgAsNestedDag(i))
418       collectOps(child, ops);
419 }
420 
421 void PatternEmitter::emit(StringRef rewriteName) {
422   // Get the DAG tree for the source pattern.
423   DagNode sourceTree = pattern.getSourcePattern();
424 
425   const Operator &rootOp = pattern.getSourceRootOp();
426   auto rootName = rootOp.getOperationName();
427 
428   // Collect the set of result operations.
429   llvm::SmallPtrSet<const Operator *, 4> resultOps;
430   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
431   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
432     collectOps(pattern.getResultPattern(i), resultOps);
433   }
434   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
435 
436   // Emit RewritePattern for Pattern.
437   auto locs = pattern.getLocation();
438   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
439                 make_range(locs.rbegin(), locs.rend()));
440   os << formatv(R"(struct {0} : public RewritePattern {
441   {0}(MLIRContext *context)
442       : RewritePattern("{1}", {{)",
443                 rewriteName, rootName);
444   // Sort result operators by name.
445   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
446                                                          resultOps.end());
447   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
448     return lhs->getOperationName() < rhs->getOperationName();
449   });
450   interleaveComma(sortedResultOps, os, [&](const Operator *op) {
451     os << '"' << op->getOperationName() << '"';
452   });
453   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
454 
455   // Emit matchAndRewrite() function.
456   os << R"(
457   PatternMatchResult matchAndRewrite(Operation *op0,
458                                      PatternRewriter &rewriter) const override {
459 )";
460 
461   // Register all symbols bound in the source pattern.
462   pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
463 
464   LLVM_DEBUG(
465       llvm::dbgs() << "start creating local variables for capturing matches\n");
466   os.indent(4) << "// Variables for capturing values and attributes used for "
467                   "creating ops\n";
468   // Create local variables for storing the arguments and results bound
469   // to symbols.
470   for (const auto &symbolInfoPair : symbolInfoMap) {
471     StringRef symbol = symbolInfoPair.getKey();
472     auto &info = symbolInfoPair.getValue();
473     os.indent(4) << info.getVarDecl(symbol);
474   }
475   // TODO(jpienaar): capture ops with consistent numbering so that it can be
476   // reused for fused loc.
477   os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
478                           pattern.getSourcePattern().getNumOps());
479   LLVM_DEBUG(
480       llvm::dbgs() << "done creating local variables for capturing matches\n");
481 
482   os.indent(4) << "// Match\n";
483   os.indent(4) << "tblgen_ops[0] = op0;\n";
484   emitMatchLogic(sourceTree);
485   os << "\n";
486 
487   os.indent(4) << "// Rewrite\n";
488   emitRewriteLogic();
489 
490   os.indent(4) << "return matchSuccess();\n";
491   os << "  };\n";
492   os << "};\n";
493 }
494 
495 void PatternEmitter::emitRewriteLogic() {
496   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
497   const Operator &rootOp = pattern.getSourceRootOp();
498   int numExpectedResults = rootOp.getNumResults();
499   int numResultPatterns = pattern.getNumResultPatterns();
500 
501   // First register all symbols bound to ops generated in result patterns.
502   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
503 
504   // Only the last N static values generated are used to replace the matched
505   // root N-result op. We need to calculate the starting index (of the results
506   // of the matched op) each result pattern is to replace.
507   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
508   // If we don't need to replace any value at all, set the replacement starting
509   // index as the number of result patterns so we skip all of them when trying
510   // to replace the matched op's results.
511   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
512   for (int i = numResultPatterns - 1; i >= 0; --i) {
513     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
514     offsets[i] = offsets[i + 1] - numValues;
515     if (offsets[i] == 0) {
516       if (replStartIndex == -1)
517         replStartIndex = i;
518     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
519       auto error = formatv(
520           "cannot use the same multi-result op '{0}' to generate both "
521           "auxiliary values and values to be used for replacing the matched op",
522           pattern.getResultPattern(i).getSymbol());
523       PrintFatalError(loc, error);
524     }
525   }
526 
527   if (offsets.front() > 0) {
528     const char error[] = "no enough values generated to replace the matched op";
529     PrintFatalError(loc, error);
530   }
531 
532   os.indent(4) << "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n";
533   os.indent(4) << "auto loc = rewriter.getFusedLoc({";
534   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
535     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
536   }
537   os << "}); (void)loc;\n";
538 
539   // Process auxiliary result patterns.
540   for (int i = 0; i < replStartIndex; ++i) {
541     DagNode resultTree = pattern.getResultPattern(i);
542     auto val = handleResultPattern(resultTree, offsets[i], 0);
543     // Normal op creation will be streamed to `os` by the above call; but
544     // NativeCodeCall will only be materialized to `os` if it is used. Here
545     // we are handling auxiliary patterns so we want the side effect even if
546     // NativeCodeCall is not replacing matched root op's results.
547     if (resultTree.isNativeCodeCall())
548       os.indent(4) << val << ";\n";
549   }
550 
551   if (numExpectedResults == 0) {
552     assert(replStartIndex >= numResultPatterns &&
553            "invalid auxiliary vs. replacement pattern division!");
554     // No result to replace. Just erase the op.
555     os.indent(4) << "rewriter.eraseOp(op0);\n";
556   } else {
557     // Process replacement result patterns.
558     os.indent(4) << "SmallVector<Value *, 4> tblgen_values;";
559     for (int i = replStartIndex; i < numResultPatterns; ++i) {
560       DagNode resultTree = pattern.getResultPattern(i);
561       auto val = handleResultPattern(resultTree, offsets[i], 0);
562       os.indent(4) << "\n";
563       // Resolve each symbol for all range use so that we can loop over them.
564       os << symbolInfoMap.getAllRangeUse(
565           val, "    for (auto *v : {0}) {{ tblgen_values.push_back(v); }",
566           "\n");
567     }
568     os.indent(4) << "\n";
569     os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n";
570   }
571 
572   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
573 }
574 
575 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
576   return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
577 }
578 
579 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
580                                                 int resultIndex, int depth) {
581   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
582   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
583   LLVM_DEBUG(llvm::dbgs() << '\n');
584 
585   if (resultTree.isNativeCodeCall()) {
586     auto symbol = handleReplaceWithNativeCodeCall(resultTree);
587     symbolInfoMap.bindValue(symbol);
588     return symbol;
589   }
590 
591   if (resultTree.isReplaceWithValue()) {
592     return handleReplaceWithValue(resultTree);
593   }
594 
595   // Normal op creation.
596   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
597   if (resultTree.getSymbol().empty()) {
598     // This is an op not explicitly bound to a symbol in the rewrite rule.
599     // Register the auto-generated symbol for it.
600     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
601   }
602   return symbol;
603 }
604 
605 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
606   assert(tree.isReplaceWithValue());
607 
608   if (tree.getNumArgs() != 1) {
609     PrintFatalError(
610         loc, "replaceWithValue directive must take exactly one argument");
611   }
612 
613   if (!tree.getSymbol().empty()) {
614     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
615   }
616 
617   return tree.getArgName(0);
618 }
619 
620 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
621                                              StringRef patArgName) {
622   if (leaf.isConstantAttr()) {
623     auto constAttr = leaf.getAsConstantAttr();
624     return handleConstantAttr(constAttr.getAttribute(),
625                               constAttr.getConstantValue());
626   }
627   if (leaf.isEnumAttrCase()) {
628     auto enumCase = leaf.getAsEnumAttrCase();
629     if (enumCase.isStrCase())
630       return handleConstantAttr(enumCase, enumCase.getSymbol());
631     // This is an enum case backed by an IntegerAttr. We need to get its value
632     // to build the constant.
633     std::string val = std::to_string(enumCase.getValue());
634     return handleConstantAttr(enumCase, val);
635   }
636 
637   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
638   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
639   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
640     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
641                             << "' (via symbol ref)\n");
642     return argName;
643   }
644   if (leaf.isNativeCodeCall()) {
645     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
646     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
647                             << "' (via NativeCodeCall)\n");
648     return repl;
649   }
650   PrintFatalError(loc, "unhandled case when rewriting op");
651 }
652 
653 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
654   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
655   LLVM_DEBUG(tree.print(llvm::dbgs()));
656   LLVM_DEBUG(llvm::dbgs() << '\n');
657 
658   auto fmt = tree.getNativeCodeTemplate();
659   // TODO(b/138794486): replace formatv arguments with the exact specified args.
660   SmallVector<std::string, 8> attrs(8);
661   if (tree.getNumArgs() > 8) {
662     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
663                              Twine(tree.getNumArgs()));
664   }
665   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
666     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
667     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argment #" << i
668                             << " replacement: " << attrs[i] << "\n");
669   }
670   return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
671                attrs[5], attrs[6], attrs[7]);
672 }
673 
674 int PatternEmitter::getNodeValueCount(DagNode node) {
675   if (node.isOperation()) {
676     // If the op is bound to a symbol in the rewrite rule, query its result
677     // count from the symbol info map.
678     auto symbol = node.getSymbol();
679     if (!symbol.empty()) {
680       return symbolInfoMap.getStaticValueCount(symbol);
681     }
682     // Otherwise this is an unbound op; we will use all its results.
683     return pattern.getDialectOp(node).getNumResults();
684   }
685   // TODO(antiagainst): This considers all NativeCodeCall as returning one
686   // value. Enhance if multi-value ones are needed.
687   return 1;
688 }
689 
690 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
691                                              int depth) {
692   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
693   LLVM_DEBUG(tree.print(llvm::dbgs()));
694   LLVM_DEBUG(llvm::dbgs() << '\n');
695 
696   Operator &resultOp = tree.getDialectOp(opMap);
697   auto numOpArgs = resultOp.getNumArgs();
698 
699   if (numOpArgs != tree.getNumArgs()) {
700     PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
701                                  "{1} in pattern vs. {2} in definition",
702                                  resultOp.getOperationName(), tree.getNumArgs(),
703                                  numOpArgs));
704   }
705 
706   // A map to collect all nested DAG child nodes' names, with operand index as
707   // the key. This includes both bound and unbound child nodes.
708   llvm::DenseMap<unsigned, std::string> childNodeNames;
709 
710   // First go through all the child nodes who are nested DAG constructs to
711   // create ops for them and remember the symbol names for them, so that we can
712   // use the results in the current node. This happens in a recursive manner.
713   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
714     if (auto child = tree.getArgAsNestedDag(i)) {
715       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
716     }
717   }
718 
719   // The name of the local variable holding this op.
720   std::string valuePackName;
721   // The symbol for holding the result of this pattern. Note that the result of
722   // this pattern is not necessarily the same as the variable created by this
723   // pattern because we can use `__N` suffix to refer only a specific result if
724   // the generated op is a multi-result op.
725   std::string resultValue;
726   if (tree.getSymbol().empty()) {
727     // No symbol is explicitly bound to this op in the pattern. Generate a
728     // unique name.
729     valuePackName = resultValue = getUniqueSymbol(&resultOp);
730   } else {
731     resultValue = tree.getSymbol();
732     // Strip the index to get the name for the value pack and use it to name the
733     // local variable for the op.
734     valuePackName = SymbolInfoMap::getValuePackName(resultValue);
735   }
736 
737   // Create the local variable for this op.
738   os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
739                           valuePackName);
740   os.indent(4) << "{\n";
741 
742   // Now prepare operands used for building this op:
743   // * If the operand is non-variadic, we create a `Value*` local variable.
744   // * If the operand is variadic, we create a `SmallVector<Value*>` local
745   //   variable.
746 
747   int valueIndex = 0; // An index for uniquing local variable names.
748   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
749     const auto *operand =
750         resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
751     if (!operand) {
752       // We do not need special handling for attributes.
753       continue;
754     }
755     std::string varName;
756     if (operand->isVariadic()) {
757       varName = formatv("tblgen_values_{0}", valueIndex++);
758       os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
759       std::string range;
760       if (tree.isNestedDagArg(argIndex)) {
761         range = childNodeNames[argIndex];
762       } else {
763         range = tree.getArgName(argIndex);
764       }
765       // Resolve the symbol for all range use so that we have a uniform way of
766       // capturing the values.
767       range = symbolInfoMap.getValueAndRangeUse(range);
768       os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range,
769                               varName);
770     } else {
771       varName = formatv("tblgen_value_{0}", valueIndex++);
772       os.indent(6) << formatv("Value *{0} = ", varName);
773       if (tree.isNestedDagArg(argIndex)) {
774         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
775       } else {
776         DagLeaf leaf = tree.getArgAsLeaf(argIndex);
777         auto symbol =
778             symbolInfoMap.getValueAndRangeUse(tree.getArgName(argIndex));
779         if (leaf.isNativeCodeCall()) {
780           os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
781         } else {
782           os << symbol;
783         }
784       }
785       os << ";\n";
786     }
787 
788     // Update to use the newly created local variable for building the op later.
789     childNodeNames[argIndex] = varName;
790   }
791 
792   // Then we create the builder call.
793 
794   // Right now we don't have general type inference in MLIR. Except a few
795   // special cases listed below, we need to supply types for all results
796   // when building an op.
797   bool isSameOperandsAndResultType =
798       resultOp.hasTrait("OpTrait::SameOperandsAndResultType");
799   bool isBroadcastable =
800       resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult");
801   bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType");
802   bool usePartialResults = valuePackName != resultValue;
803 
804   if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
805       usePartialResults || depth > 0 || resultIndex < 0) {
806     os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
807                             resultOp.getQualCppClassName());
808   } else {
809     // If depth == 0 and resultIndex >= 0, it means we are replacing the values
810     // generated from the source pattern root op. Then we can use the source
811     // pattern's value types to determine the value type of the generated op
812     // here.
813 
814     // We need to specify the types for all results.
815     int numResults = resultOp.getNumResults();
816     if (numResults != 0) {
817       os.indent(6) << "tblgen_types.clear();\n";
818       for (int i = 0; i < numResults; ++i) {
819         os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) "
820                                 "tblgen_types.push_back(v->getType());\n",
821                                 resultIndex + i);
822       }
823     }
824 
825     os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
826                             resultOp.getQualCppClassName());
827     if (numResults != 0)
828       os.indent(6) << ", tblgen_types";
829   }
830 
831   // Add arguments for the builder call.
832   for (int argIndex = 0; argIndex != numOpArgs; ++argIndex) {
833     // Start each argment on its own line.
834     (os << ",\n").indent(8);
835 
836     Argument opArg = resultOp.getArg(argIndex);
837     // Handle the case of operand first.
838     if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
839       if (!operand->name.empty()) {
840         os << "/*" << operand->name << "=*/";
841       }
842       os << childNodeNames[argIndex];
843       // TODO(jpienaar): verify types
844       continue;
845     }
846 
847     // The argument in the op definition.
848     auto opArgName = resultOp.getArgName(argIndex);
849     if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
850       if (!subTree.isNativeCodeCall())
851         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
852                              "for creating attribute");
853       os << formatv("/*{0}=*/{1}", opArgName,
854                     handleReplaceWithNativeCodeCall(subTree));
855     } else {
856       auto leaf = tree.getArgAsLeaf(argIndex);
857       // The argument in the result DAG pattern.
858       auto patArgName = tree.getArgName(argIndex);
859       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
860         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
861         if (!opArg.is<NamedAttribute *>())
862           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
863         if (!patArgName.empty())
864           os << "/*" << patArgName << "=*/";
865       } else {
866         os << "/*" << opArgName << "=*/";
867       }
868       os << handleOpArgument(leaf, patArgName);
869     }
870   }
871   os << "\n      );\n";
872   os.indent(4) << "}\n";
873 
874   return resultValue;
875 }
876 
877 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
878   emitSourceFileHeader("Rewriters", os);
879 
880   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
881   auto numPatterns = patterns.size();
882 
883   // We put the map here because it can be shared among multiple patterns.
884   RecordOperatorMap recordOpMap;
885 
886   std::vector<std::string> rewriterNames;
887   rewriterNames.reserve(numPatterns);
888 
889   std::string baseRewriterName = "GeneratedConvert";
890   int rewriterIndex = 0;
891 
892   for (Record *p : patterns) {
893     std::string name;
894     if (p->isAnonymous()) {
895       // If no name is provided, ensure unique rewriter names simply by
896       // appending unique suffix.
897       name = baseRewriterName + llvm::utostr(rewriterIndex++);
898     } else {
899       name = p->getName();
900     }
901     LLVM_DEBUG(llvm::dbgs()
902                << "=== start generating pattern '" << name << "' ===\n");
903     PatternEmitter(p, &recordOpMap, os).emit(name);
904     LLVM_DEBUG(llvm::dbgs()
905                << "=== done generating pattern '" << name << "' ===\n");
906     rewriterNames.push_back(std::move(name));
907   }
908 
909   // Emit function to add the generated matchers to the pattern list.
910   os << "void populateWithGenerated(MLIRContext *context, "
911      << "OwningRewritePatternList *patterns) {\n";
912   for (const auto &name : rewriterNames) {
913     os << "  patterns->insert<" << name << ">(context);\n";
914   }
915   os << "}\n";
916 }
917 
918 static mlir::GenRegistration
919     genRewriters("gen-rewriters", "Generate pattern rewriters",
920                  [](const RecordKeeper &records, raw_ostream &os) {
921                    emitRewriters(records, os);
922                    return false;
923                  });
924