1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Support/IndentedOstream.h" 14 #include "mlir/TableGen/Attribute.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/GenInfo.h" 17 #include "mlir/TableGen/Operator.h" 18 #include "mlir/TableGen/Pattern.h" 19 #include "mlir/TableGen/Predicate.h" 20 #include "mlir/TableGen/Type.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/StringSet.h" 23 #include "llvm/Support/CommandLine.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/FormatAdapters.h" 26 #include "llvm/Support/PrettyStackTrace.h" 27 #include "llvm/Support/Signals.h" 28 #include "llvm/TableGen/Error.h" 29 #include "llvm/TableGen/Main.h" 30 #include "llvm/TableGen/Record.h" 31 #include "llvm/TableGen/TableGenBackend.h" 32 33 using namespace mlir; 34 using namespace mlir::tblgen; 35 36 using llvm::formatv; 37 using llvm::Record; 38 using llvm::RecordKeeper; 39 40 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 41 42 namespace llvm { 43 template <> 44 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 45 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 46 raw_ostream &os, StringRef style) { 47 os << v.first << ":" << v.second; 48 } 49 }; 50 } // end namespace llvm 51 52 //===----------------------------------------------------------------------===// 53 // PatternEmitter 54 //===----------------------------------------------------------------------===// 55 56 namespace { 57 class PatternEmitter { 58 public: 59 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 60 61 // Emits the mlir::RewritePattern struct named `rewriteName`. 62 void emit(StringRef rewriteName); 63 64 private: 65 // Emits the code for matching ops. 66 void emitMatchLogic(DagNode tree); 67 68 // Emits the code for rewriting ops. 69 void emitRewriteLogic(); 70 71 //===--------------------------------------------------------------------===// 72 // Match utilities 73 //===--------------------------------------------------------------------===// 74 75 // Emits C++ statements for matching the op constrained by the given DAG 76 // `tree`. 77 void emitOpMatch(DagNode tree, int depth); 78 79 // Emits C++ statements for matching the `argIndex`-th argument of the given 80 // DAG `tree` as an operand. 81 void emitOperandMatch(DagNode tree, int argIndex, int depth); 82 83 // Emits C++ statements for matching the `argIndex`-th argument of the given 84 // DAG `tree` as an attribute. 85 void emitAttributeMatch(DagNode tree, int argIndex, int depth); 86 87 // Emits C++ for checking a match with a corresponding match failure 88 // diagnostic. 89 void emitMatchCheck(int depth, const FmtObjectBase &matchFmt, 90 const llvm::formatv_object_base &failureFmt); 91 92 //===--------------------------------------------------------------------===// 93 // Rewrite utilities 94 //===--------------------------------------------------------------------===// 95 96 // The entry point for handling a result pattern rooted at `resultTree`. This 97 // method dispatches to concrete handlers according to `resultTree`'s kind and 98 // returns a symbol representing the whole value pack. Callers are expected to 99 // further resolve the symbol according to the specific use case. 100 // 101 // `depth` is the nesting level of `resultTree`; 0 means top-level result 102 // pattern. For top-level result pattern, `resultIndex` indicates which result 103 // of the matched root op this pattern is intended to replace, which can be 104 // used to deduce the result type of the op generated from this result 105 // pattern. 106 std::string handleResultPattern(DagNode resultTree, int resultIndex, 107 int depth); 108 109 // Emits the C++ statement to replace the matched DAG with a value built via 110 // calling native C++ code. 111 std::string handleReplaceWithNativeCodeCall(DagNode resultTree); 112 113 // Returns the symbol of the old value serving as the replacement. 114 StringRef handleReplaceWithValue(DagNode tree); 115 116 // Returns the location value to use. 117 std::pair<bool, std::string> getLocation(DagNode tree); 118 119 // Returns the location value to use. 120 std::string handleLocationDirective(DagNode tree); 121 122 // Emits the C++ statement to build a new op out of the given DAG `tree` and 123 // returns the variable name that this op is assigned to. If the root op in 124 // DAG `tree` has a specified name, the created op will be assigned to a 125 // variable of the given name. Otherwise, a unique name will be used as the 126 // result value name. 127 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 128 129 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; 130 131 // Emits a local variable for each value and attribute to be used for creating 132 // an op. 133 void createSeparateLocalVarsForOpArgs(DagNode node, 134 ChildNodeIndexNameMap &childNodeNames); 135 136 // Emits the concrete arguments used to call an op's builder. 137 void supplyValuesForOpArgs(DagNode node, 138 const ChildNodeIndexNameMap &childNodeNames); 139 140 // Emits the local variables for holding all values as a whole and all named 141 // attributes as a whole to be used for creating an op. 142 void createAggregateLocalVarsForOpArgs( 143 DagNode node, const ChildNodeIndexNameMap &childNodeNames); 144 145 // Returns the C++ expression to construct a constant attribute of the given 146 // `value` for the given attribute kind `attr`. 147 std::string handleConstantAttr(Attribute attr, StringRef value); 148 149 // Returns the C++ expression to build an argument from the given DAG `leaf`. 150 // `patArgName` is used to bound the argument to the source pattern. 151 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 152 153 //===--------------------------------------------------------------------===// 154 // General utilities 155 //===--------------------------------------------------------------------===// 156 157 // Collects all of the operations within the given dag tree. 158 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 159 160 // Returns a unique symbol for a local variable of the given `op`. 161 std::string getUniqueSymbol(const Operator *op); 162 163 //===--------------------------------------------------------------------===// 164 // Symbol utilities 165 //===--------------------------------------------------------------------===// 166 167 // Returns how many static values the given DAG `node` correspond to. 168 int getNodeValueCount(DagNode node); 169 170 private: 171 // Pattern instantiation location followed by the location of multiclass 172 // prototypes used. This is intended to be used as a whole to 173 // PrintFatalError() on errors. 174 ArrayRef<llvm::SMLoc> loc; 175 176 // Op's TableGen Record to wrapper object. 177 RecordOperatorMap *opMap; 178 179 // Handy wrapper for pattern being emitted. 180 Pattern pattern; 181 182 // Map for all bound symbols' info. 183 SymbolInfoMap symbolInfoMap; 184 185 // The next unused ID for newly created values. 186 unsigned nextValueId; 187 188 raw_indented_ostream os; 189 190 // Format contexts containing placeholder substitutions. 191 FmtContext fmtCtx; 192 193 // Number of op processed. 194 int opCounter = 0; 195 }; 196 } // end anonymous namespace 197 198 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 199 raw_ostream &os) 200 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 201 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 202 fmtCtx.withBuilder("rewriter"); 203 } 204 205 std::string PatternEmitter::handleConstantAttr(Attribute attr, 206 StringRef value) { 207 if (!attr.isConstBuildable()) 208 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 209 " does not have the 'constBuilderCall' field"); 210 211 // TODO: Verify the constants here 212 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value)); 213 } 214 215 // Helper function to match patterns. 216 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 217 Operator &op = tree.getDialectOp(opMap); 218 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 219 << op.getOperationName() << "' at depth " << depth 220 << '\n'); 221 222 int indent = 4 + 2 * depth; 223 os.indent(indent) << formatv( 224 "auto castedOp{0} = ::llvm::dyn_cast_or_null<{1}>(op{0}); " 225 "(void)castedOp{0};\n", 226 depth, op.getQualCppClassName()); 227 // Skip the operand matching at depth 0 as the pattern rewriter already does. 228 if (depth != 0) { 229 // Skip if there is no defining operation (e.g., arguments to function). 230 os << formatv("if (!castedOp{0})\n return failure();\n", depth); 231 } 232 if (tree.getNumArgs() != op.getNumArgs()) { 233 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 234 "pattern vs. {2} in definition", 235 op.getOperationName(), tree.getNumArgs(), 236 op.getNumArgs())); 237 } 238 239 // If the operand's name is set, set to that variable. 240 auto name = tree.getSymbol(); 241 if (!name.empty()) 242 os << formatv("{0} = castedOp{1};\n", name, depth); 243 244 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 245 auto opArg = op.getArg(i); 246 247 // Handle nested DAG construct first 248 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 249 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 250 if (operand->isVariableLength()) { 251 auto error = formatv("use nested DAG construct to match op {0}'s " 252 "variadic operand #{1} unsupported now", 253 op.getOperationName(), i); 254 PrintFatalError(loc, error); 255 } 256 } 257 os << "{\n"; 258 259 os.indent() << formatv( 260 "auto *op{0} = " 261 "(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n", 262 depth + 1, depth, i); 263 emitOpMatch(argTree, depth + 1); 264 os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); 265 os.unindent() << "}\n"; 266 continue; 267 } 268 269 // Next handle DAG leaf: operand or attribute 270 if (opArg.is<NamedTypeConstraint *>()) { 271 emitOperandMatch(tree, i, depth); 272 } else if (opArg.is<NamedAttribute *>()) { 273 emitAttributeMatch(tree, i, depth); 274 } else { 275 PrintFatalError(loc, "unhandled case when matching op"); 276 } 277 } 278 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 279 << op.getOperationName() << "' at depth " << depth 280 << '\n'); 281 } 282 283 void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) { 284 Operator &op = tree.getDialectOp(opMap); 285 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>(); 286 auto matcher = tree.getArgAsLeaf(argIndex); 287 288 // If a constraint is specified, we need to generate C++ statements to 289 // check the constraint. 290 if (!matcher.isUnspecified()) { 291 if (!matcher.isOperandMatcher()) { 292 PrintFatalError( 293 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 294 op.getOperationName(), argIndex + 1)); 295 } 296 297 // Only need to verify if the matcher's type is different from the one 298 // of op definition. 299 Constraint constraint = matcher.getAsConstraint(); 300 if (operand->constraint != constraint) { 301 if (operand->isVariableLength()) { 302 auto error = formatv( 303 "further constrain op {0}'s variadic operand #{1} unsupported now", 304 op.getOperationName(), argIndex); 305 PrintFatalError(loc, error); 306 } 307 auto self = 308 formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth, 309 argIndex); 310 emitMatchCheck( 311 depth, 312 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), 313 formatv("\"operand {0} of op '{1}' failed to satisfy constraint: " 314 "'{2}'\"", 315 operand - op.operand_begin(), op.getOperationName(), 316 constraint.getDescription())); 317 } 318 } 319 320 // Capture the value 321 auto name = tree.getArgName(argIndex); 322 // `$_` is a special symbol to ignore op argument matching. 323 if (!name.empty() && name != "_") { 324 // We need to subtract the number of attributes before this operand to get 325 // the index in the operand list. 326 auto numPrevAttrs = std::count_if( 327 op.arg_begin(), op.arg_begin() + argIndex, 328 [](const Argument &arg) { return arg.is<NamedAttribute *>(); }); 329 330 os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth, 331 argIndex - numPrevAttrs); 332 } 333 } 334 335 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) { 336 Operator &op = tree.getDialectOp(opMap); 337 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>(); 338 const auto &attr = namedAttr->attr; 339 340 os << "{\n"; 341 os.indent() << formatv( 342 "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); " 343 "(void)tblgen_attr;\n", 344 depth, attr.getStorageType(), namedAttr->name); 345 346 // TODO: This should use getter method to avoid duplication. 347 if (attr.hasDefaultValue()) { 348 os << "if (!tblgen_attr) tblgen_attr = " 349 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 350 attr.getDefaultValue())) 351 << ";\n"; 352 } else if (attr.isOptional()) { 353 // For a missing attribute that is optional according to definition, we 354 // should just capture a mlir::Attribute() to signal the missing state. 355 // That is precisely what getAttr() returns on missing attributes. 356 } else { 357 emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx), 358 formatv("\"expected op '{0}' to have attribute '{1}' " 359 "of type '{2}'\"", 360 op.getOperationName(), namedAttr->name, 361 attr.getStorageType())); 362 } 363 364 auto matcher = tree.getArgAsLeaf(argIndex); 365 if (!matcher.isUnspecified()) { 366 if (!matcher.isAttrMatcher()) { 367 PrintFatalError( 368 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 369 op.getOperationName(), argIndex + 1)); 370 } 371 372 // If a constraint is specified, we need to generate C++ statements to 373 // check the constraint. 374 emitMatchCheck( 375 depth, 376 tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")), 377 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " 378 "{2}\"", 379 op.getOperationName(), namedAttr->name, 380 matcher.getAsConstraint().getDescription())); 381 } 382 383 // Capture the value 384 auto name = tree.getArgName(argIndex); 385 // `$_` is a special symbol to ignore op argument matching. 386 if (!name.empty() && name != "_") { 387 os << formatv("{0} = tblgen_attr;\n", name); 388 } 389 390 os.unindent() << "}\n"; 391 } 392 393 void PatternEmitter::emitMatchCheck( 394 int depth, const FmtObjectBase &matchFmt, 395 const llvm::formatv_object_base &failureFmt) { 396 os << "if (!(" << matchFmt.str() << "))"; 397 os.scope("{\n", "\n}\n").os 398 << "return rewriter.notifyMatchFailure(op" << depth 399 << ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureFmt.str() 400 << ";\n});"; 401 } 402 403 void PatternEmitter::emitMatchLogic(DagNode tree) { 404 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 405 int depth = 0; 406 emitOpMatch(tree, depth); 407 408 for (auto &appliedConstraint : pattern.getConstraints()) { 409 auto &constraint = appliedConstraint.constraint; 410 auto &entities = appliedConstraint.entities; 411 412 auto condition = constraint.getConditionTemplate(); 413 if (isa<TypeConstraint>(constraint)) { 414 auto self = formatv("({0}.getType())", 415 symbolInfoMap.getValueAndRangeUse(entities.front())); 416 emitMatchCheck( 417 depth, tgfmt(condition, &fmtCtx.withSelf(self.str())), 418 formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"", 419 entities.front(), constraint.getDescription())); 420 421 } else if (isa<AttrConstraint>(constraint)) { 422 PrintFatalError( 423 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 424 } else { 425 // TODO: replace formatv arguments with the exact specified 426 // args. 427 if (entities.size() > 4) { 428 PrintFatalError(loc, "only support up to 4-entity constraints now"); 429 } 430 SmallVector<std::string, 4> names; 431 int i = 0; 432 for (int e = entities.size(); i < e; ++i) 433 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 434 std::string self = appliedConstraint.self; 435 if (!self.empty()) 436 self = symbolInfoMap.getValueAndRangeUse(self); 437 for (; i < 4; ++i) 438 names.push_back("<unused>"); 439 emitMatchCheck(depth, 440 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 441 names[1], names[2], names[3]), 442 formatv("\"entities '{0}' failed to satisfy constraint: " 443 "{1}\"", 444 llvm::join(entities, ", "), 445 constraint.getDescription())); 446 } 447 } 448 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 449 } 450 451 void PatternEmitter::collectOps(DagNode tree, 452 llvm::SmallPtrSetImpl<const Operator *> &ops) { 453 // Check if this tree is an operation. 454 if (tree.isOperation()) { 455 const Operator &op = tree.getDialectOp(opMap); 456 LLVM_DEBUG(llvm::dbgs() 457 << "found operation " << op.getOperationName() << '\n'); 458 ops.insert(&op); 459 } 460 461 // Recurse the arguments of the tree. 462 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 463 if (auto child = tree.getArgAsNestedDag(i)) 464 collectOps(child, ops); 465 } 466 467 void PatternEmitter::emit(StringRef rewriteName) { 468 // Get the DAG tree for the source pattern. 469 DagNode sourceTree = pattern.getSourcePattern(); 470 471 const Operator &rootOp = pattern.getSourceRootOp(); 472 auto rootName = rootOp.getOperationName(); 473 474 // Collect the set of result operations. 475 llvm::SmallPtrSet<const Operator *, 4> resultOps; 476 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 477 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 478 collectOps(pattern.getResultPattern(i), resultOps); 479 } 480 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 481 482 // Emit RewritePattern for Pattern. 483 auto locs = pattern.getLocation(); 484 os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n", 485 make_range(locs.rbegin(), locs.rend())); 486 os << formatv(R"(struct {0} : public ::mlir::RewritePattern { 487 {0}(::mlir::MLIRContext *context) 488 : ::mlir::RewritePattern("{1}", {{)", 489 rewriteName, rootName); 490 // Sort result operators by name. 491 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 492 resultOps.end()); 493 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 494 return lhs->getOperationName() < rhs->getOperationName(); 495 }); 496 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) { 497 os << '"' << op->getOperationName() << '"'; 498 }); 499 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 500 501 // Emit matchAndRewrite() function. 502 { 503 auto classScope = os.scope(); 504 os.reindent(R"( 505 ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, 506 ::mlir::PatternRewriter &rewriter) const override {)") 507 << '\n'; 508 { 509 auto functionScope = os.scope(); 510 511 // Register all symbols bound in the source pattern. 512 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 513 514 LLVM_DEBUG(llvm::dbgs() 515 << "start creating local variables for capturing matches\n"); 516 os << "// Variables for capturing values and attributes used while " 517 "creating ops\n"; 518 // Create local variables for storing the arguments and results bound 519 // to symbols. 520 for (const auto &symbolInfoPair : symbolInfoMap) { 521 StringRef symbol = symbolInfoPair.getKey(); 522 auto &info = symbolInfoPair.getValue(); 523 os << info.getVarDecl(symbol); 524 } 525 // TODO: capture ops with consistent numbering so that it can be 526 // reused for fused loc. 527 os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n", 528 pattern.getSourcePattern().getNumOps()); 529 LLVM_DEBUG(llvm::dbgs() 530 << "done creating local variables for capturing matches\n"); 531 532 os << "// Match\n"; 533 os << "tblgen_ops[0] = op0;\n"; 534 emitMatchLogic(sourceTree); 535 536 os << "\n// Rewrite\n"; 537 emitRewriteLogic(); 538 539 os << "return ::mlir::success();\n"; 540 } 541 os << "};\n"; 542 } 543 os << "};\n\n"; 544 } 545 546 void PatternEmitter::emitRewriteLogic() { 547 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 548 const Operator &rootOp = pattern.getSourceRootOp(); 549 int numExpectedResults = rootOp.getNumResults(); 550 int numResultPatterns = pattern.getNumResultPatterns(); 551 552 // First register all symbols bound to ops generated in result patterns. 553 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 554 555 // Only the last N static values generated are used to replace the matched 556 // root N-result op. We need to calculate the starting index (of the results 557 // of the matched op) each result pattern is to replace. 558 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 559 // If we don't need to replace any value at all, set the replacement starting 560 // index as the number of result patterns so we skip all of them when trying 561 // to replace the matched op's results. 562 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 563 for (int i = numResultPatterns - 1; i >= 0; --i) { 564 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 565 offsets[i] = offsets[i + 1] - numValues; 566 if (offsets[i] == 0) { 567 if (replStartIndex == -1) 568 replStartIndex = i; 569 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 570 auto error = formatv( 571 "cannot use the same multi-result op '{0}' to generate both " 572 "auxiliary values and values to be used for replacing the matched op", 573 pattern.getResultPattern(i).getSymbol()); 574 PrintFatalError(loc, error); 575 } 576 } 577 578 if (offsets.front() > 0) { 579 const char error[] = "no enough values generated to replace the matched op"; 580 PrintFatalError(loc, error); 581 } 582 583 os << "auto odsLoc = rewriter.getFusedLoc({"; 584 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 585 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 586 } 587 os << "}); (void)odsLoc;\n"; 588 589 // Process auxiliary result patterns. 590 for (int i = 0; i < replStartIndex; ++i) { 591 DagNode resultTree = pattern.getResultPattern(i); 592 auto val = handleResultPattern(resultTree, offsets[i], 0); 593 // Normal op creation will be streamed to `os` by the above call; but 594 // NativeCodeCall will only be materialized to `os` if it is used. Here 595 // we are handling auxiliary patterns so we want the side effect even if 596 // NativeCodeCall is not replacing matched root op's results. 597 if (resultTree.isNativeCodeCall()) 598 os << val << ";\n"; 599 } 600 601 if (numExpectedResults == 0) { 602 assert(replStartIndex >= numResultPatterns && 603 "invalid auxiliary vs. replacement pattern division!"); 604 // No result to replace. Just erase the op. 605 os << "rewriter.eraseOp(op0);\n"; 606 } else { 607 // Process replacement result patterns. 608 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n"; 609 for (int i = replStartIndex; i < numResultPatterns; ++i) { 610 DagNode resultTree = pattern.getResultPattern(i); 611 auto val = handleResultPattern(resultTree, offsets[i], 0); 612 os << "\n"; 613 // Resolve each symbol for all range use so that we can loop over them. 614 // We need an explicit cast to `SmallVector` to capture the cases where 615 // `{0}` resolves to an `Operation::result_range` as well as cases that 616 // are not iterable (e.g. vector that gets wrapped in additional braces by 617 // RewriterGen). 618 // TODO: Revisit the need for materializing a vector. 619 os << symbolInfoMap.getAllRangeUse( 620 val, 621 "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" 622 " tblgen_repl_values.push_back(v);\n}\n", 623 "\n"); 624 } 625 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n"; 626 } 627 628 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 629 } 630 631 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 632 return std::string( 633 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++)); 634 } 635 636 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 637 int resultIndex, int depth) { 638 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 639 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 640 LLVM_DEBUG(llvm::dbgs() << '\n'); 641 642 if (resultTree.isLocationDirective()) { 643 PrintFatalError(loc, 644 "location directive can only be used with op creation"); 645 } 646 647 if (resultTree.isNativeCodeCall()) { 648 auto symbol = handleReplaceWithNativeCodeCall(resultTree); 649 symbolInfoMap.bindValue(symbol); 650 return symbol; 651 } 652 653 if (resultTree.isReplaceWithValue()) 654 return handleReplaceWithValue(resultTree).str(); 655 656 // Normal op creation. 657 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 658 if (resultTree.getSymbol().empty()) { 659 // This is an op not explicitly bound to a symbol in the rewrite rule. 660 // Register the auto-generated symbol for it. 661 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 662 } 663 return symbol; 664 } 665 666 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { 667 assert(tree.isReplaceWithValue()); 668 669 if (tree.getNumArgs() != 1) { 670 PrintFatalError( 671 loc, "replaceWithValue directive must take exactly one argument"); 672 } 673 674 if (!tree.getSymbol().empty()) { 675 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 676 } 677 678 return tree.getArgName(0); 679 } 680 681 std::string PatternEmitter::handleLocationDirective(DagNode tree) { 682 assert(tree.isLocationDirective()); 683 auto lookUpArgLoc = [this, &tree](int idx) { 684 const auto *const lookupFmt = "(*{0}.begin()).getLoc()"; 685 return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt); 686 }; 687 688 if (tree.getNumArgs() == 0) 689 llvm::PrintFatalError( 690 "At least one argument to location directive required"); 691 692 if (!tree.getSymbol().empty()) 693 PrintFatalError(loc, "cannot bind symbol to location"); 694 695 if (tree.getNumArgs() == 1) { 696 DagLeaf leaf = tree.getArgAsLeaf(0); 697 if (leaf.isStringAttr()) 698 return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), " 699 "rewriter.getContext())", 700 leaf.getStringAttr()) 701 .str(); 702 return lookUpArgLoc(0); 703 } 704 705 std::string ret; 706 llvm::raw_string_ostream os(ret); 707 std::string strAttr; 708 os << "rewriter.getFusedLoc({"; 709 bool first = true; 710 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 711 DagLeaf leaf = tree.getArgAsLeaf(i); 712 // Handle the optional string value. 713 if (leaf.isStringAttr()) { 714 if (!strAttr.empty()) 715 llvm::PrintFatalError("Only one string attribute may be specified"); 716 strAttr = leaf.getStringAttr(); 717 continue; 718 } 719 os << (first ? "" : ", ") << lookUpArgLoc(i); 720 first = false; 721 } 722 os << "}"; 723 if (!strAttr.empty()) { 724 os << ", rewriter.getStringAttr(\"" << strAttr << "\")"; 725 } 726 os << ")"; 727 return os.str(); 728 } 729 730 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 731 StringRef patArgName) { 732 if (leaf.isStringAttr()) 733 PrintFatalError(loc, "raw string not supported as argument"); 734 if (leaf.isConstantAttr()) { 735 auto constAttr = leaf.getAsConstantAttr(); 736 return handleConstantAttr(constAttr.getAttribute(), 737 constAttr.getConstantValue()); 738 } 739 if (leaf.isEnumAttrCase()) { 740 auto enumCase = leaf.getAsEnumAttrCase(); 741 if (enumCase.isStrCase()) 742 return handleConstantAttr(enumCase, enumCase.getSymbol()); 743 // This is an enum case backed by an IntegerAttr. We need to get its value 744 // to build the constant. 745 std::string val = std::to_string(enumCase.getValue()); 746 return handleConstantAttr(enumCase, val); 747 } 748 749 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 750 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 751 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 752 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 753 << "' (via symbol ref)\n"); 754 return argName; 755 } 756 if (leaf.isNativeCodeCall()) { 757 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 758 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 759 << "' (via NativeCodeCall)\n"); 760 return std::string(repl); 761 } 762 PrintFatalError(loc, "unhandled case when rewriting op"); 763 } 764 765 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { 766 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 767 LLVM_DEBUG(tree.print(llvm::dbgs())); 768 LLVM_DEBUG(llvm::dbgs() << '\n'); 769 770 auto fmt = tree.getNativeCodeTemplate(); 771 // TODO: replace formatv arguments with the exact specified args. 772 SmallVector<std::string, 8> attrs(8); 773 if (tree.getNumArgs() > 8) { 774 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 775 Twine(tree.getNumArgs())); 776 } 777 bool hasLocationDirective; 778 std::string locToUse; 779 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 780 781 for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { 782 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 783 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i 784 << " replacement: " << attrs[i] << "\n"); 785 } 786 return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0], 787 attrs[1], attrs[2], attrs[3], attrs[4], attrs[5], 788 attrs[6], attrs[7])); 789 } 790 791 int PatternEmitter::getNodeValueCount(DagNode node) { 792 if (node.isOperation()) { 793 // If the op is bound to a symbol in the rewrite rule, query its result 794 // count from the symbol info map. 795 auto symbol = node.getSymbol(); 796 if (!symbol.empty()) { 797 return symbolInfoMap.getStaticValueCount(symbol); 798 } 799 // Otherwise this is an unbound op; we will use all its results. 800 return pattern.getDialectOp(node).getNumResults(); 801 } 802 // TODO: This considers all NativeCodeCall as returning one 803 // value. Enhance if multi-value ones are needed. 804 return 1; 805 } 806 807 std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) { 808 auto numPatArgs = tree.getNumArgs(); 809 810 if (numPatArgs != 0) { 811 if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1)) 812 if (lastArg.isLocationDirective()) { 813 return std::make_pair(true, handleLocationDirective(lastArg)); 814 } 815 } 816 817 // If no explicit location is given, use the default, all fused, location. 818 return std::make_pair(false, "odsLoc"); 819 } 820 821 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 822 int depth) { 823 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); 824 LLVM_DEBUG(tree.print(llvm::dbgs())); 825 LLVM_DEBUG(llvm::dbgs() << '\n'); 826 827 Operator &resultOp = tree.getDialectOp(opMap); 828 auto numOpArgs = resultOp.getNumArgs(); 829 auto numPatArgs = tree.getNumArgs(); 830 831 bool hasLocationDirective; 832 std::string locToUse; 833 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 834 835 auto inPattern = numPatArgs - hasLocationDirective; 836 if (numOpArgs != inPattern) { 837 PrintFatalError(loc, 838 formatv("resultant op '{0}' argument number mismatch: " 839 "{1} in pattern vs. {2} in definition", 840 resultOp.getOperationName(), inPattern, numOpArgs)); 841 } 842 843 // A map to collect all nested DAG child nodes' names, with operand index as 844 // the key. This includes both bound and unbound child nodes. 845 ChildNodeIndexNameMap childNodeNames; 846 847 // First go through all the child nodes who are nested DAG constructs to 848 // create ops for them and remember the symbol names for them, so that we can 849 // use the results in the current node. This happens in a recursive manner. 850 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 851 if (auto child = tree.getArgAsNestedDag(i)) 852 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 853 } 854 855 // The name of the local variable holding this op. 856 std::string valuePackName; 857 // The symbol for holding the result of this pattern. Note that the result of 858 // this pattern is not necessarily the same as the variable created by this 859 // pattern because we can use `__N` suffix to refer only a specific result if 860 // the generated op is a multi-result op. 861 std::string resultValue; 862 if (tree.getSymbol().empty()) { 863 // No symbol is explicitly bound to this op in the pattern. Generate a 864 // unique name. 865 valuePackName = resultValue = getUniqueSymbol(&resultOp); 866 } else { 867 resultValue = std::string(tree.getSymbol()); 868 // Strip the index to get the name for the value pack and use it to name the 869 // local variable for the op. 870 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue)); 871 } 872 873 // Create the local variable for this op. 874 os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(), 875 valuePackName); 876 877 // Right now ODS don't have general type inference support. Except a few 878 // special cases listed below, DRR needs to supply types for all results 879 // when building an op. 880 bool isSameOperandsAndResultType = 881 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType"); 882 bool useFirstAttr = 883 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"); 884 885 if (isSameOperandsAndResultType || useFirstAttr) { 886 // We know how to deduce the result type for ops with these traits and we've 887 // generated builders taking aggregate parameters. Use those builders to 888 // create the ops. 889 890 // First prepare local variables for op arguments used in builder call. 891 createAggregateLocalVarsForOpArgs(tree, childNodeNames); 892 893 // Then create the op. 894 os.scope("", "\n}\n").os << formatv( 895 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);", 896 valuePackName, resultOp.getQualCppClassName(), locToUse); 897 return resultValue; 898 } 899 900 bool usePartialResults = valuePackName != resultValue; 901 902 if (usePartialResults || depth > 0 || resultIndex < 0) { 903 // For these cases (broadcastable ops, op results used both as auxiliary 904 // values and replacement values, ops in nested patterns, auxiliary ops), we 905 // still need to supply the result types when building the op. But because 906 // we don't generate a builder automatically with ODS for them, it's the 907 // developer's responsibility to make sure such a builder (with result type 908 // deduction ability) exists. We go through the separate-parameter builder 909 // here given that it's easier for developers to write compared to 910 // aggregate-parameter builders. 911 createSeparateLocalVarsForOpArgs(tree, childNodeNames); 912 913 os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, 914 resultOp.getQualCppClassName(), locToUse); 915 supplyValuesForOpArgs(tree, childNodeNames); 916 os << "\n );\n}\n"; 917 return resultValue; 918 } 919 920 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 921 // generated from the source pattern root op. Then we can use the source 922 // pattern's value types to determine the value type of the generated op 923 // here. 924 925 // First prepare local variables for op arguments used in builder call. 926 createAggregateLocalVarsForOpArgs(tree, childNodeNames); 927 928 // Then prepare the result types. We need to specify the types for all 929 // results. 930 os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; " 931 "(void)tblgen_types;\n"); 932 int numResults = resultOp.getNumResults(); 933 if (numResults != 0) { 934 for (int i = 0; i < numResults; ++i) 935 os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" 936 " tblgen_types.push_back(v.getType());\n}\n", 937 resultIndex + i); 938 } 939 os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " 940 "tblgen_values, tblgen_attrs);\n", 941 valuePackName, resultOp.getQualCppClassName(), locToUse); 942 os.unindent() << "}\n"; 943 return resultValue; 944 } 945 946 void PatternEmitter::createSeparateLocalVarsForOpArgs( 947 DagNode node, ChildNodeIndexNameMap &childNodeNames) { 948 Operator &resultOp = node.getDialectOp(opMap); 949 950 // Now prepare operands used for building this op: 951 // * If the operand is non-variadic, we create a `Value` local variable. 952 // * If the operand is variadic, we create a `SmallVector<Value>` local 953 // variable. 954 955 int valueIndex = 0; // An index for uniquing local variable names. 956 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 957 const auto *operand = 958 resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>(); 959 // We do not need special handling for attributes. 960 if (!operand) 961 continue; 962 963 raw_indented_ostream::DelimitedScope scope(os); 964 std::string varName; 965 if (operand->isVariadic()) { 966 varName = std::string(formatv("tblgen_values_{0}", valueIndex++)); 967 os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName); 968 std::string range; 969 if (node.isNestedDagArg(argIndex)) { 970 range = childNodeNames[argIndex]; 971 } else { 972 range = std::string(node.getArgName(argIndex)); 973 } 974 // Resolve the symbol for all range use so that we have a uniform way of 975 // capturing the values. 976 range = symbolInfoMap.getValueAndRangeUse(range); 977 os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range, 978 varName); 979 } else { 980 varName = std::string(formatv("tblgen_value_{0}", valueIndex++)); 981 os << formatv("::mlir::Value {0} = ", varName); 982 if (node.isNestedDagArg(argIndex)) { 983 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 984 } else { 985 DagLeaf leaf = node.getArgAsLeaf(argIndex); 986 auto symbol = 987 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 988 if (leaf.isNativeCodeCall()) { 989 os << std::string( 990 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 991 } else { 992 os << symbol; 993 } 994 } 995 os << ";\n"; 996 } 997 998 // Update to use the newly created local variable for building the op later. 999 childNodeNames[argIndex] = varName; 1000 } 1001 } 1002 1003 void PatternEmitter::supplyValuesForOpArgs( 1004 DagNode node, const ChildNodeIndexNameMap &childNodeNames) { 1005 Operator &resultOp = node.getDialectOp(opMap); 1006 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); 1007 argIndex != numOpArgs; ++argIndex) { 1008 // Start each argument on its own line. 1009 os << ",\n "; 1010 1011 Argument opArg = resultOp.getArg(argIndex); 1012 // Handle the case of operand first. 1013 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 1014 if (!operand->name.empty()) 1015 os << "/*" << operand->name << "=*/"; 1016 os << childNodeNames.lookup(argIndex); 1017 continue; 1018 } 1019 1020 // The argument in the op definition. 1021 auto opArgName = resultOp.getArgName(argIndex); 1022 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1023 if (!subTree.isNativeCodeCall()) 1024 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1025 "for creating attribute"); 1026 os << formatv("/*{0}=*/{1}", opArgName, 1027 handleReplaceWithNativeCodeCall(subTree)); 1028 } else { 1029 auto leaf = node.getArgAsLeaf(argIndex); 1030 // The argument in the result DAG pattern. 1031 auto patArgName = node.getArgName(argIndex); 1032 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 1033 // TODO: Refactor out into map to avoid recomputing these. 1034 if (!opArg.is<NamedAttribute *>()) 1035 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 1036 if (!patArgName.empty()) 1037 os << "/*" << patArgName << "=*/"; 1038 } else { 1039 os << "/*" << opArgName << "=*/"; 1040 } 1041 os << handleOpArgument(leaf, patArgName); 1042 } 1043 } 1044 } 1045 1046 void PatternEmitter::createAggregateLocalVarsForOpArgs( 1047 DagNode node, const ChildNodeIndexNameMap &childNodeNames) { 1048 Operator &resultOp = node.getDialectOp(opMap); 1049 1050 auto scope = os.scope(); 1051 os << formatv("::mlir::SmallVector<::mlir::Value, 4> " 1052 "tblgen_values; (void)tblgen_values;\n"); 1053 os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> " 1054 "tblgen_attrs; (void)tblgen_attrs;\n"); 1055 1056 const char *addAttrCmd = 1057 "if (auto tmpAttr = {1}) {\n" 1058 " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), " 1059 "tmpAttr);\n}\n"; 1060 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1061 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) { 1062 // The argument in the op definition. 1063 auto opArgName = resultOp.getArgName(argIndex); 1064 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1065 if (!subTree.isNativeCodeCall()) 1066 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1067 "for creating attribute"); 1068 os << formatv(addAttrCmd, opArgName, 1069 handleReplaceWithNativeCodeCall(subTree)); 1070 } else { 1071 auto leaf = node.getArgAsLeaf(argIndex); 1072 // The argument in the result DAG pattern. 1073 auto patArgName = node.getArgName(argIndex); 1074 os << formatv(addAttrCmd, opArgName, 1075 handleOpArgument(leaf, patArgName)); 1076 } 1077 continue; 1078 } 1079 1080 const auto *operand = 1081 resultOp.getArg(argIndex).get<NamedTypeConstraint *>(); 1082 std::string varName; 1083 if (operand->isVariadic()) { 1084 std::string range; 1085 if (node.isNestedDagArg(argIndex)) { 1086 range = childNodeNames.lookup(argIndex); 1087 } else { 1088 range = std::string(node.getArgName(argIndex)); 1089 } 1090 // Resolve the symbol for all range use so that we have a uniform way of 1091 // capturing the values. 1092 range = symbolInfoMap.getValueAndRangeUse(range); 1093 os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n", 1094 range); 1095 } else { 1096 os << formatv("tblgen_values.push_back(", varName); 1097 if (node.isNestedDagArg(argIndex)) { 1098 os << symbolInfoMap.getValueAndRangeUse( 1099 childNodeNames.lookup(argIndex)); 1100 } else { 1101 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1102 auto symbol = 1103 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1104 if (leaf.isNativeCodeCall()) { 1105 os << std::string( 1106 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1107 } else { 1108 os << symbol; 1109 } 1110 } 1111 os << ");\n"; 1112 } 1113 } 1114 } 1115 1116 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 1117 emitSourceFileHeader("Rewriters", os); 1118 1119 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 1120 auto numPatterns = patterns.size(); 1121 1122 // We put the map here because it can be shared among multiple patterns. 1123 RecordOperatorMap recordOpMap; 1124 1125 std::vector<std::string> rewriterNames; 1126 rewriterNames.reserve(numPatterns); 1127 1128 std::string baseRewriterName = "GeneratedConvert"; 1129 int rewriterIndex = 0; 1130 1131 for (Record *p : patterns) { 1132 std::string name; 1133 if (p->isAnonymous()) { 1134 // If no name is provided, ensure unique rewriter names simply by 1135 // appending unique suffix. 1136 name = baseRewriterName + llvm::utostr(rewriterIndex++); 1137 } else { 1138 name = std::string(p->getName()); 1139 } 1140 LLVM_DEBUG(llvm::dbgs() 1141 << "=== start generating pattern '" << name << "' ===\n"); 1142 PatternEmitter(p, &recordOpMap, os).emit(name); 1143 LLVM_DEBUG(llvm::dbgs() 1144 << "=== done generating pattern '" << name << "' ===\n"); 1145 rewriterNames.push_back(std::move(name)); 1146 } 1147 1148 // Emit function to add the generated matchers to the pattern list. 1149 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::MLIRContext " 1150 "*context, ::mlir::OwningRewritePatternList *patterns) {\n"; 1151 for (const auto &name : rewriterNames) { 1152 os << " patterns->insert<" << name << ">(context);\n"; 1153 } 1154 os << "}\n"; 1155 } 1156 1157 static mlir::GenRegistration 1158 genRewriters("gen-rewriters", "Generate pattern rewriters", 1159 [](const RecordKeeper &records, raw_ostream &os) { 1160 emitRewriters(records, os); 1161 return false; 1162 }); 1163