xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision b0746c68629c26567cd123d8f9b28e796ef26f47)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Argument.h"
15 #include "mlir/TableGen/Attribute.h"
16 #include "mlir/TableGen/CodeGenHelpers.h"
17 #include "mlir/TableGen/Format.h"
18 #include "mlir/TableGen/GenInfo.h"
19 #include "mlir/TableGen/Operator.h"
20 #include "mlir/TableGen/Pattern.h"
21 #include "mlir/TableGen/Predicate.h"
22 #include "mlir/TableGen/Property.h"
23 #include "mlir/TableGen/Type.h"
24 #include "llvm/ADT/FunctionExtras.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/FormatAdapters.h"
31 #include "llvm/Support/PrettyStackTrace.h"
32 #include "llvm/Support/Signals.h"
33 #include "llvm/TableGen/Error.h"
34 #include "llvm/TableGen/Main.h"
35 #include "llvm/TableGen/Record.h"
36 #include "llvm/TableGen/TableGenBackend.h"
37 
38 using namespace mlir;
39 using namespace mlir::tblgen;
40 
41 using llvm::formatv;
42 using llvm::Record;
43 using llvm::RecordKeeper;
44 
45 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
46 
47 namespace llvm {
48 template <>
49 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
50   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
51                      raw_ostream &os, StringRef style) {
52     os << v.first << ":" << v.second;
53   }
54 };
55 } // namespace llvm
56 
57 //===----------------------------------------------------------------------===//
58 // PatternEmitter
59 //===----------------------------------------------------------------------===//
60 
61 namespace {
62 
63 class StaticMatcherHelper;
64 
65 class PatternEmitter {
66 public:
67   PatternEmitter(const Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
68                  StaticMatcherHelper &helper);
69 
70   // Emits the mlir::RewritePattern struct named `rewriteName`.
71   void emit(StringRef rewriteName);
72 
73   // Emits the static function of DAG matcher.
74   void emitStaticMatcher(DagNode tree, std::string funcName);
75 
76 private:
77   // Emits the code for matching ops.
78   void emitMatchLogic(DagNode tree, StringRef opName);
79 
80   // Emits the code for rewriting ops.
81   void emitRewriteLogic();
82 
83   //===--------------------------------------------------------------------===//
84   // Match utilities
85   //===--------------------------------------------------------------------===//
86 
87   // Emits C++ statements for matching the DAG structure.
88   void emitMatch(DagNode tree, StringRef name, int depth);
89 
90   // Emit C++ function call to static DAG matcher.
91   void emitStaticMatchCall(DagNode tree, StringRef name);
92 
93   // Emit C++ function call to static type/attribute constraint function.
94   void emitStaticVerifierCall(StringRef funcName, StringRef opName,
95                               StringRef arg, StringRef failureStr);
96 
97   // Emits C++ statements for matching using a native code call.
98   void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
99 
100   // Emits C++ statements for matching the op constrained by the given DAG
101   // `tree` returning the op's variable name.
102   void emitOpMatch(DagNode tree, StringRef opName, int depth);
103 
104   // Emits C++ statements for matching the `argIndex`-th argument of the given
105   // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
106   // bound name and the constraint of the operand respectively.
107   void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
108                         int operandIndex, DagLeaf operandMatcher,
109                         StringRef argName, int argIndex,
110                         std::optional<int> variadicSubIndex);
111 
112   // Emits C++ statements for matching the operands which can be matched in
113   // either order.
114   void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
115                               StringRef opName, int argIndex, int &operandIndex,
116                               int depth);
117 
118   // Emits C++ statements for matching a variadic operand.
119   void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree,
120                                 StringRef opName, int argIndex,
121                                 int &operandIndex, int depth);
122 
123   // Emits C++ statements for matching the `argIndex`-th argument of the given
124   // DAG `tree` as an attribute.
125   void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
126                           int depth);
127 
128   // Emits C++ for checking a match with a corresponding match failure
129   // diagnostic.
130   void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
131                       const llvm::formatv_object_base &failureFmt);
132 
133   // Emits C++ for checking a match with a corresponding match failure
134   // diagnostics.
135   void emitMatchCheck(StringRef opName, const std::string &matchStr,
136                       const std::string &failureStr);
137 
138   //===--------------------------------------------------------------------===//
139   // Rewrite utilities
140   //===--------------------------------------------------------------------===//
141 
142   // The entry point for handling a result pattern rooted at `resultTree`. This
143   // method dispatches to concrete handlers according to `resultTree`'s kind and
144   // returns a symbol representing the whole value pack. Callers are expected to
145   // further resolve the symbol according to the specific use case.
146   //
147   // `depth` is the nesting level of `resultTree`; 0 means top-level result
148   // pattern. For top-level result pattern, `resultIndex` indicates which result
149   // of the matched root op this pattern is intended to replace, which can be
150   // used to deduce the result type of the op generated from this result
151   // pattern.
152   std::string handleResultPattern(DagNode resultTree, int resultIndex,
153                                   int depth);
154 
155   // Emits the C++ statement to replace the matched DAG with a value built via
156   // calling native C++ code.
157   std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
158 
159   // Returns the symbol of the old value serving as the replacement.
160   StringRef handleReplaceWithValue(DagNode tree);
161 
162   // Emits the C++ statement to replace the matched DAG with an array of
163   // matched values.
164   std::string handleVariadic(DagNode tree, int depth);
165 
166   // Trailing directives are used at the end of DAG node argument lists to
167   // specify additional behaviour for op matchers and creators, etc.
168   struct TrailingDirectives {
169     // DAG node containing the `location` directive. Null if there is none.
170     DagNode location;
171 
172     // DAG node containing the `returnType` directive. Null if there is none.
173     DagNode returnType;
174 
175     // Number of found trailing directives.
176     int numDirectives;
177   };
178 
179   // Collect any trailing directives.
180   TrailingDirectives getTrailingDirectives(DagNode tree);
181 
182   // Returns the location value to use.
183   std::string getLocation(TrailingDirectives &tail);
184 
185   // Returns the location value to use.
186   std::string handleLocationDirective(DagNode tree);
187 
188   // Emit return type argument.
189   std::string handleReturnTypeArg(DagNode returnType, int i, int depth);
190 
191   // Emits the C++ statement to build a new op out of the given DAG `tree` and
192   // returns the variable name that this op is assigned to. If the root op in
193   // DAG `tree` has a specified name, the created op will be assigned to a
194   // variable of the given name. Otherwise, a unique name will be used as the
195   // result value name.
196   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
197 
198   using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
199 
200   // Emits a local variable for each value and attribute to be used for creating
201   // an op.
202   void createSeparateLocalVarsForOpArgs(DagNode node,
203                                         ChildNodeIndexNameMap &childNodeNames);
204 
205   // Emits the concrete arguments used to call an op's builder.
206   void supplyValuesForOpArgs(DagNode node,
207                              const ChildNodeIndexNameMap &childNodeNames,
208                              int depth);
209 
210   // Emits the local variables for holding all values as a whole and all named
211   // attributes as a whole to be used for creating an op.
212   void createAggregateLocalVarsForOpArgs(
213       DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
214 
215   // Returns the C++ expression to construct a constant attribute of the given
216   // `value` for the given attribute kind `attr`.
217   std::string handleConstantAttr(Attribute attr, const Twine &value);
218 
219   // Returns the C++ expression to build an argument from the given DAG `leaf`.
220   // `patArgName` is used to bound the argument to the source pattern.
221   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
222 
223   //===--------------------------------------------------------------------===//
224   // General utilities
225   //===--------------------------------------------------------------------===//
226 
227   // Collects all of the operations within the given dag tree.
228   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
229 
230   // Returns a unique symbol for a local variable of the given `op`.
231   std::string getUniqueSymbol(const Operator *op);
232 
233   //===--------------------------------------------------------------------===//
234   // Symbol utilities
235   //===--------------------------------------------------------------------===//
236 
237   // Returns how many static values the given DAG `node` correspond to.
238   int getNodeValueCount(DagNode node);
239 
240 private:
241   // Pattern instantiation location followed by the location of multiclass
242   // prototypes used. This is intended to be used as a whole to
243   // PrintFatalError() on errors.
244   ArrayRef<SMLoc> loc;
245 
246   // Op's TableGen Record to wrapper object.
247   RecordOperatorMap *opMap;
248 
249   // Handy wrapper for pattern being emitted.
250   Pattern pattern;
251 
252   // Map for all bound symbols' info.
253   SymbolInfoMap symbolInfoMap;
254 
255   StaticMatcherHelper &staticMatcherHelper;
256 
257   // The next unused ID for newly created values.
258   unsigned nextValueId = 0;
259 
260   raw_indented_ostream os;
261 
262   // Format contexts containing placeholder substitutions.
263   FmtContext fmtCtx;
264 };
265 
266 // Tracks DagNode's reference multiple times across patterns. Enables generating
267 // static matcher functions for DagNode's referenced multiple times rather than
268 // inlining them.
269 class StaticMatcherHelper {
270 public:
271   StaticMatcherHelper(raw_ostream &os, const RecordKeeper &records,
272                       RecordOperatorMap &mapper);
273 
274   // Determine if we should inline the match logic or delegate to a static
275   // function.
276   bool useStaticMatcher(DagNode node) {
277     // either/variadic node must be associated to the parentOp, thus we can't
278     // emit a static matcher rooted at them.
279     if (node.isEither() || node.isVariadic())
280       return false;
281 
282     return refStats[node] > kStaticMatcherThreshold;
283   }
284 
285   // Get the name of the static DAG matcher function corresponding to the node.
286   std::string getMatcherName(DagNode node) {
287     assert(useStaticMatcher(node));
288     return matcherNames[node];
289   }
290 
291   // Get the name of static type/attribute verification function.
292   StringRef getVerifierName(DagLeaf leaf);
293 
294   // Collect the `Record`s, i.e., the DRR, so that we can get the information of
295   // the duplicated DAGs.
296   void addPattern(const Record *record);
297 
298   // Emit all static functions of DAG Matcher.
299   void populateStaticMatchers(raw_ostream &os);
300 
301   // Emit all static functions for Constraints.
302   void populateStaticConstraintFunctions(raw_ostream &os);
303 
304 private:
305   static constexpr unsigned kStaticMatcherThreshold = 1;
306 
307   // Consider two patterns as down below,
308   //   DagNode_Root_A    DagNode_Root_B
309   //       \                 \
310   //     DagNode_C         DagNode_C
311   //         \                 \
312   //       DagNode_D         DagNode_D
313   //
314   // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of
315   // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced
316   // multiple times so we'll have static matchers for both of them. When we're
317   // emitting the match logic for DagNode_C, we will check if DagNode_D has the
318   // static matcher generated. If so, then we'll generate a call to the
319   // function, inline otherwise. In this case, inlining is not what we want. As
320   // a result, generate the static matcher in topological order to ensure all
321   // the dependent static matchers are generated and we can avoid accidentally
322   // inlining.
323   //
324   // The topological order of all the DagNodes among all patterns.
325   SmallVector<std::pair<DagNode, const Record *>> topologicalOrder;
326 
327   RecordOperatorMap &opMap;
328 
329   // Records of the static function name of each DagNode
330   DenseMap<DagNode, std::string> matcherNames;
331 
332   // After collecting all the DagNode in each pattern, `refStats` records the
333   // number of users for each DagNode. We will generate the static matcher for a
334   // DagNode while the number of users exceeds a certain threshold.
335   DenseMap<DagNode, unsigned> refStats;
336 
337   // Number of static matcher generated. This is used to generate a unique name
338   // for each DagNode.
339   int staticMatcherCounter = 0;
340 
341   // The DagLeaf which contains type or attr constraint.
342   SetVector<DagLeaf> constraints;
343 
344   // Static type/attribute verification function emitter.
345   StaticVerifierFunctionEmitter staticVerifierEmitter;
346 };
347 
348 } // namespace
349 
350 PatternEmitter::PatternEmitter(const Record *pat, RecordOperatorMap *mapper,
351                                raw_ostream &os, StaticMatcherHelper &helper)
352     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
353       symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) {
354   fmtCtx.withBuilder("rewriter");
355 }
356 
357 std::string PatternEmitter::handleConstantAttr(Attribute attr,
358                                                const Twine &value) {
359   if (!attr.isConstBuildable())
360     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
361                              " does not have the 'constBuilderCall' field");
362 
363   // TODO: Verify the constants here
364   return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
365 }
366 
367 void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
368   os << formatv(
369       "static ::llvm::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
370       "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
371       "*, 4> &tblgen_ops",
372       funcName);
373 
374   // We pass the reference of the variables that need to be captured. Hence we
375   // need to collect all the symbols in the tree first.
376   pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true);
377   symbolInfoMap.assignUniqueAlternativeNames();
378   for (const auto &info : symbolInfoMap)
379     os << formatv(", {0}", info.second.getArgDecl(info.first));
380 
381   os << ") {\n";
382   os.indent();
383   os << "(void)tblgen_ops;\n";
384 
385   // Note that a static matcher is considered at least one step from the match
386   // entry.
387   emitMatch(tree, "op0", /*depth=*/1);
388 
389   os << "return ::mlir::success();\n";
390   os.unindent();
391   os << "}\n\n";
392 }
393 
394 // Helper function to match patterns.
395 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
396   if (tree.isNativeCodeCall()) {
397     emitNativeCodeMatch(tree, name, depth);
398     return;
399   }
400 
401   if (tree.isOperation()) {
402     emitOpMatch(tree, name, depth);
403     return;
404   }
405 
406   PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
407 }
408 
409 void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
410   std::string funcName = staticMatcherHelper.getMatcherName(tree);
411   os << formatv("if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", funcName,
412                 opName);
413 
414   // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in
415   // one pass.
416 
417   // In general, bound symbol should have the unique name in the pattern but
418   // for the operand, binding same symbol to multiple operands imply a
419   // constraint at the same time. In this case, we will rename those operands
420   // with different names. As a result, we need to collect all the symbolInfos
421   // from the DagNode then get the updated name of the local variables from the
422   // global symbolInfoMap.
423 
424   // Collect all the bound symbols in the Dag
425   SymbolInfoMap localSymbolMap(loc);
426   pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true);
427 
428   for (const auto &info : localSymbolMap) {
429     auto name = info.first;
430     auto symboInfo = info.second;
431     auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo);
432     os << formatv(", {0}", ret->second.getVarName(name));
433   }
434 
435   os << "))) {\n";
436   os.scope().os << "return ::mlir::failure();\n";
437   os << "}\n";
438 }
439 
440 void PatternEmitter::emitStaticVerifierCall(StringRef funcName,
441                                             StringRef opName, StringRef arg,
442                                             StringRef failureStr) {
443   os << formatv("if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n",
444                 funcName, opName, arg, failureStr);
445   os.scope().os << "return ::mlir::failure();\n";
446   os << "}\n";
447 }
448 
449 // Helper function to match patterns.
450 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
451                                          int depth) {
452   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
453   LLVM_DEBUG(tree.print(llvm::dbgs()));
454   LLVM_DEBUG(llvm::dbgs() << '\n');
455 
456   // The order of generating static matcher follows the topological order so
457   // that for every dependent DagNode already have their static matcher
458   // generated if needed. The reason we check if `getMatcherName(tree).empty()`
459   // is when we are generating the static matcher for a DagNode itself. In this
460   // case, we need to emit the function body rather than a function call.
461   if (staticMatcherHelper.useStaticMatcher(tree) &&
462       !staticMatcherHelper.getMatcherName(tree).empty()) {
463     emitStaticMatchCall(tree, opName);
464 
465     // NativeCodeCall will never be at depth 0 so that we don't need to catch
466     // the root operation as emitOpMatch();
467 
468     return;
469   }
470 
471   // TODO(suderman): iterate through arguments, determine their types, output
472   // names.
473   SmallVector<std::string, 8> capture;
474 
475   raw_indented_ostream::DelimitedScope scope(os);
476 
477   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
478     std::string argName = formatv("arg{0}_{1}", depth, i);
479     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
480       if (argTree.isEither())
481         PrintFatalError(loc, "NativeCodeCall cannot have `either` operands");
482       if (argTree.isVariadic())
483         PrintFatalError(loc, "NativeCodeCall cannot have `variadic` operands");
484 
485       os << "::mlir::Value " << argName << ";\n";
486     } else {
487       auto leaf = tree.getArgAsLeaf(i);
488       if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
489         os << "::mlir::Attribute " << argName << ";\n";
490       } else {
491         os << "::mlir::Value " << argName << ";\n";
492       }
493     }
494 
495     capture.push_back(std::move(argName));
496   }
497 
498   auto tail = getTrailingDirectives(tree);
499   if (tail.returnType)
500     PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
501   auto locToUse = getLocation(tail);
502 
503   auto fmt = tree.getNativeCodeTemplate();
504   if (fmt.count("$_self") != 1)
505     PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
506                          "passing the defining Operation");
507 
508   auto nativeCodeCall = std::string(
509       tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()),
510             static_cast<ArrayRef<std::string>>(capture)));
511 
512   emitMatchCheck(opName, formatv("!::mlir::failed({0})", nativeCodeCall),
513                  formatv("\"{0} return ::mlir::failure\"", nativeCodeCall));
514 
515   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
516     auto name = tree.getArgName(i);
517     if (!name.empty() && name != "_") {
518       os << formatv("{0} = {1};\n", name, capture[i]);
519     }
520   }
521 
522   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
523     std::string argName = capture[i];
524 
525     // Handle nested DAG construct first
526     if (tree.getArgAsNestedDag(i)) {
527       PrintFatalError(
528           loc, formatv("Matching nested tree in NativeCodecall not support for "
529                        "{0} as arg {1}",
530                        argName, i));
531     }
532 
533     DagLeaf leaf = tree.getArgAsLeaf(i);
534 
535     // The parameter for native function doesn't bind any constraints.
536     if (leaf.isUnspecified())
537       continue;
538 
539     auto constraint = leaf.getAsConstraint();
540 
541     std::string self;
542     if (leaf.isAttrMatcher() || leaf.isConstantAttr())
543       self = argName;
544     else
545       self = formatv("{0}.getType()", argName);
546     StringRef verifier = staticMatcherHelper.getVerifierName(leaf);
547     emitStaticVerifierCall(
548         verifier, opName, self,
549         formatv("\"operand {0} of native code call '{1}' failed to satisfy "
550                 "constraint: "
551                 "'{2}'\"",
552                 i, tree.getNativeCodeTemplate(),
553                 escapeString(constraint.getSummary()))
554             .str());
555   }
556 
557   LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
558 }
559 
560 // Helper function to match patterns.
561 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
562   Operator &op = tree.getDialectOp(opMap);
563   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
564                           << op.getOperationName() << "' at depth " << depth
565                           << '\n');
566 
567   auto getCastedName = [depth]() -> std::string {
568     return formatv("castedOp{0}", depth);
569   };
570 
571   // The order of generating static matcher follows the topological order so
572   // that for every dependent DagNode already have their static matcher
573   // generated if needed. The reason we check if `getMatcherName(tree).empty()`
574   // is when we are generating the static matcher for a DagNode itself. In this
575   // case, we need to emit the function body rather than a function call.
576   if (staticMatcherHelper.useStaticMatcher(tree) &&
577       !staticMatcherHelper.getMatcherName(tree).empty()) {
578     emitStaticMatchCall(tree, opName);
579     // In the codegen of rewriter, we suppose that castedOp0 will capture the
580     // root operation. Manually add it if the root DagNode is a static matcher.
581     if (depth == 0)
582       os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); "
583                     "(void){2};\n",
584                     opName, op.getQualCppClassName(), getCastedName());
585     return;
586   }
587 
588   std::string castedName = getCastedName();
589   os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
590                 "(void){0};\n",
591                 castedName, opName, op.getQualCppClassName());
592 
593   // Skip the operand matching at depth 0 as the pattern rewriter already does.
594   if (depth != 0)
595     emitMatchCheck(opName, /*matchStr=*/castedName,
596                    formatv("\"{0} is not {1} type\"", castedName,
597                            op.getQualCppClassName()));
598 
599   // If the operand's name is set, set to that variable.
600   auto name = tree.getSymbol();
601   if (!name.empty())
602     os << formatv("{0} = {1};\n", name, castedName);
603 
604   for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;
605        ++i, ++opArgIdx) {
606     auto opArg = op.getArg(opArgIdx);
607     std::string argName = formatv("op{0}", depth + 1);
608 
609     // Handle nested DAG construct first
610     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
611       if (argTree.isEither()) {
612         emitEitherOperandMatch(tree, argTree, castedName, opArgIdx, nextOperand,
613                                depth);
614         ++opArgIdx;
615         continue;
616       }
617       if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
618         if (argTree.isVariadic()) {
619           if (!operand->isVariadic()) {
620             auto error = formatv("variadic DAG construct can't match op {0}'s "
621                                  "non-variadic operand #{1}",
622                                  op.getOperationName(), opArgIdx);
623             PrintFatalError(loc, error);
624           }
625           emitVariadicOperandMatch(tree, argTree, castedName, opArgIdx,
626                                    nextOperand, depth);
627           ++nextOperand;
628           continue;
629         }
630         if (operand->isVariableLength()) {
631           auto error = formatv("use nested DAG construct to match op {0}'s "
632                                "variadic operand #{1} unsupported now",
633                                op.getOperationName(), opArgIdx);
634           PrintFatalError(loc, error);
635         }
636       }
637 
638       os << "{\n";
639 
640       // Attributes don't count for getODSOperands.
641       // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
642       os.indent() << formatv(
643           "auto *{0} = "
644           "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
645           argName, castedName, nextOperand);
646       // Null check of operand's definingOp
647       emitMatchCheck(
648           castedName, /*matchStr=*/argName,
649           formatv("\"There's no operation that defines operand {0} of {1}\"",
650                   nextOperand++, castedName));
651       emitMatch(argTree, argName, depth + 1);
652       os << formatv("tblgen_ops.push_back({0});\n", argName);
653       os.unindent() << "}\n";
654       continue;
655     }
656 
657     // Next handle DAG leaf: operand or attribute
658     if (isa<NamedTypeConstraint *>(opArg)) {
659       auto operandName =
660           formatv("{0}.getODSOperands({1})", castedName, nextOperand);
661       emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
662                        /*operandMatcher=*/tree.getArgAsLeaf(i),
663                        /*argName=*/tree.getArgName(i), opArgIdx,
664                        /*variadicSubIndex=*/std::nullopt);
665       ++nextOperand;
666     } else if (isa<NamedAttribute *>(opArg)) {
667       emitAttributeMatch(tree, opName, opArgIdx, depth);
668     } else {
669       PrintFatalError(loc, "unhandled case when matching op");
670     }
671   }
672   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
673                           << op.getOperationName() << "' at depth " << depth
674                           << '\n');
675 }
676 
677 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
678                                       StringRef operandName, int operandIndex,
679                                       DagLeaf operandMatcher, StringRef argName,
680                                       int argIndex,
681                                       std::optional<int> variadicSubIndex) {
682   Operator &op = tree.getDialectOp(opMap);
683   auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));
684 
685   // If a constraint is specified, we need to generate C++ statements to
686   // check the constraint.
687   if (!operandMatcher.isUnspecified()) {
688     if (!operandMatcher.isOperandMatcher())
689       PrintFatalError(
690           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
691                        op.getOperationName(), argIndex + 1));
692 
693     // Only need to verify if the matcher's type is different from the one
694     // of op definition.
695     Constraint constraint = operandMatcher.getAsConstraint();
696     if (operand->constraint != constraint) {
697       if (operand->isVariableLength()) {
698         auto error = formatv(
699             "further constrain op {0}'s variadic operand #{1} unsupported now",
700             op.getOperationName(), argIndex);
701         PrintFatalError(loc, error);
702       }
703       auto self = formatv("(*{0}.begin()).getType()", operandName);
704       StringRef verifier = staticMatcherHelper.getVerifierName(operandMatcher);
705       emitStaticVerifierCall(
706           verifier, opName, self.str(),
707           formatv(
708               "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
709               operand - op.operand_begin(), op.getOperationName(),
710               escapeString(constraint.getSummary()))
711               .str());
712     }
713   }
714 
715   // Capture the value
716   // `$_` is a special symbol to ignore op argument matching.
717   if (!argName.empty() && argName != "_") {
718     auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
719                                              variadicSubIndex);
720     if (res == symbolInfoMap.end())
721       PrintFatalError(loc, formatv("symbol not found: {0}", argName));
722 
723     os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName);
724   }
725 }
726 
727 void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
728                                             StringRef opName, int argIndex,
729                                             int &operandIndex, int depth) {
730   constexpr int numEitherArgs = 2;
731   if (eitherArgTree.getNumArgs() != numEitherArgs)
732     PrintFatalError(loc, "`either` only supports grouping two operands");
733 
734   Operator &op = tree.getDialectOp(opMap);
735 
736   std::string codeBuffer;
737   llvm::raw_string_ostream tblgenOps(codeBuffer);
738 
739   std::string lambda = formatv("eitherLambda{0}", depth);
740   os << formatv(
741       "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n",
742       lambda);
743 
744   os.indent();
745 
746   for (int i = 0; i < numEitherArgs; ++i, ++argIndex) {
747     if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) {
748       if (argTree.isEither())
749         PrintFatalError(loc, "either cannot be nested");
750 
751       std::string argName = formatv("local_op_{0}", i).str();
752 
753       os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName,
754                     i);
755 
756       // Indent emitMatchCheck and emitMatch because they declare local
757       // variables.
758       os << "{\n";
759       os.indent();
760 
761       emitMatchCheck(
762           opName, /*matchStr=*/argName,
763           formatv("\"There's no operation that defines operand {0} of {1}\"",
764                   operandIndex++, opName));
765       emitMatch(argTree, argName, depth + 1);
766 
767       os.unindent() << "}\n";
768 
769       // `tblgen_ops` is used to collect the matched operations. In either, we
770       // need to queue the operation only if the matching success. Thus we emit
771       // the code at the end.
772       tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
773     } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
774       emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
775                        operandIndex,
776                        /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
777                        /*argName=*/eitherArgTree.getArgName(i), argIndex,
778                        /*variadicSubIndex=*/std::nullopt);
779       ++operandIndex;
780     } else {
781       PrintFatalError(loc, "either can only be applied on operand");
782     }
783   }
784 
785   os << tblgenOps.str();
786   os << "return ::mlir::success();\n";
787   os.unindent() << "};\n";
788 
789   os << "{\n";
790   os.indent();
791 
792   os << formatv("auto eitherOperand0 = {0}.getODSOperands({1});\n", opName,
793                 operandIndex - 2);
794   os << formatv("auto eitherOperand1 = {0}.getODSOperands({1});\n", opName,
795                 operandIndex - 1);
796 
797   os << formatv("if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && "
798                 "::mlir::failed({0}(eitherOperand1, "
799                 "eitherOperand0)))\n",
800                 lambda);
801   os.indent() << "return ::mlir::failure();\n";
802 
803   os.unindent().unindent() << "}\n";
804 }
805 
806 void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
807                                               DagNode variadicArgTree,
808                                               StringRef opName, int argIndex,
809                                               int &operandIndex, int depth) {
810   Operator &op = tree.getDialectOp(opMap);
811 
812   os << "{\n";
813   os.indent();
814 
815   os << formatv("auto variadic_operand_range = {0}.getODSOperands({1});\n",
816                 opName, operandIndex);
817   os << formatv("if (variadic_operand_range.size() != {0}) "
818                 "return ::mlir::failure();\n",
819                 variadicArgTree.getNumArgs());
820 
821   StringRef variadicTreeName = variadicArgTree.getSymbol();
822   if (!variadicTreeName.empty()) {
823     auto res =
824         symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
825                                       /*variadicSubIndex=*/std::nullopt);
826     if (res == symbolInfoMap.end())
827       PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));
828 
829     os << formatv("{0} = variadic_operand_range;\n",
830                   res->second.getVarName(variadicTreeName));
831   }
832 
833   for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) {
834     if (DagNode argTree = variadicArgTree.getArgAsNestedDag(i)) {
835       if (!argTree.isOperation())
836         PrintFatalError(loc, "variadic only accepts operation sub-dags");
837 
838       os << "{\n";
839       os.indent();
840 
841       std::string argName = formatv("local_op_{0}", i).str();
842       os << formatv("auto *{0} = "
843                     "variadic_operand_range[{1}].getDefiningOp();\n",
844                     argName, i);
845       emitMatchCheck(
846           opName, /*matchStr=*/argName,
847           formatv("\"There's no operation that defines variadic operand "
848                   "{0} (variadic sub-opearnd #{1}) of {2}\"",
849                   operandIndex, i, opName));
850       emitMatch(argTree, argName, depth + 1);
851       os << formatv("tblgen_ops.push_back({0});\n", argName);
852 
853       os.unindent() << "}\n";
854     } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
855       auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i);
856       emitOperandMatch(tree, opName, operandName.str(), operandIndex,
857                        /*operandMatcher=*/variadicArgTree.getArgAsLeaf(i),
858                        /*argName=*/variadicArgTree.getArgName(i), argIndex, i);
859     } else {
860       PrintFatalError(loc, "variadic can only be applied on operand");
861     }
862   }
863 
864   os.unindent() << "}\n";
865 }
866 
867 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
868                                         int argIndex, int depth) {
869   Operator &op = tree.getDialectOp(opMap);
870   auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex));
871   const auto &attr = namedAttr->attr;
872 
873   os << "{\n";
874   os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
875                          "(void)tblgen_attr;\n",
876                          opName, attr.getStorageType(), namedAttr->name);
877 
878   // TODO: This should use getter method to avoid duplication.
879   if (attr.hasDefaultValue()) {
880     os << "if (!tblgen_attr) tblgen_attr = "
881        << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
882                             attr.getDefaultValue()))
883        << ";\n";
884   } else if (attr.isOptional()) {
885     // For a missing attribute that is optional according to definition, we
886     // should just capture a mlir::Attribute() to signal the missing state.
887     // That is precisely what getDiscardableAttr() returns on missing
888     // attributes.
889   } else {
890     emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
891                    formatv("\"expected op '{0}' to have attribute '{1}' "
892                            "of type '{2}'\"",
893                            op.getOperationName(), namedAttr->name,
894                            attr.getStorageType()));
895   }
896 
897   auto matcher = tree.getArgAsLeaf(argIndex);
898   if (!matcher.isUnspecified()) {
899     if (!matcher.isAttrMatcher()) {
900       PrintFatalError(
901           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
902                        op.getOperationName(), argIndex + 1));
903     }
904 
905     // If a constraint is specified, we need to generate function call to its
906     // static verifier.
907     StringRef verifier = staticMatcherHelper.getVerifierName(matcher);
908     if (attr.isOptional()) {
909       // Avoid dereferencing null attribute. This is using a simple heuristic to
910       // avoid common cases of attempting to dereference null attribute. This
911       // will return where there is no check if attribute is null unless the
912       // attribute's value is not used.
913       // FIXME: This could be improved as some null dereferences could slip
914       // through.
915       if (!StringRef(matcher.getConditionTemplate()).contains("!$_self") &&
916           StringRef(matcher.getConditionTemplate()).contains("$_self")) {
917         os << "if (!tblgen_attr) return ::mlir::failure();\n";
918       }
919     }
920     emitStaticVerifierCall(
921         verifier, opName, "tblgen_attr",
922         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
923                 "'{2}'\"",
924                 op.getOperationName(), namedAttr->name,
925                 escapeString(matcher.getAsConstraint().getSummary()))
926             .str());
927   }
928 
929   // Capture the value
930   auto name = tree.getArgName(argIndex);
931   // `$_` is a special symbol to ignore op argument matching.
932   if (!name.empty() && name != "_") {
933     os << formatv("{0} = tblgen_attr;\n", name);
934   }
935 
936   os.unindent() << "}\n";
937 }
938 
939 void PatternEmitter::emitMatchCheck(
940     StringRef opName, const FmtObjectBase &matchFmt,
941     const llvm::formatv_object_base &failureFmt) {
942   emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
943 }
944 
945 void PatternEmitter::emitMatchCheck(StringRef opName,
946                                     const std::string &matchStr,
947                                     const std::string &failureStr) {
948 
949   os << "if (!(" << matchStr << "))";
950   os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
951                               << ", [&](::mlir::Diagnostic &diag) {\n  diag << "
952                               << failureStr << ";\n});";
953 }
954 
955 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
956   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
957   int depth = 0;
958   emitMatch(tree, opName, depth);
959 
960   for (auto &appliedConstraint : pattern.getConstraints()) {
961     auto &constraint = appliedConstraint.constraint;
962     auto &entities = appliedConstraint.entities;
963 
964     auto condition = constraint.getConditionTemplate();
965     if (isa<TypeConstraint>(constraint)) {
966       if (entities.size() != 1)
967         PrintFatalError(loc, "type constraint requires exactly one argument");
968 
969       auto self = formatv("({0}.getType())",
970                           symbolInfoMap.getValueAndRangeUse(entities.front()));
971       emitMatchCheck(
972           opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
973           formatv("\"value entity '{0}' failed to satisfy constraint: '{1}'\"",
974                   entities.front(), escapeString(constraint.getSummary())));
975 
976     } else if (isa<AttrConstraint>(constraint)) {
977       PrintFatalError(
978           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
979     } else {
980       // TODO: replace formatv arguments with the exact specified
981       // args.
982       if (entities.size() > 4) {
983         PrintFatalError(loc, "only support up to 4-entity constraints now");
984       }
985       SmallVector<std::string, 4> names;
986       int i = 0;
987       for (int e = entities.size(); i < e; ++i)
988         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
989       std::string self = appliedConstraint.self;
990       if (!self.empty())
991         self = symbolInfoMap.getValueAndRangeUse(self);
992       for (; i < 4; ++i)
993         names.push_back("<unused>");
994       emitMatchCheck(opName,
995                      tgfmt(condition, &fmtCtx.withSelf(self), names[0],
996                            names[1], names[2], names[3]),
997                      formatv("\"entities '{0}' failed to satisfy constraint: "
998                              "'{1}'\"",
999                              llvm::join(entities, ", "),
1000                              escapeString(constraint.getSummary())));
1001     }
1002   }
1003 
1004   // Some of the operands could be bound to the same symbol name, we need
1005   // to enforce equality constraint on those.
1006   // TODO: we should be able to emit equality checks early
1007   // and short circuit unnecessary work if vars are not equal.
1008   for (auto symbolInfoIt = symbolInfoMap.begin();
1009        symbolInfoIt != symbolInfoMap.end();) {
1010     auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
1011     auto startRange = range.first;
1012     auto endRange = range.second;
1013 
1014     auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
1015     for (++startRange; startRange != endRange; ++startRange) {
1016       auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
1017       emitMatchCheck(
1018           opName,
1019           formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
1020           formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
1021                   secondOperand));
1022     }
1023 
1024     symbolInfoIt = endRange;
1025   }
1026 
1027   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
1028 }
1029 
1030 void PatternEmitter::collectOps(DagNode tree,
1031                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
1032   // Check if this tree is an operation.
1033   if (tree.isOperation()) {
1034     const Operator &op = tree.getDialectOp(opMap);
1035     LLVM_DEBUG(llvm::dbgs()
1036                << "found operation " << op.getOperationName() << '\n');
1037     ops.insert(&op);
1038   }
1039 
1040   // Recurse the arguments of the tree.
1041   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
1042     if (auto child = tree.getArgAsNestedDag(i))
1043       collectOps(child, ops);
1044 }
1045 
1046 void PatternEmitter::emit(StringRef rewriteName) {
1047   // Get the DAG tree for the source pattern.
1048   DagNode sourceTree = pattern.getSourcePattern();
1049 
1050   const Operator &rootOp = pattern.getSourceRootOp();
1051   auto rootName = rootOp.getOperationName();
1052 
1053   // Collect the set of result operations.
1054   llvm::SmallPtrSet<const Operator *, 4> resultOps;
1055   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
1056   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
1057     collectOps(pattern.getResultPattern(i), resultOps);
1058   }
1059   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
1060 
1061   // Emit RewritePattern for Pattern.
1062   auto locs = pattern.getLocation();
1063   os << formatv("/* Generated from:\n    {0:$[ instantiating\n    ]}\n*/\n",
1064                 make_range(locs.rbegin(), locs.rend()));
1065   os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
1066   {0}(::mlir::MLIRContext *context)
1067       : ::mlir::RewritePattern("{1}", {2}, context, {{)",
1068                 rewriteName, rootName, pattern.getBenefit());
1069   // Sort result operators by name.
1070   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
1071                                                          resultOps.end());
1072   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
1073     return lhs->getOperationName() < rhs->getOperationName();
1074   });
1075   llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
1076     os << '"' << op->getOperationName() << '"';
1077   });
1078   os << "}) {}\n";
1079 
1080   // Emit matchAndRewrite() function.
1081   {
1082     auto classScope = os.scope();
1083     os.printReindented(R"(
1084     ::llvm::LogicalResult matchAndRewrite(::mlir::Operation *op0,
1085         ::mlir::PatternRewriter &rewriter) const override {)")
1086         << '\n';
1087     {
1088       auto functionScope = os.scope();
1089 
1090       // Register all symbols bound in the source pattern.
1091       pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
1092 
1093       LLVM_DEBUG(llvm::dbgs()
1094                  << "start creating local variables for capturing matches\n");
1095       os << "// Variables for capturing values and attributes used while "
1096             "creating ops\n";
1097       // Create local variables for storing the arguments and results bound
1098       // to symbols.
1099       for (const auto &symbolInfoPair : symbolInfoMap) {
1100         const auto &symbol = symbolInfoPair.first;
1101         const auto &info = symbolInfoPair.second;
1102 
1103         os << info.getVarDecl(symbol);
1104       }
1105       // TODO: capture ops with consistent numbering so that it can be
1106       // reused for fused loc.
1107       os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n";
1108       LLVM_DEBUG(llvm::dbgs()
1109                  << "done creating local variables for capturing matches\n");
1110 
1111       os << "// Match\n";
1112       os << "tblgen_ops.push_back(op0);\n";
1113       emitMatchLogic(sourceTree, "op0");
1114 
1115       os << "\n// Rewrite\n";
1116       emitRewriteLogic();
1117 
1118       os << "return ::mlir::success();\n";
1119     }
1120     os << "}\n";
1121   }
1122   os << "};\n\n";
1123 }
1124 
1125 void PatternEmitter::emitRewriteLogic() {
1126   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
1127   const Operator &rootOp = pattern.getSourceRootOp();
1128   int numExpectedResults = rootOp.getNumResults();
1129   int numResultPatterns = pattern.getNumResultPatterns();
1130 
1131   // First register all symbols bound to ops generated in result patterns.
1132   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
1133 
1134   // Only the last N static values generated are used to replace the matched
1135   // root N-result op. We need to calculate the starting index (of the results
1136   // of the matched op) each result pattern is to replace.
1137   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
1138   // If we don't need to replace any value at all, set the replacement starting
1139   // index as the number of result patterns so we skip all of them when trying
1140   // to replace the matched op's results.
1141   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
1142   for (int i = numResultPatterns - 1; i >= 0; --i) {
1143     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
1144     offsets[i] = offsets[i + 1] - numValues;
1145     if (offsets[i] == 0) {
1146       if (replStartIndex == -1)
1147         replStartIndex = i;
1148     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
1149       auto error = formatv(
1150           "cannot use the same multi-result op '{0}' to generate both "
1151           "auxiliary values and values to be used for replacing the matched op",
1152           pattern.getResultPattern(i).getSymbol());
1153       PrintFatalError(loc, error);
1154     }
1155   }
1156 
1157   if (offsets.front() > 0) {
1158     const char error[] =
1159         "not enough values generated to replace the matched op";
1160     PrintFatalError(loc, error);
1161   }
1162 
1163   os << "auto odsLoc = rewriter.getFusedLoc({";
1164   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
1165     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
1166   }
1167   os << "}); (void)odsLoc;\n";
1168 
1169   // Process auxiliary result patterns.
1170   for (int i = 0; i < replStartIndex; ++i) {
1171     DagNode resultTree = pattern.getResultPattern(i);
1172     auto val = handleResultPattern(resultTree, offsets[i], 0);
1173     // Normal op creation will be streamed to `os` by the above call; but
1174     // NativeCodeCall will only be materialized to `os` if it is used. Here
1175     // we are handling auxiliary patterns so we want the side effect even if
1176     // NativeCodeCall is not replacing matched root op's results.
1177     if (resultTree.isNativeCodeCall() &&
1178         resultTree.getNumReturnsOfNativeCode() == 0)
1179       os << val << ";\n";
1180   }
1181 
1182   auto processSupplementalPatterns = [&]() {
1183     int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
1184     for (int i = 0, offset = -numSupplementalPatterns;
1185          i < numSupplementalPatterns; ++i) {
1186       DagNode resultTree = pattern.getSupplementalPattern(i);
1187       auto val = handleResultPattern(resultTree, offset++, 0);
1188       if (resultTree.isNativeCodeCall() &&
1189           resultTree.getNumReturnsOfNativeCode() == 0)
1190         os << val << ";\n";
1191     }
1192   };
1193 
1194   if (numExpectedResults == 0) {
1195     assert(replStartIndex >= numResultPatterns &&
1196            "invalid auxiliary vs. replacement pattern division!");
1197     processSupplementalPatterns();
1198     // No result to replace. Just erase the op.
1199     os << "rewriter.eraseOp(op0);\n";
1200   } else {
1201     // Process replacement result patterns.
1202     os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
1203     for (int i = replStartIndex; i < numResultPatterns; ++i) {
1204       DagNode resultTree = pattern.getResultPattern(i);
1205       auto val = handleResultPattern(resultTree, offsets[i], 0);
1206       os << "\n";
1207       // Resolve each symbol for all range use so that we can loop over them.
1208       // We need an explicit cast to `SmallVector` to capture the cases where
1209       // `{0}` resolves to an `Operation::result_range` as well as cases that
1210       // are not iterable (e.g. vector that gets wrapped in additional braces by
1211       // RewriterGen).
1212       // TODO: Revisit the need for materializing a vector.
1213       os << symbolInfoMap.getAllRangeUse(
1214           val,
1215           "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
1216           "  tblgen_repl_values.push_back(v);\n}\n",
1217           "\n");
1218     }
1219     processSupplementalPatterns();
1220     os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
1221   }
1222 
1223   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
1224 }
1225 
1226 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
1227   return std::string(
1228       formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
1229 }
1230 
1231 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
1232                                                 int resultIndex, int depth) {
1233   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
1234   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
1235   LLVM_DEBUG(llvm::dbgs() << '\n');
1236 
1237   if (resultTree.isLocationDirective()) {
1238     PrintFatalError(loc,
1239                     "location directive can only be used with op creation");
1240   }
1241 
1242   if (resultTree.isNativeCodeCall())
1243     return handleReplaceWithNativeCodeCall(resultTree, depth);
1244 
1245   if (resultTree.isReplaceWithValue())
1246     return handleReplaceWithValue(resultTree).str();
1247 
1248   if (resultTree.isVariadic())
1249     return handleVariadic(resultTree, depth);
1250 
1251   // Normal op creation.
1252   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
1253   if (resultTree.getSymbol().empty()) {
1254     // This is an op not explicitly bound to a symbol in the rewrite rule.
1255     // Register the auto-generated symbol for it.
1256     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
1257   }
1258   return symbol;
1259 }
1260 
1261 std::string PatternEmitter::handleVariadic(DagNode tree, int depth) {
1262   assert(tree.isVariadic());
1263 
1264   std::string output;
1265   llvm::raw_string_ostream oss(output);
1266   auto name = std::string(formatv("tblgen_variadic_values_{0}", nextValueId++));
1267   symbolInfoMap.bindValue(name);
1268   oss << "::llvm::SmallVector<::mlir::Value, 4> " << name << ";\n";
1269   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1270     if (auto child = tree.getArgAsNestedDag(i)) {
1271       oss << name << ".push_back(" << handleResultPattern(child, i, depth + 1)
1272           << ");\n";
1273     } else {
1274       oss << name << ".push_back("
1275           << handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i))
1276           << ");\n";
1277     }
1278   }
1279 
1280   os << oss.str();
1281   return name;
1282 }
1283 
1284 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
1285   assert(tree.isReplaceWithValue());
1286 
1287   if (tree.getNumArgs() != 1) {
1288     PrintFatalError(
1289         loc, "replaceWithValue directive must take exactly one argument");
1290   }
1291 
1292   if (!tree.getSymbol().empty()) {
1293     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
1294   }
1295 
1296   return tree.getArgName(0);
1297 }
1298 
1299 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
1300   assert(tree.isLocationDirective());
1301   auto lookUpArgLoc = [this, &tree](int idx) {
1302     const auto *const lookupFmt = "{0}.getLoc()";
1303     return symbolInfoMap.getValueAndRangeUse(tree.getArgName(idx), lookupFmt);
1304   };
1305 
1306   if (tree.getNumArgs() == 0)
1307     llvm::PrintFatalError(
1308         "At least one argument to location directive required");
1309 
1310   if (!tree.getSymbol().empty())
1311     PrintFatalError(loc, "cannot bind symbol to location");
1312 
1313   if (tree.getNumArgs() == 1) {
1314     DagLeaf leaf = tree.getArgAsLeaf(0);
1315     if (leaf.isStringAttr())
1316       return formatv("::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))",
1317                      leaf.getStringAttr())
1318           .str();
1319     return lookUpArgLoc(0);
1320   }
1321 
1322   std::string ret;
1323   llvm::raw_string_ostream os(ret);
1324   std::string strAttr;
1325   os << "rewriter.getFusedLoc({";
1326   bool first = true;
1327   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1328     DagLeaf leaf = tree.getArgAsLeaf(i);
1329     // Handle the optional string value.
1330     if (leaf.isStringAttr()) {
1331       if (!strAttr.empty())
1332         llvm::PrintFatalError("Only one string attribute may be specified");
1333       strAttr = leaf.getStringAttr();
1334       continue;
1335     }
1336     os << (first ? "" : ", ") << lookUpArgLoc(i);
1337     first = false;
1338   }
1339   os << "}";
1340   if (!strAttr.empty()) {
1341     os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
1342   }
1343   os << ")";
1344   return os.str();
1345 }
1346 
1347 std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
1348                                                 int depth) {
1349   // Nested NativeCodeCall.
1350   if (auto dagNode = returnType.getArgAsNestedDag(i)) {
1351     if (!dagNode.isNativeCodeCall())
1352       PrintFatalError(loc, "nested DAG in `returnType` must be a native code "
1353                            "call");
1354     return handleReplaceWithNativeCodeCall(dagNode, depth);
1355   }
1356   // String literal.
1357   auto dagLeaf = returnType.getArgAsLeaf(i);
1358   if (dagLeaf.isStringAttr())
1359     return tgfmt(dagLeaf.getStringAttr(), &fmtCtx);
1360   return tgfmt(
1361       "$0.getType()", &fmtCtx,
1362       handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i)));
1363 }
1364 
1365 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
1366                                              StringRef patArgName) {
1367   if (leaf.isStringAttr())
1368     PrintFatalError(loc, "raw string not supported as argument");
1369   if (leaf.isConstantAttr()) {
1370     auto constAttr = leaf.getAsConstantAttr();
1371     return handleConstantAttr(constAttr.getAttribute(),
1372                               constAttr.getConstantValue());
1373   }
1374   if (leaf.isEnumAttrCase()) {
1375     auto enumCase = leaf.getAsEnumAttrCase();
1376     // This is an enum case backed by an IntegerAttr. We need to get its value
1377     // to build the constant.
1378     std::string val = std::to_string(enumCase.getValue());
1379     return handleConstantAttr(enumCase, val);
1380   }
1381 
1382   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
1383   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
1384   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
1385     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
1386                             << "' (via symbol ref)\n");
1387     return argName;
1388   }
1389   if (leaf.isNativeCodeCall()) {
1390     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
1391     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
1392                             << "' (via NativeCodeCall)\n");
1393     return std::string(repl);
1394   }
1395   PrintFatalError(loc, "unhandled case when rewriting op");
1396 }
1397 
1398 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
1399                                                             int depth) {
1400   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
1401   LLVM_DEBUG(tree.print(llvm::dbgs()));
1402   LLVM_DEBUG(llvm::dbgs() << '\n');
1403 
1404   auto fmt = tree.getNativeCodeTemplate();
1405 
1406   SmallVector<std::string, 16> attrs;
1407 
1408   auto tail = getTrailingDirectives(tree);
1409   if (tail.returnType)
1410     PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
1411   auto locToUse = getLocation(tail);
1412 
1413   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1414     if (tree.isNestedDagArg(i)) {
1415       attrs.push_back(
1416           handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
1417     } else {
1418       attrs.push_back(
1419           handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
1420     }
1421     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
1422                             << " replacement: " << attrs[i] << "\n");
1423   }
1424 
1425   std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse),
1426                              static_cast<ArrayRef<std::string>>(attrs));
1427 
1428   // In general, NativeCodeCall without naming binding don't need this. To
1429   // ensure void helper function has been correctly labeled, i.e., use
1430   // NativeCodeCallVoid, we cache the result to a local variable so that we will
1431   // get a compilation error in the auto-generated file.
1432   // Example.
1433   //   // In the td file
1434   //   Pat<(...), (NativeCodeCall<Foo> ...)>
1435   //
1436   //   ---
1437   //
1438   //   // In the auto-generated .cpp
1439   //   ...
1440   //   // Causes compilation error if Foo() returns void.
1441   //   auto nativeVar = Foo();
1442   //   ...
1443   if (tree.getNumReturnsOfNativeCode() != 0) {
1444     // Determine the local variable name for return value.
1445     std::string varName =
1446         SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
1447     if (varName.empty()) {
1448       varName = formatv("nativeVar_{0}", nextValueId++);
1449       // Register the local variable for later uses.
1450       symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
1451     }
1452 
1453     // Catch the return value of helper function.
1454     os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);
1455 
1456     if (!tree.getSymbol().empty())
1457       symbol = tree.getSymbol().str();
1458     else
1459       symbol = varName;
1460   }
1461 
1462   return symbol;
1463 }
1464 
1465 int PatternEmitter::getNodeValueCount(DagNode node) {
1466   if (node.isOperation()) {
1467     // If the op is bound to a symbol in the rewrite rule, query its result
1468     // count from the symbol info map.
1469     auto symbol = node.getSymbol();
1470     if (!symbol.empty()) {
1471       return symbolInfoMap.getStaticValueCount(symbol);
1472     }
1473     // Otherwise this is an unbound op; we will use all its results.
1474     return pattern.getDialectOp(node).getNumResults();
1475   }
1476 
1477   if (node.isNativeCodeCall())
1478     return node.getNumReturnsOfNativeCode();
1479 
1480   return 1;
1481 }
1482 
1483 PatternEmitter::TrailingDirectives
1484 PatternEmitter::getTrailingDirectives(DagNode tree) {
1485   TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0};
1486 
1487   // Look backwards through the arguments.
1488   auto numPatArgs = tree.getNumArgs();
1489   for (int i = numPatArgs - 1; i >= 0; --i) {
1490     auto dagArg = tree.getArgAsNestedDag(i);
1491     // A leaf is not a directive. Stop looking.
1492     if (!dagArg)
1493       break;
1494 
1495     auto isLocation = dagArg.isLocationDirective();
1496     auto isReturnType = dagArg.isReturnTypeDirective();
1497     // If encountered a DAG node that isn't a trailing directive, stop looking.
1498     if (!(isLocation || isReturnType))
1499       break;
1500     // Save the directive, but error if one of the same type was already
1501     // found.
1502     ++tail.numDirectives;
1503     if (isLocation) {
1504       if (tail.location)
1505         PrintFatalError(loc, "`location` directive can only be specified "
1506                              "once");
1507       tail.location = dagArg;
1508     } else if (isReturnType) {
1509       if (tail.returnType)
1510         PrintFatalError(loc, "`returnType` directive can only be specified "
1511                              "once");
1512       tail.returnType = dagArg;
1513     }
1514   }
1515 
1516   return tail;
1517 }
1518 
1519 std::string
1520 PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
1521   if (tail.location)
1522     return handleLocationDirective(tail.location);
1523 
1524   // If no explicit location is given, use the default, all fused, location.
1525   return "odsLoc";
1526 }
1527 
1528 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
1529                                              int depth) {
1530   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
1531   LLVM_DEBUG(tree.print(llvm::dbgs()));
1532   LLVM_DEBUG(llvm::dbgs() << '\n');
1533 
1534   Operator &resultOp = tree.getDialectOp(opMap);
1535   auto numOpArgs = resultOp.getNumArgs();
1536   auto numPatArgs = tree.getNumArgs();
1537 
1538   auto tail = getTrailingDirectives(tree);
1539   auto locToUse = getLocation(tail);
1540 
1541   auto inPattern = numPatArgs - tail.numDirectives;
1542   if (numOpArgs != inPattern) {
1543     PrintFatalError(loc,
1544                     formatv("resultant op '{0}' argument number mismatch: "
1545                             "{1} in pattern vs. {2} in definition",
1546                             resultOp.getOperationName(), inPattern, numOpArgs));
1547   }
1548 
1549   // A map to collect all nested DAG child nodes' names, with operand index as
1550   // the key. This includes both bound and unbound child nodes.
1551   ChildNodeIndexNameMap childNodeNames;
1552 
1553   // If the argument is a type constraint, then its an operand. Check if the
1554   // op's argument is variadic that the argument in the pattern is too.
1555   auto checkIfMatchedVariadic = [&](int i) {
1556     // FIXME: This does not yet check for variable/leaf case.
1557     // FIXME: Change so that native code call can be handled.
1558     const auto *operand =
1559         llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(i));
1560     if (!operand || !operand->isVariadic())
1561       return;
1562 
1563     auto child = tree.getArgAsNestedDag(i);
1564     if (!child)
1565       return;
1566 
1567     // Skip over replaceWithValues.
1568     while (child.isReplaceWithValue()) {
1569       if (!(child = child.getArgAsNestedDag(0)))
1570         return;
1571     }
1572     if (!child.isNativeCodeCall() && !child.isVariadic())
1573       PrintFatalError(loc, formatv("op expects variadic operand `{0}`, while "
1574                                    "provided is non-variadic",
1575                                    resultOp.getArgName(i)));
1576   };
1577 
1578   // First go through all the child nodes who are nested DAG constructs to
1579   // create ops for them and remember the symbol names for them, so that we can
1580   // use the results in the current node. This happens in a recursive manner.
1581   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1582     checkIfMatchedVariadic(i);
1583     if (auto child = tree.getArgAsNestedDag(i))
1584       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
1585   }
1586 
1587   // The name of the local variable holding this op.
1588   std::string valuePackName;
1589   // The symbol for holding the result of this pattern. Note that the result of
1590   // this pattern is not necessarily the same as the variable created by this
1591   // pattern because we can use `__N` suffix to refer only a specific result if
1592   // the generated op is a multi-result op.
1593   std::string resultValue;
1594   if (tree.getSymbol().empty()) {
1595     // No symbol is explicitly bound to this op in the pattern. Generate a
1596     // unique name.
1597     valuePackName = resultValue = getUniqueSymbol(&resultOp);
1598   } else {
1599     resultValue = std::string(tree.getSymbol());
1600     // Strip the index to get the name for the value pack and use it to name the
1601     // local variable for the op.
1602     valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
1603   }
1604 
1605   // Create the local variable for this op.
1606   os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
1607                 valuePackName);
1608 
1609   // Right now ODS don't have general type inference support. Except a few
1610   // special cases listed below, DRR needs to supply types for all results
1611   // when building an op.
1612   bool isSameOperandsAndResultType =
1613       resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
1614   bool useFirstAttr =
1615       resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
1616 
1617   if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) {
1618     // We know how to deduce the result type for ops with these traits and we've
1619     // generated builders taking aggregate parameters. Use those builders to
1620     // create the ops.
1621 
1622     // First prepare local variables for op arguments used in builder call.
1623     createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1624 
1625     // Then create the op.
1626     os.scope("", "\n}\n").os << formatv(
1627         "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1628         valuePackName, resultOp.getQualCppClassName(), locToUse);
1629     return resultValue;
1630   }
1631 
1632   bool usePartialResults = valuePackName != resultValue;
1633 
1634   if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) {
1635     // For these cases (broadcastable ops, op results used both as auxiliary
1636     // values and replacement values, ops in nested patterns, auxiliary ops), we
1637     // still need to supply the result types when building the op. But because
1638     // we don't generate a builder automatically with ODS for them, it's the
1639     // developer's responsibility to make sure such a builder (with result type
1640     // deduction ability) exists. We go through the separate-parameter builder
1641     // here given that it's easier for developers to write compared to
1642     // aggregate-parameter builders.
1643     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
1644 
1645     os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
1646                              resultOp.getQualCppClassName(), locToUse);
1647     supplyValuesForOpArgs(tree, childNodeNames, depth);
1648     os << "\n  );\n}\n";
1649     return resultValue;
1650   }
1651 
1652   // If we are provided explicit return types, use them to build the op.
1653   // However, if depth == 0 and resultIndex >= 0, it means we are replacing
1654   // the values generated from the source pattern root op. Then we must use the
1655   // source pattern's value types to determine the value type of the generated
1656   // op here.
1657   if (depth == 0 && resultIndex >= 0 && tail.returnType)
1658     PrintFatalError(loc, "Cannot specify explicit return types in an op whose "
1659                          "return values replace the source pattern's root op");
1660 
1661   // First prepare local variables for op arguments used in builder call.
1662   createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1663 
1664   // Then prepare the result types. We need to specify the types for all
1665   // results.
1666   os.indent() << formatv("::llvm::SmallVector<::mlir::Type, 4> tblgen_types; "
1667                          "(void)tblgen_types;\n");
1668   int numResults = resultOp.getNumResults();
1669   if (tail.returnType) {
1670     auto numRetTys = tail.returnType.getNumArgs();
1671     for (int i = 0; i < numRetTys; ++i) {
1672       auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1);
1673       os << "tblgen_types.push_back(" << varName << ");\n";
1674     }
1675   } else {
1676     if (numResults != 0) {
1677       // Copy the result types from the source pattern.
1678       for (int i = 0; i < numResults; ++i)
1679         os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
1680                       "  tblgen_types.push_back(v.getType());\n}\n",
1681                       resultIndex + i);
1682     }
1683   }
1684   os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
1685                 "tblgen_values, tblgen_attrs);\n",
1686                 valuePackName, resultOp.getQualCppClassName(), locToUse);
1687   os.unindent() << "}\n";
1688   return resultValue;
1689 }
1690 
1691 void PatternEmitter::createSeparateLocalVarsForOpArgs(
1692     DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1693   Operator &resultOp = node.getDialectOp(opMap);
1694 
1695   // Now prepare operands used for building this op:
1696   // * If the operand is non-variadic, we create a `Value` local variable.
1697   // * If the operand is variadic, we create a `SmallVector<Value>` local
1698   //   variable.
1699 
1700   int valueIndex = 0; // An index for uniquing local variable names.
1701   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1702     const auto *operand =
1703         llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(argIndex));
1704     // We do not need special handling for attributes.
1705     if (!operand)
1706       continue;
1707 
1708     raw_indented_ostream::DelimitedScope scope(os);
1709     std::string varName;
1710     if (operand->isVariadic()) {
1711       varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
1712       os << formatv("::llvm::SmallVector<::mlir::Value, 4> {0};\n", varName);
1713       std::string range;
1714       if (node.isNestedDagArg(argIndex)) {
1715         range = childNodeNames[argIndex];
1716       } else {
1717         range = std::string(node.getArgName(argIndex));
1718       }
1719       // Resolve the symbol for all range use so that we have a uniform way of
1720       // capturing the values.
1721       range = symbolInfoMap.getValueAndRangeUse(range);
1722       os << formatv("for (auto v: {0}) {{\n  {1}.push_back(v);\n}\n", range,
1723                     varName);
1724     } else {
1725       varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
1726       os << formatv("::mlir::Value {0} = ", varName);
1727       if (node.isNestedDagArg(argIndex)) {
1728         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
1729       } else {
1730         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1731         auto symbol =
1732             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1733         if (leaf.isNativeCodeCall()) {
1734           os << std::string(
1735               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1736         } else {
1737           os << symbol;
1738         }
1739       }
1740       os << ";\n";
1741     }
1742 
1743     // Update to use the newly created local variable for building the op later.
1744     childNodeNames[argIndex] = varName;
1745   }
1746 }
1747 
1748 void PatternEmitter::supplyValuesForOpArgs(
1749     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1750   Operator &resultOp = node.getDialectOp(opMap);
1751   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1752        argIndex != numOpArgs; ++argIndex) {
1753     // Start each argument on its own line.
1754     os << ",\n    ";
1755 
1756     Argument opArg = resultOp.getArg(argIndex);
1757     // Handle the case of operand first.
1758     if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
1759       if (!operand->name.empty())
1760         os << "/*" << operand->name << "=*/";
1761       os << childNodeNames.lookup(argIndex);
1762       continue;
1763     }
1764 
1765     // The argument in the op definition.
1766     auto opArgName = resultOp.getArgName(argIndex);
1767     if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1768       if (!subTree.isNativeCodeCall())
1769         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1770                              "for creating attribute");
1771       os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
1772     } else {
1773       auto leaf = node.getArgAsLeaf(argIndex);
1774       // The argument in the result DAG pattern.
1775       auto patArgName = node.getArgName(argIndex);
1776       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1777         // TODO: Refactor out into map to avoid recomputing these.
1778         if (!isa<NamedAttribute *>(opArg))
1779           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1780         if (!patArgName.empty())
1781           os << "/*" << patArgName << "=*/";
1782       } else {
1783         os << "/*" << opArgName << "=*/";
1784       }
1785       os << handleOpArgument(leaf, patArgName);
1786     }
1787   }
1788 }
1789 
1790 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1791     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1792   Operator &resultOp = node.getDialectOp(opMap);
1793 
1794   auto scope = os.scope();
1795   os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
1796                 "tblgen_values; (void)tblgen_values;\n");
1797   os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1798                 "tblgen_attrs; (void)tblgen_attrs;\n");
1799 
1800   const char *addAttrCmd =
1801       "if (auto tmpAttr = {1}) {\n"
1802       "  tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
1803       "tmpAttr);\n}\n";
1804   int numVariadic = 0;
1805   bool hasOperandSegmentSizes = false;
1806   std::vector<std::string> sizes;
1807   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1808     if (isa<NamedAttribute *>(resultOp.getArg(argIndex))) {
1809       // The argument in the op definition.
1810       auto opArgName = resultOp.getArgName(argIndex);
1811       hasOperandSegmentSizes =
1812           hasOperandSegmentSizes || opArgName == "operandSegmentSizes";
1813       if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1814         if (!subTree.isNativeCodeCall())
1815           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1816                                "for creating attribute");
1817         os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
1818       } else {
1819         auto leaf = node.getArgAsLeaf(argIndex);
1820         // The argument in the result DAG pattern.
1821         auto patArgName = node.getArgName(argIndex);
1822         os << formatv(addAttrCmd, opArgName,
1823                       handleOpArgument(leaf, patArgName));
1824       }
1825       continue;
1826     }
1827 
1828     const auto *operand =
1829         cast<NamedTypeConstraint *>(resultOp.getArg(argIndex));
1830     std::string varName;
1831     if (operand->isVariadic()) {
1832       ++numVariadic;
1833       std::string range;
1834       if (node.isNestedDagArg(argIndex)) {
1835         range = childNodeNames.lookup(argIndex);
1836       } else {
1837         range = std::string(node.getArgName(argIndex));
1838       }
1839       // Resolve the symbol for all range use so that we have a uniform way of
1840       // capturing the values.
1841       range = symbolInfoMap.getValueAndRangeUse(range);
1842       os << formatv("for (auto v: {0}) {{\n  tblgen_values.push_back(v);\n}\n",
1843                     range);
1844       sizes.push_back(formatv("static_cast<int32_t>({0}.size())", range));
1845     } else {
1846       sizes.emplace_back("1");
1847       os << formatv("tblgen_values.push_back(");
1848       if (node.isNestedDagArg(argIndex)) {
1849         os << symbolInfoMap.getValueAndRangeUse(
1850             childNodeNames.lookup(argIndex));
1851       } else {
1852         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1853         if (leaf.isConstantAttr())
1854           // TODO: Use better location
1855           PrintFatalError(
1856               loc,
1857               "attribute found where value was expected, if attempting to use "
1858               "constant value, construct a constant op with given attribute "
1859               "instead");
1860 
1861         auto symbol =
1862             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1863         if (leaf.isNativeCodeCall()) {
1864           os << std::string(
1865               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1866         } else {
1867           os << symbol;
1868         }
1869       }
1870       os << ");\n";
1871     }
1872   }
1873 
1874   if (numVariadic > 1 && !hasOperandSegmentSizes) {
1875     // Only set size if it can't be computed.
1876     const auto *sameVariadicSize =
1877         resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
1878     if (!sameVariadicSize) {
1879       const char *setSizes = R"(
1880         tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
1881           rewriter.getDenseI32ArrayAttr({{ {0} }));
1882           )";
1883       os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
1884     }
1885   }
1886 }
1887 
1888 StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
1889                                          const RecordKeeper &records,
1890                                          RecordOperatorMap &mapper)
1891     : opMap(mapper), staticVerifierEmitter(os, records) {}
1892 
1893 void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
1894   // PatternEmitter will use the static matcher if there's one generated. To
1895   // ensure that all the dependent static matchers are generated before emitting
1896   // the matching logic of the DagNode, we use topological order to achieve it.
1897   for (auto &dagInfo : topologicalOrder) {
1898     DagNode node = dagInfo.first;
1899     if (!useStaticMatcher(node))
1900       continue;
1901 
1902     std::string funcName =
1903         formatv("static_dag_matcher_{0}", staticMatcherCounter++);
1904     assert(!matcherNames.contains(node));
1905     PatternEmitter(dagInfo.second, &opMap, os, *this)
1906         .emitStaticMatcher(node, funcName);
1907     matcherNames[node] = funcName;
1908   }
1909 }
1910 
1911 void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
1912   staticVerifierEmitter.emitPatternConstraints(constraints.getArrayRef());
1913 }
1914 
1915 void StaticMatcherHelper::addPattern(const Record *record) {
1916   Pattern pat(record, &opMap);
1917 
1918   // While generating the function body of the DAG matcher, it may depends on
1919   // other DAG matchers. To ensure the dependent matchers are ready, we compute
1920   // the topological order for all the DAGs and emit the DAG matchers in this
1921   // order.
1922   llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
1923     ++refStats[node];
1924 
1925     if (refStats[node] != 1)
1926       return;
1927 
1928     for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
1929       if (DagNode sibling = node.getArgAsNestedDag(i))
1930         dfs(sibling);
1931       else {
1932         DagLeaf leaf = node.getArgAsLeaf(i);
1933         if (!leaf.isUnspecified())
1934           constraints.insert(leaf);
1935       }
1936 
1937     topologicalOrder.push_back(std::make_pair(node, record));
1938   };
1939 
1940   dfs(pat.getSourcePattern());
1941 }
1942 
1943 StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
1944   if (leaf.isAttrMatcher()) {
1945     std::optional<StringRef> constraint =
1946         staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint());
1947     assert(constraint && "attribute constraint was not uniqued");
1948     return *constraint;
1949   }
1950   assert(leaf.isOperandMatcher());
1951   return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint());
1952 }
1953 
1954 static void emitRewriters(const RecordKeeper &records, raw_ostream &os) {
1955   emitSourceFileHeader("Rewriters", os, records);
1956 
1957   auto patterns = records.getAllDerivedDefinitions("Pattern");
1958 
1959   // We put the map here because it can be shared among multiple patterns.
1960   RecordOperatorMap recordOpMap;
1961 
1962   // Exam all the patterns and generate static matcher for the duplicated
1963   // DagNode.
1964   StaticMatcherHelper staticMatcher(os, records, recordOpMap);
1965   for (const Record *p : patterns)
1966     staticMatcher.addPattern(p);
1967   staticMatcher.populateStaticConstraintFunctions(os);
1968   staticMatcher.populateStaticMatchers(os);
1969 
1970   std::vector<std::string> rewriterNames;
1971   rewriterNames.reserve(patterns.size());
1972 
1973   std::string baseRewriterName = "GeneratedConvert";
1974   int rewriterIndex = 0;
1975 
1976   for (const Record *p : patterns) {
1977     std::string name;
1978     if (p->isAnonymous()) {
1979       // If no name is provided, ensure unique rewriter names simply by
1980       // appending unique suffix.
1981       name = baseRewriterName + llvm::utostr(rewriterIndex++);
1982     } else {
1983       name = std::string(p->getName());
1984     }
1985     LLVM_DEBUG(llvm::dbgs()
1986                << "=== start generating pattern '" << name << "' ===\n");
1987     PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name);
1988     LLVM_DEBUG(llvm::dbgs()
1989                << "=== done generating pattern '" << name << "' ===\n");
1990     rewriterNames.push_back(std::move(name));
1991   }
1992 
1993   // Emit function to add the generated matchers to the pattern list.
1994   os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
1995         "::mlir::RewritePatternSet &patterns) {\n";
1996   for (const auto &name : rewriterNames) {
1997     os << "  patterns.add<" << name << ">(patterns.getContext());\n";
1998   }
1999   os << "}\n";
2000 }
2001 
2002 static mlir::GenRegistration
2003     genRewriters("gen-rewriters", "Generate pattern rewriters",
2004                  [](const RecordKeeper &records, raw_ostream &os) {
2005                    emitRewriters(records, os);
2006                    return false;
2007                  });
2008