xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 603117b2d62b320e86d3452fa0b9f7b00d125574)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Support/STLExtras.h"
23 #include "mlir/TableGen/Attribute.h"
24 #include "mlir/TableGen/Format.h"
25 #include "mlir/TableGen/GenInfo.h"
26 #include "mlir/TableGen/Operator.h"
27 #include "mlir/TableGen/Pattern.h"
28 #include "mlir/TableGen/Predicate.h"
29 #include "mlir/TableGen/Type.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/FormatAdapters.h"
35 #include "llvm/Support/PrettyStackTrace.h"
36 #include "llvm/Support/Signals.h"
37 #include "llvm/TableGen/Error.h"
38 #include "llvm/TableGen/Main.h"
39 #include "llvm/TableGen/Record.h"
40 #include "llvm/TableGen/TableGenBackend.h"
41 
42 using namespace llvm;
43 using namespace mlir;
44 using namespace mlir::tblgen;
45 
46 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
47 
48 namespace llvm {
49 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
50   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
51                      raw_ostream &os, StringRef style) {
52     os << v.first << ":" << v.second;
53   }
54 };
55 } // end namespace llvm
56 
57 //===----------------------------------------------------------------------===//
58 // PatternEmitter
59 //===----------------------------------------------------------------------===//
60 
61 namespace {
62 class PatternEmitter {
63 public:
64   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
65 
66   // Emits the mlir::RewritePattern struct named `rewriteName`.
67   void emit(StringRef rewriteName);
68 
69 private:
70   // Emits the code for matching ops.
71   void emitMatchLogic(DagNode tree);
72 
73   // Emits the code for rewriting ops.
74   void emitRewriteLogic();
75 
76   //===--------------------------------------------------------------------===//
77   // Match utilities
78   //===--------------------------------------------------------------------===//
79 
80   // Emits C++ statements for matching the op constrained by the given DAG
81   // `tree`.
82   void emitOpMatch(DagNode tree, int depth);
83 
84   // Emits C++ statements for matching the `index`-th argument of the given DAG
85   // `tree` as an operand.
86   void emitOperandMatch(DagNode tree, int index, int depth, int indent);
87 
88   // Emits C++ statements for matching the `index`-th argument of the given DAG
89   // `tree` as an attribute.
90   void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
91 
92   //===--------------------------------------------------------------------===//
93   // Rewrite utilities
94   //===--------------------------------------------------------------------===//
95 
96   // The entry point for handling a result pattern rooted at `resultTree`. This
97   // method dispatches to concrete handlers according to `resultTree`'s kind and
98   // returns a symbol representing the whole value pack. Callers are expected to
99   // further resolve the symbol according to the specific use case.
100   //
101   // `depth` is the nesting level of `resultTree`; 0 means top-level result
102   // pattern. For top-level result pattern, `resultIndex` indicates which result
103   // of the matched root op this pattern is intended to replace, which can be
104   // used to deduce the result type of the op generated from this result
105   // pattern.
106   std::string handleResultPattern(DagNode resultTree, int resultIndex,
107                                   int depth);
108 
109   // Emits the C++ statement to replace the matched DAG with a value built via
110   // calling native C++ code.
111   std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
112 
113   // Returns the C++ expression referencing the old value serving as the
114   // replacement.
115   std::string handleReplaceWithValue(DagNode tree);
116 
117   // Emits the C++ statement to build a new op out of the given DAG `tree` and
118   // returns the variable name that this op is assigned to. If the root op in
119   // DAG `tree` has a specified name, the created op will be assigned to a
120   // variable of the given name. Otherwise, a unique name will be used as the
121   // result value name.
122   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
123 
124   // Returns the C++ expression to construct a constant attribute of the given
125   // `value` for the given attribute kind `attr`.
126   std::string handleConstantAttr(Attribute attr, StringRef value);
127 
128   // Returns the C++ expression to build an argument from the given DAG `leaf`.
129   // `patArgName` is used to bound the argument to the source pattern.
130   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
131 
132   //===--------------------------------------------------------------------===//
133   // General utilities
134   //===--------------------------------------------------------------------===//
135 
136   // Collects all of the operations within the given dag tree.
137   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
138 
139   // Returns a unique symbol for a local variable of the given `op`.
140   std::string getUniqueSymbol(const Operator *op);
141 
142   //===--------------------------------------------------------------------===//
143   // Symbol utilities
144   //===--------------------------------------------------------------------===//
145 
146   // Returns how many static values the given DAG `node` correspond to.
147   int getNodeValueCount(DagNode node);
148 
149 private:
150   // Pattern instantiation location followed by the location of multiclass
151   // prototypes used. This is intended to be used as a whole to
152   // PrintFatalError() on errors.
153   ArrayRef<llvm::SMLoc> loc;
154 
155   // Op's TableGen Record to wrapper object.
156   RecordOperatorMap *opMap;
157 
158   // Handy wrapper for pattern being emitted.
159   Pattern pattern;
160 
161   // Map for all bound symbols' info.
162   SymbolInfoMap symbolInfoMap;
163 
164   // The next unused ID for newly created values.
165   unsigned nextValueId;
166 
167   raw_ostream &os;
168 
169   // Format contexts containing placeholder substitutions.
170   FmtContext fmtCtx;
171 
172   // Number of op processed.
173   int opCounter = 0;
174 };
175 } // end anonymous namespace
176 
177 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
178                                raw_ostream &os)
179     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
180       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
181   fmtCtx.withBuilder("rewriter");
182 }
183 
184 std::string PatternEmitter::handleConstantAttr(Attribute attr,
185                                                StringRef value) {
186   if (!attr.isConstBuildable())
187     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
188                              " does not have the 'constBuilderCall' field");
189 
190   // TODO(jpienaar): Verify the constants here
191   return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value);
192 }
193 
194 // Helper function to match patterns.
195 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
196   Operator &op = tree.getDialectOp(opMap);
197   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
198                           << op.getOperationName() << "' at depth " << depth
199                           << '\n');
200 
201   int indent = 4 + 2 * depth;
202   os.indent(indent) << formatv(
203       "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
204       depth, op.getQualCppClassName());
205   // Skip the operand matching at depth 0 as the pattern rewriter already does.
206   if (depth != 0) {
207     // Skip if there is no defining operation (e.g., arguments to function).
208     os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
209                                  depth);
210   }
211   if (tree.getNumArgs() != op.getNumArgs()) {
212     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
213                                  "pattern vs. {2} in definition",
214                                  op.getOperationName(), tree.getNumArgs(),
215                                  op.getNumArgs()));
216   }
217 
218   // If the operand's name is set, set to that variable.
219   auto name = tree.getSymbol();
220   if (!name.empty())
221     os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
222 
223   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
224     auto opArg = op.getArg(i);
225 
226     // Handle nested DAG construct first
227     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
228       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
229         if (operand->isVariadic()) {
230           auto error = formatv("use nested DAG construct to match op {0}'s "
231                                "variadic operand #{1} unsupported now",
232                                op.getOperationName(), i);
233           PrintFatalError(loc, error);
234         }
235       }
236       os.indent(indent) << "{\n";
237 
238       os.indent(indent + 2) << formatv(
239           "auto *op{0} = "
240           "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n",
241           depth + 1, depth, i);
242       emitOpMatch(argTree, depth + 1);
243       os.indent(indent + 2)
244           << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
245       os.indent(indent) << "}\n";
246       continue;
247     }
248 
249     // Next handle DAG leaf: operand or attribute
250     if (opArg.is<NamedTypeConstraint *>()) {
251       emitOperandMatch(tree, i, depth, indent);
252     } else if (opArg.is<NamedAttribute *>()) {
253       emitAttributeMatch(tree, i, depth, indent);
254     } else {
255       PrintFatalError(loc, "unhandled case when matching op");
256     }
257   }
258   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
259                           << op.getOperationName() << "' at depth " << depth
260                           << '\n');
261 }
262 
263 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
264                                       int indent) {
265   Operator &op = tree.getDialectOp(opMap);
266   auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
267   auto matcher = tree.getArgAsLeaf(index);
268 
269   // If a constraint is specified, we need to generate C++ statements to
270   // check the constraint.
271   if (!matcher.isUnspecified()) {
272     if (!matcher.isOperandMatcher()) {
273       PrintFatalError(
274           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
275                        op.getOperationName(), index + 1));
276     }
277 
278     // Only need to verify if the matcher's type is different from the one
279     // of op definition.
280     if (operand->constraint != matcher.getAsConstraint()) {
281       if (operand->isVariadic()) {
282         auto error = formatv(
283             "further constrain op {0}'s variadic operand #{1} unsupported now",
284             op.getOperationName(), index);
285         PrintFatalError(loc, error);
286       }
287       auto self =
288           formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
289                   depth, index);
290       os.indent(indent) << "if (!("
291                         << tgfmt(matcher.getConditionTemplate(),
292                                  &fmtCtx.withSelf(self))
293                         << ")) return matchFailure();\n";
294     }
295   }
296 
297   // Capture the value
298   auto name = tree.getArgName(index);
299   if (!name.empty()) {
300     os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
301                                  name, depth, index);
302   }
303 }
304 
305 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
306                                         int indent) {
307   Operator &op = tree.getDialectOp(opMap);
308   auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
309   const auto &attr = namedAttr->attr;
310 
311   os.indent(indent) << "{\n";
312   indent += 2;
313   os.indent(indent) << formatv(
314       "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
315       attr.getStorageType(), namedAttr->name);
316 
317   // TODO(antiagainst): This should use getter method to avoid duplication.
318   if (attr.hasDefaultValueInitializer()) {
319     os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
320                       << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
321                                attr.getDefaultValueInitializer())
322                       << ";\n";
323   } else if (attr.isOptional()) {
324     // For a missing attribute that is optional according to definition, we
325     // should just capture a mlir::Attribute() to signal the missing state.
326     // That is precisely what getAttr() returns on missing attributes.
327   } else {
328     os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
329   }
330 
331   auto matcher = tree.getArgAsLeaf(index);
332   if (!matcher.isUnspecified()) {
333     if (!matcher.isAttrMatcher()) {
334       PrintFatalError(
335           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
336                        op.getOperationName(), index + 1));
337     }
338 
339     // If a constraint is specified, we need to generate C++ statements to
340     // check the constraint.
341     os.indent(indent) << "if (!("
342                       << tgfmt(matcher.getConditionTemplate(),
343                                &fmtCtx.withSelf("tblgen_attr"))
344                       << ")) return matchFailure();\n";
345   }
346 
347   // Capture the value
348   auto name = tree.getArgName(index);
349   if (!name.empty()) {
350     os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
351   }
352 
353   indent -= 2;
354   os.indent(indent) << "}\n";
355 }
356 
357 void PatternEmitter::emitMatchLogic(DagNode tree) {
358   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
359   emitOpMatch(tree, 0);
360 
361   for (auto &appliedConstraint : pattern.getConstraints()) {
362     auto &constraint = appliedConstraint.constraint;
363     auto &entities = appliedConstraint.entities;
364 
365     auto condition = constraint.getConditionTemplate();
366     auto cmd = "if (!({0})) return matchFailure();\n";
367 
368     if (isa<TypeConstraint>(constraint)) {
369       auto self = formatv("({0}->getType())",
370                           symbolInfoMap.getValueAndRangeUse(entities.front()));
371       os.indent(4) << formatv(cmd,
372                               tgfmt(condition, &fmtCtx.withSelf(self.str())));
373     } else if (isa<AttrConstraint>(constraint)) {
374       PrintFatalError(
375           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
376     } else {
377       // TODO(b/138794486): replace formatv arguments with the exact specified
378       // args.
379       if (entities.size() > 4) {
380         PrintFatalError(loc, "only support up to 4-entity constraints now");
381       }
382       SmallVector<std::string, 4> names;
383       int i = 0;
384       for (int e = entities.size(); i < e; ++i)
385         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
386       std::string self = appliedConstraint.self;
387       if (!self.empty())
388         self = symbolInfoMap.getValueAndRangeUse(self);
389       for (; i < 4; ++i)
390         names.push_back("<unused>");
391       os.indent(4) << formatv(cmd,
392                               tgfmt(condition, &fmtCtx.withSelf(self), names[0],
393                                     names[1], names[2], names[3]));
394     }
395   }
396   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
397 }
398 
399 void PatternEmitter::collectOps(DagNode tree,
400                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
401   // Check if this tree is an operation.
402   if (tree.isOperation()) {
403     const Operator &op = tree.getDialectOp(opMap);
404     LLVM_DEBUG(llvm::dbgs()
405                << "found operation " << op.getOperationName() << '\n');
406     ops.insert(&op);
407   }
408 
409   // Recurse the arguments of the tree.
410   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
411     if (auto child = tree.getArgAsNestedDag(i))
412       collectOps(child, ops);
413 }
414 
415 void PatternEmitter::emit(StringRef rewriteName) {
416   // Get the DAG tree for the source pattern.
417   DagNode sourceTree = pattern.getSourcePattern();
418 
419   const Operator &rootOp = pattern.getSourceRootOp();
420   auto rootName = rootOp.getOperationName();
421 
422   // Collect the set of result operations.
423   llvm::SmallPtrSet<const Operator *, 4> resultOps;
424   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
425   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
426     collectOps(pattern.getResultPattern(i), resultOps);
427   }
428   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
429 
430   // Emit RewritePattern for Pattern.
431   auto locs = pattern.getLocation();
432   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
433                 make_range(locs.rbegin(), locs.rend()));
434   os << formatv(R"(struct {0} : public RewritePattern {
435   {0}(MLIRContext *context)
436       : RewritePattern("{1}", {{)",
437                 rewriteName, rootName);
438   // Sort result operators by name.
439   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
440                                                          resultOps.end());
441   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
442     return lhs->getOperationName() < rhs->getOperationName();
443   });
444   interleaveComma(sortedResultOps, os, [&](const Operator *op) {
445     os << '"' << op->getOperationName() << '"';
446   });
447   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
448 
449   // Emit matchAndRewrite() function.
450   os << R"(
451   PatternMatchResult matchAndRewrite(Operation *op0,
452                                      PatternRewriter &rewriter) const override {
453 )";
454 
455   // Register all symbols bound in the source pattern.
456   pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
457 
458   LLVM_DEBUG(
459       llvm::dbgs() << "start creating local variables for capturing matches\n");
460   os.indent(4) << "// Variables for capturing values and attributes used for "
461                   "creating ops\n";
462   // Create local variables for storing the arguments and results bound
463   // to symbols.
464   for (const auto &symbolInfoPair : symbolInfoMap) {
465     StringRef symbol = symbolInfoPair.getKey();
466     auto &info = symbolInfoPair.getValue();
467     os.indent(4) << info.getVarDecl(symbol);
468   }
469   // TODO(jpienaar): capture ops with consistent numbering so that it can be
470   // reused for fused loc.
471   os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
472                           pattern.getSourcePattern().getNumOps());
473   LLVM_DEBUG(
474       llvm::dbgs() << "done creating local variables for capturing matches\n");
475 
476   os.indent(4) << "// Match\n";
477   os.indent(4) << "tblgen_ops[0] = op0;\n";
478   emitMatchLogic(sourceTree);
479   os << "\n";
480 
481   os.indent(4) << "// Rewrite\n";
482   emitRewriteLogic();
483 
484   os.indent(4) << "return matchSuccess();\n";
485   os << "  };\n";
486   os << "};\n";
487 }
488 
489 void PatternEmitter::emitRewriteLogic() {
490   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
491   const Operator &rootOp = pattern.getSourceRootOp();
492   int numExpectedResults = rootOp.getNumResults();
493   int numResultPatterns = pattern.getNumResultPatterns();
494 
495   // First register all symbols bound to ops generated in result patterns.
496   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
497 
498   // Only the last N static values generated are used to replace the matched
499   // root N-result op. We need to calculate the starting index (of the results
500   // of the matched op) each result pattern is to replace.
501   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
502   // If we don't need to replace any value at all, set the replacement starting
503   // index as the number of result patterns so we skip all of them when trying
504   // to replace the matched op's results.
505   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
506   for (int i = numResultPatterns - 1; i >= 0; --i) {
507     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
508     offsets[i] = offsets[i + 1] - numValues;
509     if (offsets[i] == 0) {
510       if (replStartIndex == -1)
511         replStartIndex = i;
512     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
513       auto error = formatv(
514           "cannot use the same multi-result op '{0}' to generate both "
515           "auxiliary values and values to be used for replacing the matched op",
516           pattern.getResultPattern(i).getSymbol());
517       PrintFatalError(loc, error);
518     }
519   }
520 
521   if (offsets.front() > 0) {
522     const char error[] = "no enough values generated to replace the matched op";
523     PrintFatalError(loc, error);
524   }
525 
526   os.indent(4) << "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n";
527   os.indent(4) << "auto loc = rewriter.getFusedLoc({";
528   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
529     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
530   }
531   os << "}); (void)loc;\n";
532 
533   // Process auxiliary result patterns.
534   for (int i = 0; i < replStartIndex; ++i) {
535     DagNode resultTree = pattern.getResultPattern(i);
536     auto val = handleResultPattern(resultTree, offsets[i], 0);
537     // Normal op creation will be streamed to `os` by the above call; but
538     // NativeCodeCall will only be materialized to `os` if it is used. Here
539     // we are handling auxiliary patterns so we want the side effect even if
540     // NativeCodeCall is not replacing matched root op's results.
541     if (resultTree.isNativeCodeCall())
542       os.indent(4) << val << ";\n";
543   }
544 
545   // Process replacement result patterns.
546   os.indent(4) << "SmallVector<Value *, 4> tblgen_values;";
547   for (int i = replStartIndex; i < numResultPatterns; ++i) {
548     DagNode resultTree = pattern.getResultPattern(i);
549     auto val = handleResultPattern(resultTree, offsets[i], 0);
550     os.indent(4) << "\n";
551     // Resolve each symbol for all range use so that we can loop over them.
552     os << symbolInfoMap.getAllRangeUse(
553         val, "    for (auto *v : {0}) tblgen_values.push_back(v);", "\n");
554   }
555   os.indent(4) << "\n";
556   os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n";
557   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
558 }
559 
560 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
561   return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
562 }
563 
564 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
565                                                 int resultIndex, int depth) {
566   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
567   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
568   LLVM_DEBUG(llvm::dbgs() << '\n');
569 
570   if (resultTree.isNativeCodeCall()) {
571     auto symbol = handleReplaceWithNativeCodeCall(resultTree);
572     symbolInfoMap.bindValue(symbol);
573     return symbol;
574   }
575 
576   if (resultTree.isReplaceWithValue()) {
577     return handleReplaceWithValue(resultTree);
578   }
579 
580   // Normal op creation.
581   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
582   if (resultTree.getSymbol().empty()) {
583     // This is an op not explicitly bound to a symbol in the rewrite rule.
584     // Register the auto-generated symbol for it.
585     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
586   }
587   return symbol;
588 }
589 
590 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
591   assert(tree.isReplaceWithValue());
592 
593   if (tree.getNumArgs() != 1) {
594     PrintFatalError(
595         loc, "replaceWithValue directive must take exactly one argument");
596   }
597 
598   if (!tree.getSymbol().empty()) {
599     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
600   }
601 
602   return tree.getArgName(0);
603 }
604 
605 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
606                                              StringRef patArgName) {
607   if (leaf.isConstantAttr()) {
608     auto constAttr = leaf.getAsConstantAttr();
609     return handleConstantAttr(constAttr.getAttribute(),
610                               constAttr.getConstantValue());
611   }
612   if (leaf.isEnumAttrCase()) {
613     auto enumCase = leaf.getAsEnumAttrCase();
614     if (enumCase.isStrCase())
615       return handleConstantAttr(enumCase, enumCase.getSymbol());
616     // This is an enum case backed by an IntegerAttr. We need to get its value
617     // to build the constant.
618     std::string val = std::to_string(enumCase.getValue());
619     return handleConstantAttr(enumCase, val);
620   }
621 
622   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
623   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
624   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
625     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
626                             << "' (via symbol ref)\n");
627     return argName;
628   }
629   if (leaf.isNativeCodeCall()) {
630     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
631     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
632                             << "' (via NativeCodeCall)\n");
633     return repl;
634   }
635   PrintFatalError(loc, "unhandled case when rewriting op");
636 }
637 
638 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
639   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
640   LLVM_DEBUG(tree.print(llvm::dbgs()));
641   LLVM_DEBUG(llvm::dbgs() << '\n');
642 
643   auto fmt = tree.getNativeCodeTemplate();
644   // TODO(b/138794486): replace formatv arguments with the exact specified args.
645   SmallVector<std::string, 8> attrs(8);
646   if (tree.getNumArgs() > 8) {
647     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
648                              Twine(tree.getNumArgs()));
649   }
650   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
651     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
652     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argment #" << i
653                             << " replacement: " << attrs[i] << "\n");
654   }
655   return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
656                attrs[5], attrs[6], attrs[7]);
657 }
658 
659 int PatternEmitter::getNodeValueCount(DagNode node) {
660   if (node.isOperation()) {
661     // If the op is bound to a symbol in the rewrite rule, query its result
662     // count from the symbol info map.
663     auto symbol = node.getSymbol();
664     if (!symbol.empty()) {
665       return symbolInfoMap.getStaticValueCount(symbol);
666     }
667     // Otherwise this is an unbound op; we will use all its results.
668     return pattern.getDialectOp(node).getNumResults();
669   }
670   // TODO(antiagainst): This considers all NativeCodeCall as returning one
671   // value. Enhance if multi-value ones are needed.
672   return 1;
673 }
674 
675 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
676                                              int depth) {
677   Operator &resultOp = tree.getDialectOp(opMap);
678   auto numOpArgs = resultOp.getNumArgs();
679 
680   if (numOpArgs != tree.getNumArgs()) {
681     PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
682                                  "{1} in pattern vs. {2} in definition",
683                                  resultOp.getOperationName(), tree.getNumArgs(),
684                                  numOpArgs));
685   }
686 
687   // A map to collect all nested DAG child nodes' names, with operand index as
688   // the key. This includes both bound and unbound child nodes.
689   llvm::DenseMap<unsigned, std::string> childNodeNames;
690 
691   // First go through all the child nodes who are nested DAG constructs to
692   // create ops for them and remember the symbol names for them, so that we can
693   // use the results in the current node. This happens in a recursive manner.
694   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
695     if (auto child = tree.getArgAsNestedDag(i)) {
696       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
697     }
698   }
699 
700   // The name of the local variable holding this op.
701   std::string valuePackName;
702   // The symbol for holding the result of this pattern. Note that the result of
703   // this pattern is not necessarily the same as the variable created by this
704   // pattern because we can use `__N` suffix to refer only a specific result if
705   // the generated op is a multi-result op.
706   std::string resultValue;
707   if (tree.getSymbol().empty()) {
708     // No symbol is explicitly bound to this op in the pattern. Generate a
709     // unique name.
710     valuePackName = resultValue = getUniqueSymbol(&resultOp);
711   } else {
712     resultValue = tree.getSymbol();
713     // Strip the index to get the name for the value pack and use it to name the
714     // local variable for the op.
715     valuePackName = SymbolInfoMap::getValuePackName(resultValue);
716   }
717 
718   // Create the local variable for this op.
719   os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
720                           valuePackName);
721   os.indent(4) << "{\n";
722 
723   // Now prepare operands used for building this op:
724   // * If the operand is non-variadic, we create a `Value*` local variable.
725   // * If the operand is variadic, we create a `SmallVector<Value*>` local
726   //   variable.
727 
728   int argIndex = 0;   // The current index to this op's ODS argument
729   int valueIndex = 0; // An index for uniquing local variable names.
730   for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
731     const auto &operand = resultOp.getOperand(argIndex);
732     std::string varName;
733     if (operand.isVariadic()) {
734       varName = formatv("tblgen_values_{0}", valueIndex++);
735       os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
736       std::string range;
737       if (tree.isNestedDagArg(argIndex)) {
738         range = childNodeNames[argIndex];
739       } else {
740         range = tree.getArgName(argIndex);
741       }
742       // Resolve the symbol for all range use so that we have a uniform way of
743       // capturing the values.
744       range = symbolInfoMap.getValueAndRangeUse(range);
745       os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range,
746                               varName);
747     } else {
748       varName = formatv("tblgen_value_{0}", valueIndex++);
749       os.indent(6) << formatv("Value *{0} = ", varName);
750       if (tree.isNestedDagArg(argIndex)) {
751         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
752       } else {
753         DagLeaf leaf = tree.getArgAsLeaf(argIndex);
754         auto symbol =
755             symbolInfoMap.getValueAndRangeUse(tree.getArgName(argIndex));
756         if (leaf.isNativeCodeCall()) {
757           os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
758         } else {
759           os << symbol;
760         }
761       }
762       os << ";\n";
763     }
764 
765     // Update to use the newly created local variable for building the op later.
766     childNodeNames[argIndex] = varName;
767   }
768 
769   // Then we create the builder call.
770 
771   // Right now we don't have general type inference in MLIR. Except a few
772   // special cases listed below, we need to supply types for all results
773   // when building an op.
774   bool isSameOperandsAndResultType =
775       resultOp.hasTrait("OpTrait::SameOperandsAndResultType");
776   bool isBroadcastable =
777       resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult");
778   bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType");
779   bool usePartialResults = valuePackName != resultValue;
780 
781   if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
782       usePartialResults || depth > 0 || resultIndex < 0) {
783     os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
784                             resultOp.getQualCppClassName());
785   } else {
786     // If depth == 0 and resultIndex >= 0, it means we are replacing the values
787     // generated from the source pattern root op. Then we can use the source
788     // pattern's value types to determine the value type of the generated op
789     // here.
790 
791     // We need to specify the types for all results.
792     int numResults = resultOp.getNumResults();
793     if (numResults != 0) {
794       os.indent(6) << "tblgen_types.clear();\n";
795       for (int i = 0; i < numResults; ++i) {
796         os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) "
797                                 "tblgen_types.push_back(v->getType());\n",
798                                 resultIndex + i);
799       }
800     }
801 
802     os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
803                             resultOp.getQualCppClassName());
804     if (numResults != 0)
805       os.indent(6) << ", tblgen_types";
806   }
807 
808   // Add operands for the builder all.
809   for (int i = 0; i < argIndex; ++i) {
810     const auto &operand = resultOp.getOperand(i);
811     // Start each operand on its own line.
812     (os << ",\n").indent(8);
813     if (!operand.name.empty()) {
814       os << "/*" << operand.name << "=*/";
815     }
816     os << childNodeNames[i];
817     // TODO(jpienaar): verify types
818   }
819 
820   // Add attributes for the builder call.
821   for (; argIndex != numOpArgs; ++argIndex) {
822     // Start each attribute on its own line.
823     (os << ",\n").indent(8);
824     // The argument in the op definition.
825     auto opArgName = resultOp.getArgName(argIndex);
826     if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
827       if (!subTree.isNativeCodeCall())
828         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
829                              "for creating attribute");
830       os << formatv("/*{0}=*/{1}", opArgName,
831                     handleReplaceWithNativeCodeCall(subTree));
832     } else {
833       auto leaf = tree.getArgAsLeaf(argIndex);
834       // The argument in the result DAG pattern.
835       auto patArgName = tree.getArgName(argIndex);
836       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
837         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
838         auto argument = resultOp.getArg(argIndex);
839         if (!argument.is<NamedAttribute *>())
840           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
841         if (!patArgName.empty())
842           os << "/*" << patArgName << "=*/";
843       } else {
844         os << "/*" << opArgName << "=*/";
845       }
846       os << handleOpArgument(leaf, patArgName);
847     }
848   }
849   os << "\n      );\n";
850   os.indent(4) << "}\n";
851 
852   return resultValue;
853 }
854 
855 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
856   emitSourceFileHeader("Rewriters", os);
857 
858   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
859   auto numPatterns = patterns.size();
860 
861   // We put the map here because it can be shared among multiple patterns.
862   RecordOperatorMap recordOpMap;
863 
864   std::vector<std::string> rewriterNames;
865   rewriterNames.reserve(numPatterns);
866 
867   std::string baseRewriterName = "GeneratedConvert";
868   int rewriterIndex = 0;
869 
870   for (Record *p : patterns) {
871     std::string name;
872     if (p->isAnonymous()) {
873       // If no name is provided, ensure unique rewriter names simply by
874       // appending unique suffix.
875       name = baseRewriterName + llvm::utostr(rewriterIndex++);
876     } else {
877       name = p->getName();
878     }
879     LLVM_DEBUG(llvm::dbgs()
880                << "=== start generating pattern '" << name << "' ===\n");
881     PatternEmitter(p, &recordOpMap, os).emit(name);
882     LLVM_DEBUG(llvm::dbgs()
883                << "=== done generating pattern '" << name << "' ===\n");
884     rewriterNames.push_back(std::move(name));
885   }
886 
887   // Emit function to add the generated matchers to the pattern list.
888   os << "void populateWithGenerated(MLIRContext *context, "
889      << "OwningRewritePatternList *patterns) {\n";
890   for (const auto &name : rewriterNames) {
891     os << "  patterns->insert<" << name << ">(context);\n";
892   }
893   os << "}\n";
894 }
895 
896 static mlir::GenRegistration
897     genRewriters("gen-rewriters", "Generate pattern rewriters",
898                  [](const RecordKeeper &records, raw_ostream &os) {
899                    emitRewriters(records, os);
900                    return false;
901                  });
902