xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 830b9b072d8458ee89c48f00d4de59456c9f467f)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Operator.h"
19 #include "mlir/TableGen/Pattern.h"
20 #include "mlir/TableGen/Predicate.h"
21 #include "mlir/TableGen/Type.h"
22 #include "llvm/ADT/FunctionExtras.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/FormatAdapters.h"
29 #include "llvm/Support/PrettyStackTrace.h"
30 #include "llvm/Support/Signals.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Main.h"
33 #include "llvm/TableGen/Record.h"
34 #include "llvm/TableGen/TableGenBackend.h"
35 
36 using namespace mlir;
37 using namespace mlir::tblgen;
38 
39 using llvm::formatv;
40 using llvm::Record;
41 using llvm::RecordKeeper;
42 
43 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
44 
45 namespace llvm {
46 template <>
47 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
48   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
49                      raw_ostream &os, StringRef style) {
50     os << v.first << ":" << v.second;
51   }
52 };
53 } // namespace llvm
54 
55 //===----------------------------------------------------------------------===//
56 // PatternEmitter
57 //===----------------------------------------------------------------------===//
58 
59 namespace {
60 
61 class StaticMatcherHelper;
62 
63 class PatternEmitter {
64 public:
65   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
66                  StaticMatcherHelper &helper);
67 
68   // Emits the mlir::RewritePattern struct named `rewriteName`.
69   void emit(StringRef rewriteName);
70 
71   // Emits the static function of DAG matcher.
72   void emitStaticMatcher(DagNode tree, std::string funcName);
73 
74 private:
75   // Emits the code for matching ops.
76   void emitMatchLogic(DagNode tree, StringRef opName);
77 
78   // Emits the code for rewriting ops.
79   void emitRewriteLogic();
80 
81   //===--------------------------------------------------------------------===//
82   // Match utilities
83   //===--------------------------------------------------------------------===//
84 
85   // Emits C++ statements for matching the DAG structure.
86   void emitMatch(DagNode tree, StringRef name, int depth);
87 
88   // Emit C++ function call to static DAG matcher.
89   void emitStaticMatchCall(DagNode tree, StringRef name);
90 
91   // Emit C++ function call to static type/attribute constraint function.
92   void emitStaticVerifierCall(StringRef funcName, StringRef opName,
93                               StringRef arg, StringRef failureStr);
94 
95   // Emits C++ statements for matching using a native code call.
96   void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
97 
98   // Emits C++ statements for matching the op constrained by the given DAG
99   // `tree` returning the op's variable name.
100   void emitOpMatch(DagNode tree, StringRef opName, int depth);
101 
102   // Emits C++ statements for matching the `argIndex`-th argument of the given
103   // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
104   // bound name and the constraint of the operand respectively.
105   void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
106                         int operandIndex, DagLeaf operandMatcher,
107                         StringRef argName, int argIndex,
108                         std::optional<int> variadicSubIndex);
109 
110   // Emits C++ statements for matching the operands which can be matched in
111   // either order.
112   void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
113                               StringRef opName, int argIndex, int &operandIndex,
114                               int depth);
115 
116   // Emits C++ statements for matching a variadic operand.
117   void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree,
118                                 StringRef opName, int argIndex,
119                                 int &operandIndex, int depth);
120 
121   // Emits C++ statements for matching the `argIndex`-th argument of the given
122   // DAG `tree` as an attribute.
123   void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
124                           int depth);
125 
126   // Emits C++ for checking a match with a corresponding match failure
127   // diagnostic.
128   void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
129                       const llvm::formatv_object_base &failureFmt);
130 
131   // Emits C++ for checking a match with a corresponding match failure
132   // diagnostics.
133   void emitMatchCheck(StringRef opName, const std::string &matchStr,
134                       const std::string &failureStr);
135 
136   //===--------------------------------------------------------------------===//
137   // Rewrite utilities
138   //===--------------------------------------------------------------------===//
139 
140   // The entry point for handling a result pattern rooted at `resultTree`. This
141   // method dispatches to concrete handlers according to `resultTree`'s kind and
142   // returns a symbol representing the whole value pack. Callers are expected to
143   // further resolve the symbol according to the specific use case.
144   //
145   // `depth` is the nesting level of `resultTree`; 0 means top-level result
146   // pattern. For top-level result pattern, `resultIndex` indicates which result
147   // of the matched root op this pattern is intended to replace, which can be
148   // used to deduce the result type of the op generated from this result
149   // pattern.
150   std::string handleResultPattern(DagNode resultTree, int resultIndex,
151                                   int depth);
152 
153   // Emits the C++ statement to replace the matched DAG with a value built via
154   // calling native C++ code.
155   std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
156 
157   // Returns the symbol of the old value serving as the replacement.
158   StringRef handleReplaceWithValue(DagNode tree);
159 
160   // Trailing directives are used at the end of DAG node argument lists to
161   // specify additional behaviour for op matchers and creators, etc.
162   struct TrailingDirectives {
163     // DAG node containing the `location` directive. Null if there is none.
164     DagNode location;
165 
166     // DAG node containing the `returnType` directive. Null if there is none.
167     DagNode returnType;
168 
169     // Number of found trailing directives.
170     int numDirectives;
171   };
172 
173   // Collect any trailing directives.
174   TrailingDirectives getTrailingDirectives(DagNode tree);
175 
176   // Returns the location value to use.
177   std::string getLocation(TrailingDirectives &tail);
178 
179   // Returns the location value to use.
180   std::string handleLocationDirective(DagNode tree);
181 
182   // Emit return type argument.
183   std::string handleReturnTypeArg(DagNode returnType, int i, int depth);
184 
185   // Emits the C++ statement to build a new op out of the given DAG `tree` and
186   // returns the variable name that this op is assigned to. If the root op in
187   // DAG `tree` has a specified name, the created op will be assigned to a
188   // variable of the given name. Otherwise, a unique name will be used as the
189   // result value name.
190   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
191 
192   using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
193 
194   // Emits a local variable for each value and attribute to be used for creating
195   // an op.
196   void createSeparateLocalVarsForOpArgs(DagNode node,
197                                         ChildNodeIndexNameMap &childNodeNames);
198 
199   // Emits the concrete arguments used to call an op's builder.
200   void supplyValuesForOpArgs(DagNode node,
201                              const ChildNodeIndexNameMap &childNodeNames,
202                              int depth);
203 
204   // Emits the local variables for holding all values as a whole and all named
205   // attributes as a whole to be used for creating an op.
206   void createAggregateLocalVarsForOpArgs(
207       DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
208 
209   // Returns the C++ expression to construct a constant attribute of the given
210   // `value` for the given attribute kind `attr`.
211   std::string handleConstantAttr(Attribute attr, const Twine &value);
212 
213   // Returns the C++ expression to build an argument from the given DAG `leaf`.
214   // `patArgName` is used to bound the argument to the source pattern.
215   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
216 
217   //===--------------------------------------------------------------------===//
218   // General utilities
219   //===--------------------------------------------------------------------===//
220 
221   // Collects all of the operations within the given dag tree.
222   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
223 
224   // Returns a unique symbol for a local variable of the given `op`.
225   std::string getUniqueSymbol(const Operator *op);
226 
227   //===--------------------------------------------------------------------===//
228   // Symbol utilities
229   //===--------------------------------------------------------------------===//
230 
231   // Returns how many static values the given DAG `node` correspond to.
232   int getNodeValueCount(DagNode node);
233 
234 private:
235   // Pattern instantiation location followed by the location of multiclass
236   // prototypes used. This is intended to be used as a whole to
237   // PrintFatalError() on errors.
238   ArrayRef<SMLoc> loc;
239 
240   // Op's TableGen Record to wrapper object.
241   RecordOperatorMap *opMap;
242 
243   // Handy wrapper for pattern being emitted.
244   Pattern pattern;
245 
246   // Map for all bound symbols' info.
247   SymbolInfoMap symbolInfoMap;
248 
249   StaticMatcherHelper &staticMatcherHelper;
250 
251   // The next unused ID for newly created values.
252   unsigned nextValueId = 0;
253 
254   raw_indented_ostream os;
255 
256   // Format contexts containing placeholder substitutions.
257   FmtContext fmtCtx;
258 };
259 
260 // Tracks DagNode's reference multiple times across patterns. Enables generating
261 // static matcher functions for DagNode's referenced multiple times rather than
262 // inlining them.
263 class StaticMatcherHelper {
264 public:
265   StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper,
266                       RecordOperatorMap &mapper);
267 
268   // Determine if we should inline the match logic or delegate to a static
269   // function.
270   bool useStaticMatcher(DagNode node) {
271     // either/variadic node must be associated to the parentOp, thus we can't
272     // emit a static matcher rooted at them.
273     if (node.isEither() || node.isVariadic())
274       return false;
275 
276     return refStats[node] > kStaticMatcherThreshold;
277   }
278 
279   // Get the name of the static DAG matcher function corresponding to the node.
280   std::string getMatcherName(DagNode node) {
281     assert(useStaticMatcher(node));
282     return matcherNames[node];
283   }
284 
285   // Get the name of static type/attribute verification function.
286   StringRef getVerifierName(DagLeaf leaf);
287 
288   // Collect the `Record`s, i.e., the DRR, so that we can get the information of
289   // the duplicated DAGs.
290   void addPattern(Record *record);
291 
292   // Emit all static functions of DAG Matcher.
293   void populateStaticMatchers(raw_ostream &os);
294 
295   // Emit all static functions for Constraints.
296   void populateStaticConstraintFunctions(raw_ostream &os);
297 
298 private:
299   static constexpr unsigned kStaticMatcherThreshold = 1;
300 
301   // Consider two patterns as down below,
302   //   DagNode_Root_A    DagNode_Root_B
303   //       \                 \
304   //     DagNode_C         DagNode_C
305   //         \                 \
306   //       DagNode_D         DagNode_D
307   //
308   // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of
309   // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced
310   // multiple times so we'll have static matchers for both of them. When we're
311   // emitting the match logic for DagNode_C, we will check if DagNode_D has the
312   // static matcher generated. If so, then we'll generate a call to the
313   // function, inline otherwise. In this case, inlining is not what we want. As
314   // a result, generate the static matcher in topological order to ensure all
315   // the dependent static matchers are generated and we can avoid accidentally
316   // inlining.
317   //
318   // The topological order of all the DagNodes among all patterns.
319   SmallVector<std::pair<DagNode, Record *>> topologicalOrder;
320 
321   RecordOperatorMap &opMap;
322 
323   // Records of the static function name of each DagNode
324   DenseMap<DagNode, std::string> matcherNames;
325 
326   // After collecting all the DagNode in each pattern, `refStats` records the
327   // number of users for each DagNode. We will generate the static matcher for a
328   // DagNode while the number of users exceeds a certain threshold.
329   DenseMap<DagNode, unsigned> refStats;
330 
331   // Number of static matcher generated. This is used to generate a unique name
332   // for each DagNode.
333   int staticMatcherCounter = 0;
334 
335   // The DagLeaf which contains type or attr constraint.
336   SetVector<DagLeaf> constraints;
337 
338   // Static type/attribute verification function emitter.
339   StaticVerifierFunctionEmitter staticVerifierEmitter;
340 };
341 
342 } // namespace
343 
344 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
345                                raw_ostream &os, StaticMatcherHelper &helper)
346     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
347       symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) {
348   fmtCtx.withBuilder("rewriter");
349 }
350 
351 std::string PatternEmitter::handleConstantAttr(Attribute attr,
352                                                const Twine &value) {
353   if (!attr.isConstBuildable())
354     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
355                              " does not have the 'constBuilderCall' field");
356 
357   // TODO: Verify the constants here
358   return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
359 }
360 
361 void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
362   os << formatv(
363       "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
364       "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
365       "*, 4> &tblgen_ops",
366       funcName);
367 
368   // We pass the reference of the variables that need to be captured. Hence we
369   // need to collect all the symbols in the tree first.
370   pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true);
371   symbolInfoMap.assignUniqueAlternativeNames();
372   for (const auto &info : symbolInfoMap)
373     os << formatv(", {0}", info.second.getArgDecl(info.first));
374 
375   os << ") {\n";
376   os.indent();
377   os << "(void)tblgen_ops;\n";
378 
379   // Note that a static matcher is considered at least one step from the match
380   // entry.
381   emitMatch(tree, "op0", /*depth=*/1);
382 
383   os << "return ::mlir::success();\n";
384   os.unindent();
385   os << "}\n\n";
386 }
387 
388 // Helper function to match patterns.
389 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
390   if (tree.isNativeCodeCall()) {
391     emitNativeCodeMatch(tree, name, depth);
392     return;
393   }
394 
395   if (tree.isOperation()) {
396     emitOpMatch(tree, name, depth);
397     return;
398   }
399 
400   PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
401 }
402 
403 void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
404   std::string funcName = staticMatcherHelper.getMatcherName(tree);
405   os << formatv("if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", funcName,
406                 opName);
407 
408   // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in
409   // one pass.
410 
411   // In general, bound symbol should have the unique name in the pattern but
412   // for the operand, binding same symbol to multiple operands imply a
413   // constraint at the same time. In this case, we will rename those operands
414   // with different names. As a result, we need to collect all the symbolInfos
415   // from the DagNode then get the updated name of the local variables from the
416   // global symbolInfoMap.
417 
418   // Collect all the bound symbols in the Dag
419   SymbolInfoMap localSymbolMap(loc);
420   pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true);
421 
422   for (const auto &info : localSymbolMap) {
423     auto name = info.first;
424     auto symboInfo = info.second;
425     auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo);
426     os << formatv(", {0}", ret->second.getVarName(name));
427   }
428 
429   os << "))) {\n";
430   os.scope().os << "return ::mlir::failure();\n";
431   os << "}\n";
432 }
433 
434 void PatternEmitter::emitStaticVerifierCall(StringRef funcName,
435                                             StringRef opName, StringRef arg,
436                                             StringRef failureStr) {
437   os << formatv("if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n",
438                 funcName, opName, arg, failureStr);
439   os.scope().os << "return ::mlir::failure();\n";
440   os << "}\n";
441 }
442 
443 // Helper function to match patterns.
444 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
445                                          int depth) {
446   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
447   LLVM_DEBUG(tree.print(llvm::dbgs()));
448   LLVM_DEBUG(llvm::dbgs() << '\n');
449 
450   // The order of generating static matcher follows the topological order so
451   // that for every dependent DagNode already have their static matcher
452   // generated if needed. The reason we check if `getMatcherName(tree).empty()`
453   // is when we are generating the static matcher for a DagNode itself. In this
454   // case, we need to emit the function body rather than a function call.
455   if (staticMatcherHelper.useStaticMatcher(tree) &&
456       !staticMatcherHelper.getMatcherName(tree).empty()) {
457     emitStaticMatchCall(tree, opName);
458 
459     // NativeCodeCall will never be at depth 0 so that we don't need to catch
460     // the root operation as emitOpMatch();
461 
462     return;
463   }
464 
465   // TODO(suderman): iterate through arguments, determine their types, output
466   // names.
467   SmallVector<std::string, 8> capture;
468 
469   raw_indented_ostream::DelimitedScope scope(os);
470 
471   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
472     std::string argName = formatv("arg{0}_{1}", depth, i);
473     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
474       if (argTree.isEither())
475         PrintFatalError(loc, "NativeCodeCall cannot have `either` operands");
476       if (argTree.isVariadic())
477         PrintFatalError(loc, "NativeCodeCall cannot have `variadic` operands");
478 
479       os << "::mlir::Value " << argName << ";\n";
480     } else {
481       auto leaf = tree.getArgAsLeaf(i);
482       if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
483         os << "::mlir::Attribute " << argName << ";\n";
484       } else {
485         os << "::mlir::Value " << argName << ";\n";
486       }
487     }
488 
489     capture.push_back(std::move(argName));
490   }
491 
492   auto tail = getTrailingDirectives(tree);
493   if (tail.returnType)
494     PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
495   auto locToUse = getLocation(tail);
496 
497   auto fmt = tree.getNativeCodeTemplate();
498   if (fmt.count("$_self") != 1)
499     PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
500                          "passing the defining Operation");
501 
502   auto nativeCodeCall = std::string(
503       tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()),
504             static_cast<ArrayRef<std::string>>(capture)));
505 
506   emitMatchCheck(opName, formatv("!::mlir::failed({0})", nativeCodeCall),
507                  formatv("\"{0} return ::mlir::failure\"", nativeCodeCall));
508 
509   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
510     auto name = tree.getArgName(i);
511     if (!name.empty() && name != "_") {
512       os << formatv("{0} = {1};\n", name, capture[i]);
513     }
514   }
515 
516   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
517     std::string argName = capture[i];
518 
519     // Handle nested DAG construct first
520     if (tree.getArgAsNestedDag(i)) {
521       PrintFatalError(
522           loc, formatv("Matching nested tree in NativeCodecall not support for "
523                        "{0} as arg {1}",
524                        argName, i));
525     }
526 
527     DagLeaf leaf = tree.getArgAsLeaf(i);
528 
529     // The parameter for native function doesn't bind any constraints.
530     if (leaf.isUnspecified())
531       continue;
532 
533     auto constraint = leaf.getAsConstraint();
534 
535     std::string self;
536     if (leaf.isAttrMatcher() || leaf.isConstantAttr())
537       self = argName;
538     else
539       self = formatv("{0}.getType()", argName);
540     StringRef verifier = staticMatcherHelper.getVerifierName(leaf);
541     emitStaticVerifierCall(
542         verifier, opName, self,
543         formatv("\"operand {0} of native code call '{1}' failed to satisfy "
544                 "constraint: "
545                 "'{2}'\"",
546                 i, tree.getNativeCodeTemplate(),
547                 escapeString(constraint.getSummary()))
548             .str());
549   }
550 
551   LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
552 }
553 
554 // Helper function to match patterns.
555 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
556   Operator &op = tree.getDialectOp(opMap);
557   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
558                           << op.getOperationName() << "' at depth " << depth
559                           << '\n');
560 
561   auto getCastedName = [depth]() -> std::string {
562     return formatv("castedOp{0}", depth);
563   };
564 
565   // The order of generating static matcher follows the topological order so
566   // that for every dependent DagNode already have their static matcher
567   // generated if needed. The reason we check if `getMatcherName(tree).empty()`
568   // is when we are generating the static matcher for a DagNode itself. In this
569   // case, we need to emit the function body rather than a function call.
570   if (staticMatcherHelper.useStaticMatcher(tree) &&
571       !staticMatcherHelper.getMatcherName(tree).empty()) {
572     emitStaticMatchCall(tree, opName);
573     // In the codegen of rewriter, we suppose that castedOp0 will capture the
574     // root operation. Manually add it if the root DagNode is a static matcher.
575     if (depth == 0)
576       os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); "
577                     "(void){2};\n",
578                     opName, op.getQualCppClassName(), getCastedName());
579     return;
580   }
581 
582   std::string castedName = getCastedName();
583   os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
584                 "(void){0};\n",
585                 castedName, opName, op.getQualCppClassName());
586 
587   // Skip the operand matching at depth 0 as the pattern rewriter already does.
588   if (depth != 0)
589     emitMatchCheck(opName, /*matchStr=*/castedName,
590                    formatv("\"{0} is not {1} type\"", castedName,
591                            op.getQualCppClassName()));
592 
593   // If the operand's name is set, set to that variable.
594   auto name = tree.getSymbol();
595   if (!name.empty())
596     os << formatv("{0} = {1};\n", name, castedName);
597 
598   for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;
599        ++i, ++opArgIdx) {
600     auto opArg = op.getArg(opArgIdx);
601     std::string argName = formatv("op{0}", depth + 1);
602 
603     // Handle nested DAG construct first
604     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
605       if (argTree.isEither()) {
606         emitEitherOperandMatch(tree, argTree, castedName, opArgIdx, nextOperand,
607                                depth);
608         ++opArgIdx;
609         continue;
610       }
611       if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
612         if (argTree.isVariadic()) {
613           if (!operand->isVariadic()) {
614             auto error = formatv("variadic DAG construct can't match op {0}'s "
615                                  "non-variadic operand #{1}",
616                                  op.getOperationName(), opArgIdx);
617             PrintFatalError(loc, error);
618           }
619           emitVariadicOperandMatch(tree, argTree, castedName, opArgIdx,
620                                    nextOperand, depth);
621           ++nextOperand;
622           continue;
623         }
624         if (operand->isVariableLength()) {
625           auto error = formatv("use nested DAG construct to match op {0}'s "
626                                "variadic operand #{1} unsupported now",
627                                op.getOperationName(), opArgIdx);
628           PrintFatalError(loc, error);
629         }
630       }
631 
632       os << "{\n";
633 
634       // Attributes don't count for getODSOperands.
635       // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
636       os.indent() << formatv(
637           "auto *{0} = "
638           "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
639           argName, castedName, nextOperand);
640       // Null check of operand's definingOp
641       emitMatchCheck(
642           castedName, /*matchStr=*/argName,
643           formatv("\"There's no operation that defines operand {0} of {1}\"",
644                   nextOperand++, castedName));
645       emitMatch(argTree, argName, depth + 1);
646       os << formatv("tblgen_ops.push_back({0});\n", argName);
647       os.unindent() << "}\n";
648       continue;
649     }
650 
651     // Next handle DAG leaf: operand or attribute
652     if (opArg.is<NamedTypeConstraint *>()) {
653       auto operandName =
654           formatv("{0}.getODSOperands({1})", castedName, nextOperand);
655       emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
656                        /*operandMatcher=*/tree.getArgAsLeaf(i),
657                        /*argName=*/tree.getArgName(i), opArgIdx,
658                        /*variadicSubIndex=*/std::nullopt);
659       ++nextOperand;
660     } else if (opArg.is<NamedAttribute *>()) {
661       emitAttributeMatch(tree, opName, opArgIdx, depth);
662     } else {
663       PrintFatalError(loc, "unhandled case when matching op");
664     }
665   }
666   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
667                           << op.getOperationName() << "' at depth " << depth
668                           << '\n');
669 }
670 
671 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
672                                       StringRef operandName, int operandIndex,
673                                       DagLeaf operandMatcher, StringRef argName,
674                                       int argIndex,
675                                       std::optional<int> variadicSubIndex) {
676   Operator &op = tree.getDialectOp(opMap);
677   auto *operand = op.getArg(operandIndex).get<NamedTypeConstraint *>();
678 
679   // If a constraint is specified, we need to generate C++ statements to
680   // check the constraint.
681   if (!operandMatcher.isUnspecified()) {
682     if (!operandMatcher.isOperandMatcher())
683       PrintFatalError(
684           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
685                        op.getOperationName(), argIndex + 1));
686 
687     // Only need to verify if the matcher's type is different from the one
688     // of op definition.
689     Constraint constraint = operandMatcher.getAsConstraint();
690     if (operand->constraint != constraint) {
691       if (operand->isVariableLength()) {
692         auto error = formatv(
693             "further constrain op {0}'s variadic operand #{1} unsupported now",
694             op.getOperationName(), argIndex);
695         PrintFatalError(loc, error);
696       }
697       auto self = formatv("(*{0}.begin()).getType()", operandName);
698       StringRef verifier = staticMatcherHelper.getVerifierName(operandMatcher);
699       emitStaticVerifierCall(
700           verifier, opName, self.str(),
701           formatv(
702               "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
703               operand - op.operand_begin(), op.getOperationName(),
704               escapeString(constraint.getSummary()))
705               .str());
706     }
707   }
708 
709   // Capture the value
710   // `$_` is a special symbol to ignore op argument matching.
711   if (!argName.empty() && argName != "_") {
712     auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
713                                              variadicSubIndex);
714     if (res == symbolInfoMap.end())
715       PrintFatalError(loc, formatv("symbol not found: {0}", argName));
716 
717     os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName);
718   }
719 }
720 
721 void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
722                                             StringRef opName, int argIndex,
723                                             int &operandIndex, int depth) {
724   constexpr int numEitherArgs = 2;
725   if (eitherArgTree.getNumArgs() != numEitherArgs)
726     PrintFatalError(loc, "`either` only supports grouping two operands");
727 
728   Operator &op = tree.getDialectOp(opMap);
729 
730   std::string codeBuffer;
731   llvm::raw_string_ostream tblgenOps(codeBuffer);
732 
733   std::string lambda = formatv("eitherLambda{0}", depth);
734   os << formatv(
735       "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n",
736       lambda);
737 
738   os.indent();
739 
740   for (int i = 0; i < numEitherArgs; ++i, ++argIndex) {
741     if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) {
742       if (argTree.isEither())
743         PrintFatalError(loc, "either cannot be nested");
744 
745       std::string argName = formatv("local_op_{0}", i).str();
746 
747       os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName,
748                     i);
749 
750       // Indent emitMatchCheck and emitMatch because they declare local
751       // variables.
752       os << "{\n";
753       os.indent();
754 
755       emitMatchCheck(
756           opName, /*matchStr=*/argName,
757           formatv("\"There's no operation that defines operand {0} of {1}\"",
758                   operandIndex++, opName));
759       emitMatch(argTree, argName, depth + 1);
760 
761       os.unindent() << "}\n";
762 
763       // `tblgen_ops` is used to collect the matched operations. In either, we
764       // need to queue the operation only if the matching success. Thus we emit
765       // the code at the end.
766       tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
767     } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
768       emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
769                        operandIndex,
770                        /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
771                        /*argName=*/eitherArgTree.getArgName(i), argIndex,
772                        /*variadicSubIndex=*/std::nullopt);
773       ++operandIndex;
774     } else {
775       PrintFatalError(loc, "either can only be applied on operand");
776     }
777   }
778 
779   os << tblgenOps.str();
780   os << "return ::mlir::success();\n";
781   os.unindent() << "};\n";
782 
783   os << "{\n";
784   os.indent();
785 
786   os << formatv("auto eitherOperand0 = {0}.getODSOperands({1});\n", opName,
787                 operandIndex - 2);
788   os << formatv("auto eitherOperand1 = {0}.getODSOperands({1});\n", opName,
789                 operandIndex - 1);
790 
791   os << formatv("if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && "
792                 "::mlir::failed({0}(eitherOperand1, "
793                 "eitherOperand0)))\n",
794                 lambda);
795   os.indent() << "return ::mlir::failure();\n";
796 
797   os.unindent().unindent() << "}\n";
798 }
799 
800 void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
801                                               DagNode variadicArgTree,
802                                               StringRef opName, int argIndex,
803                                               int &operandIndex, int depth) {
804   Operator &op = tree.getDialectOp(opMap);
805 
806   os << "{\n";
807   os.indent();
808 
809   os << formatv("auto variadic_operand_range = {0}.getODSOperands({1});\n",
810                 opName, operandIndex);
811   os << formatv("if (variadic_operand_range.size() != {0}) "
812                 "return ::mlir::failure();\n",
813                 variadicArgTree.getNumArgs());
814 
815   StringRef variadicTreeName = variadicArgTree.getSymbol();
816   if (!variadicTreeName.empty()) {
817     auto res =
818         symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
819                                       /*variadicSubIndex=*/std::nullopt);
820     if (res == symbolInfoMap.end())
821       PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));
822 
823     os << formatv("{0} = variadic_operand_range;\n",
824                   res->second.getVarName(variadicTreeName));
825   }
826 
827   for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) {
828     if (DagNode argTree = variadicArgTree.getArgAsNestedDag(i)) {
829       if (!argTree.isOperation())
830         PrintFatalError(loc, "variadic only accepts operation sub-dags");
831 
832       os << "{\n";
833       os.indent();
834 
835       std::string argName = formatv("local_op_{0}", i).str();
836       os << formatv("auto *{0} = "
837                     "variadic_operand_range[{1}].getDefiningOp();\n",
838                     argName, i);
839       emitMatchCheck(
840           opName, /*matchStr=*/argName,
841           formatv("\"There's no operation that defines variadic operand "
842                   "{0} (variadic sub-opearnd #{1}) of {2}\"",
843                   operandIndex, i, opName));
844       emitMatch(argTree, argName, depth + 1);
845       os << formatv("tblgen_ops.push_back({0});\n", argName);
846 
847       os.unindent() << "}\n";
848     } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
849       auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i);
850       emitOperandMatch(tree, opName, operandName.str(), operandIndex,
851                        /*operandMatcher=*/variadicArgTree.getArgAsLeaf(i),
852                        /*argName=*/variadicArgTree.getArgName(i), argIndex, i);
853     } else {
854       PrintFatalError(loc, "variadic can only be applied on operand");
855     }
856   }
857 
858   os.unindent() << "}\n";
859 }
860 
861 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
862                                         int argIndex, int depth) {
863   Operator &op = tree.getDialectOp(opMap);
864   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
865   const auto &attr = namedAttr->attr;
866 
867   os << "{\n";
868   os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
869                          "(void)tblgen_attr;\n",
870                          opName, attr.getStorageType(), namedAttr->name);
871 
872   // TODO: This should use getter method to avoid duplication.
873   if (attr.hasDefaultValue()) {
874     os << "if (!tblgen_attr) tblgen_attr = "
875        << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
876                             attr.getDefaultValue()))
877        << ";\n";
878   } else if (attr.isOptional()) {
879     // For a missing attribute that is optional according to definition, we
880     // should just capture a mlir::Attribute() to signal the missing state.
881     // That is precisely what getDiscardableAttr() returns on missing
882     // attributes.
883   } else {
884     emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
885                    formatv("\"expected op '{0}' to have attribute '{1}' "
886                            "of type '{2}'\"",
887                            op.getOperationName(), namedAttr->name,
888                            attr.getStorageType()));
889   }
890 
891   auto matcher = tree.getArgAsLeaf(argIndex);
892   if (!matcher.isUnspecified()) {
893     if (!matcher.isAttrMatcher()) {
894       PrintFatalError(
895           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
896                        op.getOperationName(), argIndex + 1));
897     }
898 
899     // If a constraint is specified, we need to generate function call to its
900     // static verifier.
901     StringRef verifier = staticMatcherHelper.getVerifierName(matcher);
902     if (attr.isOptional()) {
903       // Avoid dereferencing null attribute. This is using a simple heuristic to
904       // avoid common cases of attempting to dereference null attribute. This
905       // will return where there is no check if attribute is null unless the
906       // attribute's value is not used.
907       // FIXME: This could be improved as some null dereferences could slip
908       // through.
909       if (!StringRef(matcher.getConditionTemplate()).contains("!$_self") &&
910           StringRef(matcher.getConditionTemplate()).contains("$_self")) {
911         os << "if (!tblgen_attr) return ::mlir::failure();\n";
912       }
913     }
914     emitStaticVerifierCall(
915         verifier, opName, "tblgen_attr",
916         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
917                 "'{2}'\"",
918                 op.getOperationName(), namedAttr->name,
919                 escapeString(matcher.getAsConstraint().getSummary()))
920             .str());
921   }
922 
923   // Capture the value
924   auto name = tree.getArgName(argIndex);
925   // `$_` is a special symbol to ignore op argument matching.
926   if (!name.empty() && name != "_") {
927     os << formatv("{0} = tblgen_attr;\n", name);
928   }
929 
930   os.unindent() << "}\n";
931 }
932 
933 void PatternEmitter::emitMatchCheck(
934     StringRef opName, const FmtObjectBase &matchFmt,
935     const llvm::formatv_object_base &failureFmt) {
936   emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
937 }
938 
939 void PatternEmitter::emitMatchCheck(StringRef opName,
940                                     const std::string &matchStr,
941                                     const std::string &failureStr) {
942 
943   os << "if (!(" << matchStr << "))";
944   os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
945                               << ", [&](::mlir::Diagnostic &diag) {\n  diag << "
946                               << failureStr << ";\n});";
947 }
948 
949 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
950   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
951   int depth = 0;
952   emitMatch(tree, opName, depth);
953 
954   for (auto &appliedConstraint : pattern.getConstraints()) {
955     auto &constraint = appliedConstraint.constraint;
956     auto &entities = appliedConstraint.entities;
957 
958     auto condition = constraint.getConditionTemplate();
959     if (isa<TypeConstraint>(constraint)) {
960       if (entities.size() != 1)
961         PrintFatalError(loc, "type constraint requires exactly one argument");
962 
963       auto self = formatv("({0}.getType())",
964                           symbolInfoMap.getValueAndRangeUse(entities.front()));
965       emitMatchCheck(
966           opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
967           formatv("\"value entity '{0}' failed to satisfy constraint: '{1}'\"",
968                   entities.front(), escapeString(constraint.getSummary())));
969 
970     } else if (isa<AttrConstraint>(constraint)) {
971       PrintFatalError(
972           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
973     } else {
974       // TODO: replace formatv arguments with the exact specified
975       // args.
976       if (entities.size() > 4) {
977         PrintFatalError(loc, "only support up to 4-entity constraints now");
978       }
979       SmallVector<std::string, 4> names;
980       int i = 0;
981       for (int e = entities.size(); i < e; ++i)
982         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
983       std::string self = appliedConstraint.self;
984       if (!self.empty())
985         self = symbolInfoMap.getValueAndRangeUse(self);
986       for (; i < 4; ++i)
987         names.push_back("<unused>");
988       emitMatchCheck(opName,
989                      tgfmt(condition, &fmtCtx.withSelf(self), names[0],
990                            names[1], names[2], names[3]),
991                      formatv("\"entities '{0}' failed to satisfy constraint: "
992                              "'{1}'\"",
993                              llvm::join(entities, ", "),
994                              escapeString(constraint.getSummary())));
995     }
996   }
997 
998   // Some of the operands could be bound to the same symbol name, we need
999   // to enforce equality constraint on those.
1000   // TODO: we should be able to emit equality checks early
1001   // and short circuit unnecessary work if vars are not equal.
1002   for (auto symbolInfoIt = symbolInfoMap.begin();
1003        symbolInfoIt != symbolInfoMap.end();) {
1004     auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
1005     auto startRange = range.first;
1006     auto endRange = range.second;
1007 
1008     auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
1009     for (++startRange; startRange != endRange; ++startRange) {
1010       auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
1011       emitMatchCheck(
1012           opName,
1013           formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
1014           formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
1015                   secondOperand));
1016     }
1017 
1018     symbolInfoIt = endRange;
1019   }
1020 
1021   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
1022 }
1023 
1024 void PatternEmitter::collectOps(DagNode tree,
1025                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
1026   // Check if this tree is an operation.
1027   if (tree.isOperation()) {
1028     const Operator &op = tree.getDialectOp(opMap);
1029     LLVM_DEBUG(llvm::dbgs()
1030                << "found operation " << op.getOperationName() << '\n');
1031     ops.insert(&op);
1032   }
1033 
1034   // Recurse the arguments of the tree.
1035   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
1036     if (auto child = tree.getArgAsNestedDag(i))
1037       collectOps(child, ops);
1038 }
1039 
1040 void PatternEmitter::emit(StringRef rewriteName) {
1041   // Get the DAG tree for the source pattern.
1042   DagNode sourceTree = pattern.getSourcePattern();
1043 
1044   const Operator &rootOp = pattern.getSourceRootOp();
1045   auto rootName = rootOp.getOperationName();
1046 
1047   // Collect the set of result operations.
1048   llvm::SmallPtrSet<const Operator *, 4> resultOps;
1049   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
1050   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
1051     collectOps(pattern.getResultPattern(i), resultOps);
1052   }
1053   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
1054 
1055   // Emit RewritePattern for Pattern.
1056   auto locs = pattern.getLocation();
1057   os << formatv("/* Generated from:\n    {0:$[ instantiating\n    ]}\n*/\n",
1058                 make_range(locs.rbegin(), locs.rend()));
1059   os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
1060   {0}(::mlir::MLIRContext *context)
1061       : ::mlir::RewritePattern("{1}", {2}, context, {{)",
1062                 rewriteName, rootName, pattern.getBenefit());
1063   // Sort result operators by name.
1064   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
1065                                                          resultOps.end());
1066   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
1067     return lhs->getOperationName() < rhs->getOperationName();
1068   });
1069   llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
1070     os << '"' << op->getOperationName() << '"';
1071   });
1072   os << "}) {}\n";
1073 
1074   // Emit matchAndRewrite() function.
1075   {
1076     auto classScope = os.scope();
1077     os.printReindented(R"(
1078     ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
1079         ::mlir::PatternRewriter &rewriter) const override {)")
1080         << '\n';
1081     {
1082       auto functionScope = os.scope();
1083 
1084       // Register all symbols bound in the source pattern.
1085       pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
1086 
1087       LLVM_DEBUG(llvm::dbgs()
1088                  << "start creating local variables for capturing matches\n");
1089       os << "// Variables for capturing values and attributes used while "
1090             "creating ops\n";
1091       // Create local variables for storing the arguments and results bound
1092       // to symbols.
1093       for (const auto &symbolInfoPair : symbolInfoMap) {
1094         const auto &symbol = symbolInfoPair.first;
1095         const auto &info = symbolInfoPair.second;
1096 
1097         os << info.getVarDecl(symbol);
1098       }
1099       // TODO: capture ops with consistent numbering so that it can be
1100       // reused for fused loc.
1101       os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n";
1102       LLVM_DEBUG(llvm::dbgs()
1103                  << "done creating local variables for capturing matches\n");
1104 
1105       os << "// Match\n";
1106       os << "tblgen_ops.push_back(op0);\n";
1107       emitMatchLogic(sourceTree, "op0");
1108 
1109       os << "\n// Rewrite\n";
1110       emitRewriteLogic();
1111 
1112       os << "return ::mlir::success();\n";
1113     }
1114     os << "};\n";
1115   }
1116   os << "};\n\n";
1117 }
1118 
1119 void PatternEmitter::emitRewriteLogic() {
1120   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
1121   const Operator &rootOp = pattern.getSourceRootOp();
1122   int numExpectedResults = rootOp.getNumResults();
1123   int numResultPatterns = pattern.getNumResultPatterns();
1124 
1125   // First register all symbols bound to ops generated in result patterns.
1126   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
1127 
1128   // Only the last N static values generated are used to replace the matched
1129   // root N-result op. We need to calculate the starting index (of the results
1130   // of the matched op) each result pattern is to replace.
1131   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
1132   // If we don't need to replace any value at all, set the replacement starting
1133   // index as the number of result patterns so we skip all of them when trying
1134   // to replace the matched op's results.
1135   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
1136   for (int i = numResultPatterns - 1; i >= 0; --i) {
1137     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
1138     offsets[i] = offsets[i + 1] - numValues;
1139     if (offsets[i] == 0) {
1140       if (replStartIndex == -1)
1141         replStartIndex = i;
1142     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
1143       auto error = formatv(
1144           "cannot use the same multi-result op '{0}' to generate both "
1145           "auxiliary values and values to be used for replacing the matched op",
1146           pattern.getResultPattern(i).getSymbol());
1147       PrintFatalError(loc, error);
1148     }
1149   }
1150 
1151   if (offsets.front() > 0) {
1152     const char error[] =
1153         "not enough values generated to replace the matched op";
1154     PrintFatalError(loc, error);
1155   }
1156 
1157   os << "auto odsLoc = rewriter.getFusedLoc({";
1158   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
1159     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
1160   }
1161   os << "}); (void)odsLoc;\n";
1162 
1163   // Process auxiliary result patterns.
1164   for (int i = 0; i < replStartIndex; ++i) {
1165     DagNode resultTree = pattern.getResultPattern(i);
1166     auto val = handleResultPattern(resultTree, offsets[i], 0);
1167     // Normal op creation will be streamed to `os` by the above call; but
1168     // NativeCodeCall will only be materialized to `os` if it is used. Here
1169     // we are handling auxiliary patterns so we want the side effect even if
1170     // NativeCodeCall is not replacing matched root op's results.
1171     if (resultTree.isNativeCodeCall() &&
1172         resultTree.getNumReturnsOfNativeCode() == 0)
1173       os << val << ";\n";
1174   }
1175 
1176   if (numExpectedResults == 0) {
1177     assert(replStartIndex >= numResultPatterns &&
1178            "invalid auxiliary vs. replacement pattern division!");
1179     // No result to replace. Just erase the op.
1180     os << "rewriter.eraseOp(op0);\n";
1181   } else {
1182     // Process replacement result patterns.
1183     os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
1184     for (int i = replStartIndex; i < numResultPatterns; ++i) {
1185       DagNode resultTree = pattern.getResultPattern(i);
1186       auto val = handleResultPattern(resultTree, offsets[i], 0);
1187       os << "\n";
1188       // Resolve each symbol for all range use so that we can loop over them.
1189       // We need an explicit cast to `SmallVector` to capture the cases where
1190       // `{0}` resolves to an `Operation::result_range` as well as cases that
1191       // are not iterable (e.g. vector that gets wrapped in additional braces by
1192       // RewriterGen).
1193       // TODO: Revisit the need for materializing a vector.
1194       os << symbolInfoMap.getAllRangeUse(
1195           val,
1196           "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
1197           "  tblgen_repl_values.push_back(v);\n}\n",
1198           "\n");
1199     }
1200     os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
1201   }
1202 
1203   // Process supplemtal patterns.
1204   int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
1205   for (int i = 0, offset = -numSupplementalPatterns;
1206        i < numSupplementalPatterns; ++i) {
1207     DagNode resultTree = pattern.getSupplementalPattern(i);
1208     auto val = handleResultPattern(resultTree, offset++, 0);
1209     if (resultTree.isNativeCodeCall() &&
1210         resultTree.getNumReturnsOfNativeCode() == 0)
1211       os << val << ";\n";
1212   }
1213 
1214   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
1215 }
1216 
1217 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
1218   return std::string(
1219       formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
1220 }
1221 
1222 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
1223                                                 int resultIndex, int depth) {
1224   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
1225   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
1226   LLVM_DEBUG(llvm::dbgs() << '\n');
1227 
1228   if (resultTree.isLocationDirective()) {
1229     PrintFatalError(loc,
1230                     "location directive can only be used with op creation");
1231   }
1232 
1233   if (resultTree.isNativeCodeCall())
1234     return handleReplaceWithNativeCodeCall(resultTree, depth);
1235 
1236   if (resultTree.isReplaceWithValue())
1237     return handleReplaceWithValue(resultTree).str();
1238 
1239   // Normal op creation.
1240   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
1241   if (resultTree.getSymbol().empty()) {
1242     // This is an op not explicitly bound to a symbol in the rewrite rule.
1243     // Register the auto-generated symbol for it.
1244     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
1245   }
1246   return symbol;
1247 }
1248 
1249 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
1250   assert(tree.isReplaceWithValue());
1251 
1252   if (tree.getNumArgs() != 1) {
1253     PrintFatalError(
1254         loc, "replaceWithValue directive must take exactly one argument");
1255   }
1256 
1257   if (!tree.getSymbol().empty()) {
1258     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
1259   }
1260 
1261   return tree.getArgName(0);
1262 }
1263 
1264 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
1265   assert(tree.isLocationDirective());
1266   auto lookUpArgLoc = [this, &tree](int idx) {
1267     const auto *const lookupFmt = "{0}.getLoc()";
1268     return symbolInfoMap.getValueAndRangeUse(tree.getArgName(idx), lookupFmt);
1269   };
1270 
1271   if (tree.getNumArgs() == 0)
1272     llvm::PrintFatalError(
1273         "At least one argument to location directive required");
1274 
1275   if (!tree.getSymbol().empty())
1276     PrintFatalError(loc, "cannot bind symbol to location");
1277 
1278   if (tree.getNumArgs() == 1) {
1279     DagLeaf leaf = tree.getArgAsLeaf(0);
1280     if (leaf.isStringAttr())
1281       return formatv("::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))",
1282                      leaf.getStringAttr())
1283           .str();
1284     return lookUpArgLoc(0);
1285   }
1286 
1287   std::string ret;
1288   llvm::raw_string_ostream os(ret);
1289   std::string strAttr;
1290   os << "rewriter.getFusedLoc({";
1291   bool first = true;
1292   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1293     DagLeaf leaf = tree.getArgAsLeaf(i);
1294     // Handle the optional string value.
1295     if (leaf.isStringAttr()) {
1296       if (!strAttr.empty())
1297         llvm::PrintFatalError("Only one string attribute may be specified");
1298       strAttr = leaf.getStringAttr();
1299       continue;
1300     }
1301     os << (first ? "" : ", ") << lookUpArgLoc(i);
1302     first = false;
1303   }
1304   os << "}";
1305   if (!strAttr.empty()) {
1306     os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
1307   }
1308   os << ")";
1309   return os.str();
1310 }
1311 
1312 std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
1313                                                 int depth) {
1314   // Nested NativeCodeCall.
1315   if (auto dagNode = returnType.getArgAsNestedDag(i)) {
1316     if (!dagNode.isNativeCodeCall())
1317       PrintFatalError(loc, "nested DAG in `returnType` must be a native code "
1318                            "call");
1319     return handleReplaceWithNativeCodeCall(dagNode, depth);
1320   }
1321   // String literal.
1322   auto dagLeaf = returnType.getArgAsLeaf(i);
1323   if (dagLeaf.isStringAttr())
1324     return tgfmt(dagLeaf.getStringAttr(), &fmtCtx);
1325   return tgfmt(
1326       "$0.getType()", &fmtCtx,
1327       handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i)));
1328 }
1329 
1330 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
1331                                              StringRef patArgName) {
1332   if (leaf.isStringAttr())
1333     PrintFatalError(loc, "raw string not supported as argument");
1334   if (leaf.isConstantAttr()) {
1335     auto constAttr = leaf.getAsConstantAttr();
1336     return handleConstantAttr(constAttr.getAttribute(),
1337                               constAttr.getConstantValue());
1338   }
1339   if (leaf.isEnumAttrCase()) {
1340     auto enumCase = leaf.getAsEnumAttrCase();
1341     // This is an enum case backed by an IntegerAttr. We need to get its value
1342     // to build the constant.
1343     std::string val = std::to_string(enumCase.getValue());
1344     return handleConstantAttr(enumCase, val);
1345   }
1346 
1347   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
1348   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
1349   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
1350     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
1351                             << "' (via symbol ref)\n");
1352     return argName;
1353   }
1354   if (leaf.isNativeCodeCall()) {
1355     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
1356     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
1357                             << "' (via NativeCodeCall)\n");
1358     return std::string(repl);
1359   }
1360   PrintFatalError(loc, "unhandled case when rewriting op");
1361 }
1362 
1363 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
1364                                                             int depth) {
1365   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
1366   LLVM_DEBUG(tree.print(llvm::dbgs()));
1367   LLVM_DEBUG(llvm::dbgs() << '\n');
1368 
1369   auto fmt = tree.getNativeCodeTemplate();
1370 
1371   SmallVector<std::string, 16> attrs;
1372 
1373   auto tail = getTrailingDirectives(tree);
1374   if (tail.returnType)
1375     PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
1376   auto locToUse = getLocation(tail);
1377 
1378   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1379     if (tree.isNestedDagArg(i)) {
1380       attrs.push_back(
1381           handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
1382     } else {
1383       attrs.push_back(
1384           handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
1385     }
1386     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
1387                             << " replacement: " << attrs[i] << "\n");
1388   }
1389 
1390   std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse),
1391                              static_cast<ArrayRef<std::string>>(attrs));
1392 
1393   // In general, NativeCodeCall without naming binding don't need this. To
1394   // ensure void helper function has been correctly labeled, i.e., use
1395   // NativeCodeCallVoid, we cache the result to a local variable so that we will
1396   // get a compilation error in the auto-generated file.
1397   // Example.
1398   //   // In the td file
1399   //   Pat<(...), (NativeCodeCall<Foo> ...)>
1400   //
1401   //   ---
1402   //
1403   //   // In the auto-generated .cpp
1404   //   ...
1405   //   // Causes compilation error if Foo() returns void.
1406   //   auto nativeVar = Foo();
1407   //   ...
1408   if (tree.getNumReturnsOfNativeCode() != 0) {
1409     // Determine the local variable name for return value.
1410     std::string varName =
1411         SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
1412     if (varName.empty()) {
1413       varName = formatv("nativeVar_{0}", nextValueId++);
1414       // Register the local variable for later uses.
1415       symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
1416     }
1417 
1418     // Catch the return value of helper function.
1419     os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);
1420 
1421     if (!tree.getSymbol().empty())
1422       symbol = tree.getSymbol().str();
1423     else
1424       symbol = varName;
1425   }
1426 
1427   return symbol;
1428 }
1429 
1430 int PatternEmitter::getNodeValueCount(DagNode node) {
1431   if (node.isOperation()) {
1432     // If the op is bound to a symbol in the rewrite rule, query its result
1433     // count from the symbol info map.
1434     auto symbol = node.getSymbol();
1435     if (!symbol.empty()) {
1436       return symbolInfoMap.getStaticValueCount(symbol);
1437     }
1438     // Otherwise this is an unbound op; we will use all its results.
1439     return pattern.getDialectOp(node).getNumResults();
1440   }
1441 
1442   if (node.isNativeCodeCall())
1443     return node.getNumReturnsOfNativeCode();
1444 
1445   return 1;
1446 }
1447 
1448 PatternEmitter::TrailingDirectives
1449 PatternEmitter::getTrailingDirectives(DagNode tree) {
1450   TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0};
1451 
1452   // Look backwards through the arguments.
1453   auto numPatArgs = tree.getNumArgs();
1454   for (int i = numPatArgs - 1; i >= 0; --i) {
1455     auto dagArg = tree.getArgAsNestedDag(i);
1456     // A leaf is not a directive. Stop looking.
1457     if (!dagArg)
1458       break;
1459 
1460     auto isLocation = dagArg.isLocationDirective();
1461     auto isReturnType = dagArg.isReturnTypeDirective();
1462     // If encountered a DAG node that isn't a trailing directive, stop looking.
1463     if (!(isLocation || isReturnType))
1464       break;
1465     // Save the directive, but error if one of the same type was already
1466     // found.
1467     ++tail.numDirectives;
1468     if (isLocation) {
1469       if (tail.location)
1470         PrintFatalError(loc, "`location` directive can only be specified "
1471                              "once");
1472       tail.location = dagArg;
1473     } else if (isReturnType) {
1474       if (tail.returnType)
1475         PrintFatalError(loc, "`returnType` directive can only be specified "
1476                              "once");
1477       tail.returnType = dagArg;
1478     }
1479   }
1480 
1481   return tail;
1482 }
1483 
1484 std::string
1485 PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
1486   if (tail.location)
1487     return handleLocationDirective(tail.location);
1488 
1489   // If no explicit location is given, use the default, all fused, location.
1490   return "odsLoc";
1491 }
1492 
1493 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
1494                                              int depth) {
1495   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
1496   LLVM_DEBUG(tree.print(llvm::dbgs()));
1497   LLVM_DEBUG(llvm::dbgs() << '\n');
1498 
1499   Operator &resultOp = tree.getDialectOp(opMap);
1500   auto numOpArgs = resultOp.getNumArgs();
1501   auto numPatArgs = tree.getNumArgs();
1502 
1503   auto tail = getTrailingDirectives(tree);
1504   auto locToUse = getLocation(tail);
1505 
1506   auto inPattern = numPatArgs - tail.numDirectives;
1507   if (numOpArgs != inPattern) {
1508     PrintFatalError(loc,
1509                     formatv("resultant op '{0}' argument number mismatch: "
1510                             "{1} in pattern vs. {2} in definition",
1511                             resultOp.getOperationName(), inPattern, numOpArgs));
1512   }
1513 
1514   // A map to collect all nested DAG child nodes' names, with operand index as
1515   // the key. This includes both bound and unbound child nodes.
1516   ChildNodeIndexNameMap childNodeNames;
1517 
1518   // First go through all the child nodes who are nested DAG constructs to
1519   // create ops for them and remember the symbol names for them, so that we can
1520   // use the results in the current node. This happens in a recursive manner.
1521   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1522     if (auto child = tree.getArgAsNestedDag(i))
1523       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
1524   }
1525 
1526   // The name of the local variable holding this op.
1527   std::string valuePackName;
1528   // The symbol for holding the result of this pattern. Note that the result of
1529   // this pattern is not necessarily the same as the variable created by this
1530   // pattern because we can use `__N` suffix to refer only a specific result if
1531   // the generated op is a multi-result op.
1532   std::string resultValue;
1533   if (tree.getSymbol().empty()) {
1534     // No symbol is explicitly bound to this op in the pattern. Generate a
1535     // unique name.
1536     valuePackName = resultValue = getUniqueSymbol(&resultOp);
1537   } else {
1538     resultValue = std::string(tree.getSymbol());
1539     // Strip the index to get the name for the value pack and use it to name the
1540     // local variable for the op.
1541     valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
1542   }
1543 
1544   // Create the local variable for this op.
1545   os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
1546                 valuePackName);
1547 
1548   // Right now ODS don't have general type inference support. Except a few
1549   // special cases listed below, DRR needs to supply types for all results
1550   // when building an op.
1551   bool isSameOperandsAndResultType =
1552       resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
1553   bool useFirstAttr =
1554       resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
1555 
1556   if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) {
1557     // We know how to deduce the result type for ops with these traits and we've
1558     // generated builders taking aggregate parameters. Use those builders to
1559     // create the ops.
1560 
1561     // First prepare local variables for op arguments used in builder call.
1562     createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1563 
1564     // Then create the op.
1565     os.scope("", "\n}\n").os << formatv(
1566         "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1567         valuePackName, resultOp.getQualCppClassName(), locToUse);
1568     return resultValue;
1569   }
1570 
1571   bool usePartialResults = valuePackName != resultValue;
1572 
1573   if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) {
1574     // For these cases (broadcastable ops, op results used both as auxiliary
1575     // values and replacement values, ops in nested patterns, auxiliary ops), we
1576     // still need to supply the result types when building the op. But because
1577     // we don't generate a builder automatically with ODS for them, it's the
1578     // developer's responsibility to make sure such a builder (with result type
1579     // deduction ability) exists. We go through the separate-parameter builder
1580     // here given that it's easier for developers to write compared to
1581     // aggregate-parameter builders.
1582     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
1583 
1584     os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
1585                              resultOp.getQualCppClassName(), locToUse);
1586     supplyValuesForOpArgs(tree, childNodeNames, depth);
1587     os << "\n  );\n}\n";
1588     return resultValue;
1589   }
1590 
1591   // If we are provided explicit return types, use them to build the op.
1592   // However, if depth == 0 and resultIndex >= 0, it means we are replacing
1593   // the values generated from the source pattern root op. Then we must use the
1594   // source pattern's value types to determine the value type of the generated
1595   // op here.
1596   if (depth == 0 && resultIndex >= 0 && tail.returnType)
1597     PrintFatalError(loc, "Cannot specify explicit return types in an op whose "
1598                          "return values replace the source pattern's root op");
1599 
1600   // First prepare local variables for op arguments used in builder call.
1601   createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1602 
1603   // Then prepare the result types. We need to specify the types for all
1604   // results.
1605   os.indent() << formatv("::llvm::SmallVector<::mlir::Type, 4> tblgen_types; "
1606                          "(void)tblgen_types;\n");
1607   int numResults = resultOp.getNumResults();
1608   if (tail.returnType) {
1609     auto numRetTys = tail.returnType.getNumArgs();
1610     for (int i = 0; i < numRetTys; ++i) {
1611       auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1);
1612       os << "tblgen_types.push_back(" << varName << ");\n";
1613     }
1614   } else {
1615     if (numResults != 0) {
1616       // Copy the result types from the source pattern.
1617       for (int i = 0; i < numResults; ++i)
1618         os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
1619                       "  tblgen_types.push_back(v.getType());\n}\n",
1620                       resultIndex + i);
1621     }
1622   }
1623   os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
1624                 "tblgen_values, tblgen_attrs);\n",
1625                 valuePackName, resultOp.getQualCppClassName(), locToUse);
1626   os.unindent() << "}\n";
1627   return resultValue;
1628 }
1629 
1630 void PatternEmitter::createSeparateLocalVarsForOpArgs(
1631     DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1632   Operator &resultOp = node.getDialectOp(opMap);
1633 
1634   // Now prepare operands used for building this op:
1635   // * If the operand is non-variadic, we create a `Value` local variable.
1636   // * If the operand is variadic, we create a `SmallVector<Value>` local
1637   //   variable.
1638 
1639   int valueIndex = 0; // An index for uniquing local variable names.
1640   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1641     const auto *operand =
1642         llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(argIndex));
1643     // We do not need special handling for attributes.
1644     if (!operand)
1645       continue;
1646 
1647     raw_indented_ostream::DelimitedScope scope(os);
1648     std::string varName;
1649     if (operand->isVariadic()) {
1650       varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
1651       os << formatv("::llvm::SmallVector<::mlir::Value, 4> {0};\n", varName);
1652       std::string range;
1653       if (node.isNestedDagArg(argIndex)) {
1654         range = childNodeNames[argIndex];
1655       } else {
1656         range = std::string(node.getArgName(argIndex));
1657       }
1658       // Resolve the symbol for all range use so that we have a uniform way of
1659       // capturing the values.
1660       range = symbolInfoMap.getValueAndRangeUse(range);
1661       os << formatv("for (auto v: {0}) {{\n  {1}.push_back(v);\n}\n", range,
1662                     varName);
1663     } else {
1664       varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
1665       os << formatv("::mlir::Value {0} = ", varName);
1666       if (node.isNestedDagArg(argIndex)) {
1667         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
1668       } else {
1669         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1670         auto symbol =
1671             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1672         if (leaf.isNativeCodeCall()) {
1673           os << std::string(
1674               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1675         } else {
1676           os << symbol;
1677         }
1678       }
1679       os << ";\n";
1680     }
1681 
1682     // Update to use the newly created local variable for building the op later.
1683     childNodeNames[argIndex] = varName;
1684   }
1685 }
1686 
1687 void PatternEmitter::supplyValuesForOpArgs(
1688     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1689   Operator &resultOp = node.getDialectOp(opMap);
1690   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1691        argIndex != numOpArgs; ++argIndex) {
1692     // Start each argument on its own line.
1693     os << ",\n    ";
1694 
1695     Argument opArg = resultOp.getArg(argIndex);
1696     // Handle the case of operand first.
1697     if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
1698       if (!operand->name.empty())
1699         os << "/*" << operand->name << "=*/";
1700       os << childNodeNames.lookup(argIndex);
1701       continue;
1702     }
1703 
1704     // The argument in the op definition.
1705     auto opArgName = resultOp.getArgName(argIndex);
1706     if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1707       if (!subTree.isNativeCodeCall())
1708         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1709                              "for creating attribute");
1710       os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
1711     } else {
1712       auto leaf = node.getArgAsLeaf(argIndex);
1713       // The argument in the result DAG pattern.
1714       auto patArgName = node.getArgName(argIndex);
1715       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1716         // TODO: Refactor out into map to avoid recomputing these.
1717         if (!opArg.is<NamedAttribute *>())
1718           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1719         if (!patArgName.empty())
1720           os << "/*" << patArgName << "=*/";
1721       } else {
1722         os << "/*" << opArgName << "=*/";
1723       }
1724       os << handleOpArgument(leaf, patArgName);
1725     }
1726   }
1727 }
1728 
1729 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1730     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1731   Operator &resultOp = node.getDialectOp(opMap);
1732 
1733   auto scope = os.scope();
1734   os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
1735                 "tblgen_values; (void)tblgen_values;\n");
1736   os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1737                 "tblgen_attrs; (void)tblgen_attrs;\n");
1738 
1739   const char *addAttrCmd =
1740       "if (auto tmpAttr = {1}) {\n"
1741       "  tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
1742       "tmpAttr);\n}\n";
1743   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1744     if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1745       // The argument in the op definition.
1746       auto opArgName = resultOp.getArgName(argIndex);
1747       if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1748         if (!subTree.isNativeCodeCall())
1749           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1750                                "for creating attribute");
1751         os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
1752       } else {
1753         auto leaf = node.getArgAsLeaf(argIndex);
1754         // The argument in the result DAG pattern.
1755         auto patArgName = node.getArgName(argIndex);
1756         os << formatv(addAttrCmd, opArgName,
1757                       handleOpArgument(leaf, patArgName));
1758       }
1759       continue;
1760     }
1761 
1762     const auto *operand =
1763         resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1764     std::string varName;
1765     if (operand->isVariadic()) {
1766       std::string range;
1767       if (node.isNestedDagArg(argIndex)) {
1768         range = childNodeNames.lookup(argIndex);
1769       } else {
1770         range = std::string(node.getArgName(argIndex));
1771       }
1772       // Resolve the symbol for all range use so that we have a uniform way of
1773       // capturing the values.
1774       range = symbolInfoMap.getValueAndRangeUse(range);
1775       os << formatv("for (auto v: {0}) {{\n  tblgen_values.push_back(v);\n}\n",
1776                     range);
1777     } else {
1778       os << formatv("tblgen_values.push_back(");
1779       if (node.isNestedDagArg(argIndex)) {
1780         os << symbolInfoMap.getValueAndRangeUse(
1781             childNodeNames.lookup(argIndex));
1782       } else {
1783         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1784         if (leaf.isConstantAttr())
1785           // TODO: Use better location
1786           PrintFatalError(
1787               loc,
1788               "attribute found where value was expected, if attempting to use "
1789               "constant value, construct a constant op with given attribute "
1790               "instead");
1791 
1792         auto symbol =
1793             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1794         if (leaf.isNativeCodeCall()) {
1795           os << std::string(
1796               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1797         } else {
1798           os << symbol;
1799         }
1800       }
1801       os << ");\n";
1802     }
1803   }
1804 }
1805 
1806 StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
1807                                          const RecordKeeper &recordKeeper,
1808                                          RecordOperatorMap &mapper)
1809     : opMap(mapper), staticVerifierEmitter(os, recordKeeper) {}
1810 
1811 void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
1812   // PatternEmitter will use the static matcher if there's one generated. To
1813   // ensure that all the dependent static matchers are generated before emitting
1814   // the matching logic of the DagNode, we use topological order to achieve it.
1815   for (auto &dagInfo : topologicalOrder) {
1816     DagNode node = dagInfo.first;
1817     if (!useStaticMatcher(node))
1818       continue;
1819 
1820     std::string funcName =
1821         formatv("static_dag_matcher_{0}", staticMatcherCounter++);
1822     assert(!matcherNames.contains(node));
1823     PatternEmitter(dagInfo.second, &opMap, os, *this)
1824         .emitStaticMatcher(node, funcName);
1825     matcherNames[node] = funcName;
1826   }
1827 }
1828 
1829 void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
1830   staticVerifierEmitter.emitPatternConstraints(constraints.getArrayRef());
1831 }
1832 
1833 void StaticMatcherHelper::addPattern(Record *record) {
1834   Pattern pat(record, &opMap);
1835 
1836   // While generating the function body of the DAG matcher, it may depends on
1837   // other DAG matchers. To ensure the dependent matchers are ready, we compute
1838   // the topological order for all the DAGs and emit the DAG matchers in this
1839   // order.
1840   llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
1841     ++refStats[node];
1842 
1843     if (refStats[node] != 1)
1844       return;
1845 
1846     for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
1847       if (DagNode sibling = node.getArgAsNestedDag(i))
1848         dfs(sibling);
1849       else {
1850         DagLeaf leaf = node.getArgAsLeaf(i);
1851         if (!leaf.isUnspecified())
1852           constraints.insert(leaf);
1853       }
1854 
1855     topologicalOrder.push_back(std::make_pair(node, record));
1856   };
1857 
1858   dfs(pat.getSourcePattern());
1859 }
1860 
1861 StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
1862   if (leaf.isAttrMatcher()) {
1863     std::optional<StringRef> constraint =
1864         staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint());
1865     assert(constraint && "attribute constraint was not uniqued");
1866     return *constraint;
1867   }
1868   assert(leaf.isOperandMatcher());
1869   return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint());
1870 }
1871 
1872 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1873   emitSourceFileHeader("Rewriters", os);
1874 
1875   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1876 
1877   // We put the map here because it can be shared among multiple patterns.
1878   RecordOperatorMap recordOpMap;
1879 
1880   // Exam all the patterns and generate static matcher for the duplicated
1881   // DagNode.
1882   StaticMatcherHelper staticMatcher(os, recordKeeper, recordOpMap);
1883   for (Record *p : patterns)
1884     staticMatcher.addPattern(p);
1885   staticMatcher.populateStaticConstraintFunctions(os);
1886   staticMatcher.populateStaticMatchers(os);
1887 
1888   std::vector<std::string> rewriterNames;
1889   rewriterNames.reserve(patterns.size());
1890 
1891   std::string baseRewriterName = "GeneratedConvert";
1892   int rewriterIndex = 0;
1893 
1894   for (Record *p : patterns) {
1895     std::string name;
1896     if (p->isAnonymous()) {
1897       // If no name is provided, ensure unique rewriter names simply by
1898       // appending unique suffix.
1899       name = baseRewriterName + llvm::utostr(rewriterIndex++);
1900     } else {
1901       name = std::string(p->getName());
1902     }
1903     LLVM_DEBUG(llvm::dbgs()
1904                << "=== start generating pattern '" << name << "' ===\n");
1905     PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name);
1906     LLVM_DEBUG(llvm::dbgs()
1907                << "=== done generating pattern '" << name << "' ===\n");
1908     rewriterNames.push_back(std::move(name));
1909   }
1910 
1911   // Emit function to add the generated matchers to the pattern list.
1912   os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
1913         "::mlir::RewritePatternSet &patterns) {\n";
1914   for (const auto &name : rewriterNames) {
1915     os << "  patterns.add<" << name << ">(patterns.getContext());\n";
1916   }
1917   os << "}\n";
1918 }
1919 
1920 static mlir::GenRegistration
1921     genRewriters("gen-rewriters", "Generate pattern rewriters",
1922                  [](const RecordKeeper &records, raw_ostream &os) {
1923                    emitRewriters(records, os);
1924                    return false;
1925                  });
1926