1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/TableGen/Attribute.h" 14 #include "mlir/TableGen/Format.h" 15 #include "mlir/TableGen/GenInfo.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "mlir/TableGen/Pattern.h" 18 #include "mlir/TableGen/Predicate.h" 19 #include "mlir/TableGen/Type.h" 20 #include "llvm/ADT/StringExtras.h" 21 #include "llvm/ADT/StringSet.h" 22 #include "llvm/Support/CommandLine.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/FormatAdapters.h" 25 #include "llvm/Support/PrettyStackTrace.h" 26 #include "llvm/Support/Signals.h" 27 #include "llvm/TableGen/Error.h" 28 #include "llvm/TableGen/Main.h" 29 #include "llvm/TableGen/Record.h" 30 #include "llvm/TableGen/TableGenBackend.h" 31 32 using namespace mlir; 33 using namespace mlir::tblgen; 34 35 using llvm::formatv; 36 using llvm::Record; 37 using llvm::RecordKeeper; 38 39 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 40 41 namespace llvm { 42 template <> 43 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 44 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 45 raw_ostream &os, StringRef style) { 46 os << v.first << ":" << v.second; 47 } 48 }; 49 } // end namespace llvm 50 51 //===----------------------------------------------------------------------===// 52 // PatternEmitter 53 //===----------------------------------------------------------------------===// 54 55 namespace { 56 class PatternEmitter { 57 public: 58 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 59 60 // Emits the mlir::RewritePattern struct named `rewriteName`. 61 void emit(StringRef rewriteName); 62 63 private: 64 // Emits the code for matching ops. 65 void emitMatchLogic(DagNode tree); 66 67 // Emits the code for rewriting ops. 68 void emitRewriteLogic(); 69 70 //===--------------------------------------------------------------------===// 71 // Match utilities 72 //===--------------------------------------------------------------------===// 73 74 // Emits C++ statements for matching the op constrained by the given DAG 75 // `tree`. 76 void emitOpMatch(DagNode tree, int depth); 77 78 // Emits C++ statements for matching the `argIndex`-th argument of the given 79 // DAG `tree` as an operand. 80 void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent); 81 82 // Emits C++ statements for matching the `argIndex`-th argument of the given 83 // DAG `tree` as an attribute. 84 void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent); 85 86 // Emits C++ for checking a match with a corresponding match failure 87 // diagnostic. 88 void emitMatchCheck(int depth, const FmtObjectBase &matchFmt, 89 const llvm::formatv_object_base &failureFmt); 90 91 //===--------------------------------------------------------------------===// 92 // Rewrite utilities 93 //===--------------------------------------------------------------------===// 94 95 // The entry point for handling a result pattern rooted at `resultTree`. This 96 // method dispatches to concrete handlers according to `resultTree`'s kind and 97 // returns a symbol representing the whole value pack. Callers are expected to 98 // further resolve the symbol according to the specific use case. 99 // 100 // `depth` is the nesting level of `resultTree`; 0 means top-level result 101 // pattern. For top-level result pattern, `resultIndex` indicates which result 102 // of the matched root op this pattern is intended to replace, which can be 103 // used to deduce the result type of the op generated from this result 104 // pattern. 105 std::string handleResultPattern(DagNode resultTree, int resultIndex, 106 int depth); 107 108 // Emits the C++ statement to replace the matched DAG with a value built via 109 // calling native C++ code. 110 std::string handleReplaceWithNativeCodeCall(DagNode resultTree); 111 112 // Returns the symbol of the old value serving as the replacement. 113 StringRef handleReplaceWithValue(DagNode tree); 114 115 // Returns the location value to use. 116 std::pair<bool, std::string> getLocation(DagNode tree); 117 118 // Returns the location value to use. 119 std::string handleLocationDirective(DagNode tree); 120 121 // Emits the C++ statement to build a new op out of the given DAG `tree` and 122 // returns the variable name that this op is assigned to. If the root op in 123 // DAG `tree` has a specified name, the created op will be assigned to a 124 // variable of the given name. Otherwise, a unique name will be used as the 125 // result value name. 126 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 127 128 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; 129 130 // Emits a local variable for each value and attribute to be used for creating 131 // an op. 132 void createSeparateLocalVarsForOpArgs(DagNode node, 133 ChildNodeIndexNameMap &childNodeNames); 134 135 // Emits the concrete arguments used to call an op's builder. 136 void supplyValuesForOpArgs(DagNode node, 137 const ChildNodeIndexNameMap &childNodeNames); 138 139 // Emits the local variables for holding all values as a whole and all named 140 // attributes as a whole to be used for creating an op. 141 void createAggregateLocalVarsForOpArgs( 142 DagNode node, const ChildNodeIndexNameMap &childNodeNames); 143 144 // Returns the C++ expression to construct a constant attribute of the given 145 // `value` for the given attribute kind `attr`. 146 std::string handleConstantAttr(Attribute attr, StringRef value); 147 148 // Returns the C++ expression to build an argument from the given DAG `leaf`. 149 // `patArgName` is used to bound the argument to the source pattern. 150 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 151 152 //===--------------------------------------------------------------------===// 153 // General utilities 154 //===--------------------------------------------------------------------===// 155 156 // Collects all of the operations within the given dag tree. 157 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 158 159 // Returns a unique symbol for a local variable of the given `op`. 160 std::string getUniqueSymbol(const Operator *op); 161 162 //===--------------------------------------------------------------------===// 163 // Symbol utilities 164 //===--------------------------------------------------------------------===// 165 166 // Returns how many static values the given DAG `node` correspond to. 167 int getNodeValueCount(DagNode node); 168 169 private: 170 // Pattern instantiation location followed by the location of multiclass 171 // prototypes used. This is intended to be used as a whole to 172 // PrintFatalError() on errors. 173 ArrayRef<llvm::SMLoc> loc; 174 175 // Op's TableGen Record to wrapper object. 176 RecordOperatorMap *opMap; 177 178 // Handy wrapper for pattern being emitted. 179 Pattern pattern; 180 181 // Map for all bound symbols' info. 182 SymbolInfoMap symbolInfoMap; 183 184 // The next unused ID for newly created values. 185 unsigned nextValueId; 186 187 raw_ostream &os; 188 189 // Format contexts containing placeholder substitutions. 190 FmtContext fmtCtx; 191 192 // Number of op processed. 193 int opCounter = 0; 194 }; 195 } // end anonymous namespace 196 197 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 198 raw_ostream &os) 199 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 200 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 201 fmtCtx.withBuilder("rewriter"); 202 } 203 204 std::string PatternEmitter::handleConstantAttr(Attribute attr, 205 StringRef value) { 206 if (!attr.isConstBuildable()) 207 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 208 " does not have the 'constBuilderCall' field"); 209 210 // TODO: Verify the constants here 211 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value)); 212 } 213 214 // Helper function to match patterns. 215 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 216 Operator &op = tree.getDialectOp(opMap); 217 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 218 << op.getOperationName() << "' at depth " << depth 219 << '\n'); 220 221 int indent = 4 + 2 * depth; 222 os.indent(indent) << formatv( 223 "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n", 224 depth, op.getQualCppClassName()); 225 // Skip the operand matching at depth 0 as the pattern rewriter already does. 226 if (depth != 0) { 227 // Skip if there is no defining operation (e.g., arguments to function). 228 os.indent(indent) << formatv("if (!castedOp{0}) return failure();\n", 229 depth); 230 } 231 if (tree.getNumArgs() != op.getNumArgs()) { 232 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 233 "pattern vs. {2} in definition", 234 op.getOperationName(), tree.getNumArgs(), 235 op.getNumArgs())); 236 } 237 238 // If the operand's name is set, set to that variable. 239 auto name = tree.getSymbol(); 240 if (!name.empty()) 241 os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth); 242 243 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 244 auto opArg = op.getArg(i); 245 246 // Handle nested DAG construct first 247 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 248 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 249 if (operand->isVariableLength()) { 250 auto error = formatv("use nested DAG construct to match op {0}'s " 251 "variadic operand #{1} unsupported now", 252 op.getOperationName(), i); 253 PrintFatalError(loc, error); 254 } 255 } 256 os.indent(indent) << "{\n"; 257 258 os.indent(indent + 2) << formatv( 259 "auto *op{0} = " 260 "(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n", 261 depth + 1, depth, i); 262 emitOpMatch(argTree, depth + 1); 263 os.indent(indent + 2) 264 << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); 265 os.indent(indent) << "}\n"; 266 continue; 267 } 268 269 // Next handle DAG leaf: operand or attribute 270 if (opArg.is<NamedTypeConstraint *>()) { 271 emitOperandMatch(tree, i, depth, indent); 272 } else if (opArg.is<NamedAttribute *>()) { 273 emitAttributeMatch(tree, i, depth, indent); 274 } else { 275 PrintFatalError(loc, "unhandled case when matching op"); 276 } 277 } 278 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 279 << op.getOperationName() << "' at depth " << depth 280 << '\n'); 281 } 282 283 void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth, 284 int indent) { 285 Operator &op = tree.getDialectOp(opMap); 286 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>(); 287 auto matcher = tree.getArgAsLeaf(argIndex); 288 289 // If a constraint is specified, we need to generate C++ statements to 290 // check the constraint. 291 if (!matcher.isUnspecified()) { 292 if (!matcher.isOperandMatcher()) { 293 PrintFatalError( 294 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 295 op.getOperationName(), argIndex + 1)); 296 } 297 298 // Only need to verify if the matcher's type is different from the one 299 // of op definition. 300 Constraint constraint = matcher.getAsConstraint(); 301 if (operand->constraint != constraint) { 302 if (operand->isVariableLength()) { 303 auto error = formatv( 304 "further constrain op {0}'s variadic operand #{1} unsupported now", 305 op.getOperationName(), argIndex); 306 PrintFatalError(loc, error); 307 } 308 auto self = 309 formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth, 310 argIndex); 311 emitMatchCheck( 312 depth, 313 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), 314 formatv("\"operand {0} of op '{1}' failed to satisfy constraint: " 315 "'{2}'\"", 316 operand - op.operand_begin(), op.getOperationName(), 317 constraint.getDescription())); 318 } 319 } 320 321 // Capture the value 322 auto name = tree.getArgName(argIndex); 323 // `$_` is a special symbol to ignore op argument matching. 324 if (!name.empty() && name != "_") { 325 // We need to subtract the number of attributes before this operand to get 326 // the index in the operand list. 327 auto numPrevAttrs = std::count_if( 328 op.arg_begin(), op.arg_begin() + argIndex, 329 [](const Argument &arg) { return arg.is<NamedAttribute *>(); }); 330 331 os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n", 332 name, depth, argIndex - numPrevAttrs); 333 } 334 } 335 336 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth, 337 int indent) { 338 Operator &op = tree.getDialectOp(opMap); 339 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>(); 340 const auto &attr = namedAttr->attr; 341 342 os.indent(indent) << "{\n"; 343 indent += 2; 344 os.indent(indent) << formatv( 345 "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");" 346 "(void)tblgen_attr;\n", 347 depth, attr.getStorageType(), namedAttr->name); 348 349 // TODO: This should use getter method to avoid duplication. 350 if (attr.hasDefaultValue()) { 351 os.indent(indent) << "if (!tblgen_attr) tblgen_attr = " 352 << std::string(tgfmt(attr.getConstBuilderTemplate(), 353 &fmtCtx, attr.getDefaultValue())) 354 << ";\n"; 355 } else if (attr.isOptional()) { 356 // For a missing attribute that is optional according to definition, we 357 // should just capture a mlir::Attribute() to signal the missing state. 358 // That is precisely what getAttr() returns on missing attributes. 359 } else { 360 emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx), 361 formatv("\"expected op '{0}' to have attribute '{1}' " 362 "of type '{2}'\"", 363 op.getOperationName(), namedAttr->name, 364 attr.getStorageType())); 365 } 366 367 auto matcher = tree.getArgAsLeaf(argIndex); 368 if (!matcher.isUnspecified()) { 369 if (!matcher.isAttrMatcher()) { 370 PrintFatalError( 371 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 372 op.getOperationName(), argIndex + 1)); 373 } 374 375 // If a constraint is specified, we need to generate C++ statements to 376 // check the constraint. 377 emitMatchCheck( 378 depth, 379 tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")), 380 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " 381 "{2}\"", 382 op.getOperationName(), namedAttr->name, 383 matcher.getAsConstraint().getDescription())); 384 } 385 386 // Capture the value 387 auto name = tree.getArgName(argIndex); 388 // `$_` is a special symbol to ignore op argument matching. 389 if (!name.empty() && name != "_") { 390 os.indent(indent) << formatv("{0} = tblgen_attr;\n", name); 391 } 392 393 indent -= 2; 394 os.indent(indent) << "}\n"; 395 } 396 397 void PatternEmitter::emitMatchCheck( 398 int depth, const FmtObjectBase &matchFmt, 399 const llvm::formatv_object_base &failureFmt) { 400 // {0} The match depth (used to get the operation that failed to match). 401 // {1} The format for the match string. 402 // {2} The format for the failure string. 403 const char *matchStr = R"( 404 if (!({1})) { 405 return rewriter.notifyMatchFailure(op{0}, [&](::mlir::Diagnostic &diag) { 406 diag << {2}; 407 }); 408 })"; 409 os << llvm::formatv(matchStr, depth, matchFmt.str(), failureFmt.str()) 410 << "\n"; 411 } 412 413 void PatternEmitter::emitMatchLogic(DagNode tree) { 414 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 415 int depth = 0; 416 emitOpMatch(tree, depth); 417 418 for (auto &appliedConstraint : pattern.getConstraints()) { 419 auto &constraint = appliedConstraint.constraint; 420 auto &entities = appliedConstraint.entities; 421 422 auto condition = constraint.getConditionTemplate(); 423 if (isa<TypeConstraint>(constraint)) { 424 auto self = formatv("({0}.getType())", 425 symbolInfoMap.getValueAndRangeUse(entities.front())); 426 emitMatchCheck( 427 depth, tgfmt(condition, &fmtCtx.withSelf(self.str())), 428 formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"", 429 entities.front(), constraint.getDescription())); 430 431 } else if (isa<AttrConstraint>(constraint)) { 432 PrintFatalError( 433 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 434 } else { 435 // TODO: replace formatv arguments with the exact specified 436 // args. 437 if (entities.size() > 4) { 438 PrintFatalError(loc, "only support up to 4-entity constraints now"); 439 } 440 SmallVector<std::string, 4> names; 441 int i = 0; 442 for (int e = entities.size(); i < e; ++i) 443 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 444 std::string self = appliedConstraint.self; 445 if (!self.empty()) 446 self = symbolInfoMap.getValueAndRangeUse(self); 447 for (; i < 4; ++i) 448 names.push_back("<unused>"); 449 emitMatchCheck(depth, 450 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 451 names[1], names[2], names[3]), 452 formatv("\"entities '{0}' failed to satisfy constraint: " 453 "{1}\"", 454 llvm::join(entities, ", "), 455 constraint.getDescription())); 456 } 457 } 458 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 459 } 460 461 void PatternEmitter::collectOps(DagNode tree, 462 llvm::SmallPtrSetImpl<const Operator *> &ops) { 463 // Check if this tree is an operation. 464 if (tree.isOperation()) { 465 const Operator &op = tree.getDialectOp(opMap); 466 LLVM_DEBUG(llvm::dbgs() 467 << "found operation " << op.getOperationName() << '\n'); 468 ops.insert(&op); 469 } 470 471 // Recurse the arguments of the tree. 472 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 473 if (auto child = tree.getArgAsNestedDag(i)) 474 collectOps(child, ops); 475 } 476 477 void PatternEmitter::emit(StringRef rewriteName) { 478 // Get the DAG tree for the source pattern. 479 DagNode sourceTree = pattern.getSourcePattern(); 480 481 const Operator &rootOp = pattern.getSourceRootOp(); 482 auto rootName = rootOp.getOperationName(); 483 484 // Collect the set of result operations. 485 llvm::SmallPtrSet<const Operator *, 4> resultOps; 486 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 487 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 488 collectOps(pattern.getResultPattern(i), resultOps); 489 } 490 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 491 492 // Emit RewritePattern for Pattern. 493 auto locs = pattern.getLocation(); 494 os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", 495 make_range(locs.rbegin(), locs.rend())); 496 os << formatv(R"(struct {0} : public ::mlir::RewritePattern { 497 {0}(::mlir::MLIRContext *context) 498 : ::mlir::RewritePattern("{1}", {{)", 499 rewriteName, rootName); 500 // Sort result operators by name. 501 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 502 resultOps.end()); 503 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 504 return lhs->getOperationName() < rhs->getOperationName(); 505 }); 506 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) { 507 os << '"' << op->getOperationName() << '"'; 508 }); 509 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 510 511 // Emit matchAndRewrite() function. 512 os << R"( 513 ::mlir::LogicalResult 514 matchAndRewrite(::mlir::Operation *op0, 515 ::mlir::PatternRewriter &rewriter) const override { 516 )"; 517 518 // Register all symbols bound in the source pattern. 519 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 520 521 LLVM_DEBUG( 522 llvm::dbgs() << "start creating local variables for capturing matches\n"); 523 os.indent(4) << "// Variables for capturing values and attributes used for " 524 "creating ops\n"; 525 // Create local variables for storing the arguments and results bound 526 // to symbols. 527 for (const auto &symbolInfoPair : symbolInfoMap) { 528 StringRef symbol = symbolInfoPair.getKey(); 529 auto &info = symbolInfoPair.getValue(); 530 os.indent(4) << info.getVarDecl(symbol); 531 } 532 // TODO: capture ops with consistent numbering so that it can be 533 // reused for fused loc. 534 os.indent(4) << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n", 535 pattern.getSourcePattern().getNumOps()); 536 LLVM_DEBUG( 537 llvm::dbgs() << "done creating local variables for capturing matches\n"); 538 539 os.indent(4) << "// Match\n"; 540 os.indent(4) << "tblgen_ops[0] = op0;\n"; 541 emitMatchLogic(sourceTree); 542 os << "\n"; 543 544 os.indent(4) << "// Rewrite\n"; 545 emitRewriteLogic(); 546 547 os.indent(4) << "return success();\n"; 548 os << " };\n"; 549 os << "};\n"; 550 } 551 552 void PatternEmitter::emitRewriteLogic() { 553 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 554 const Operator &rootOp = pattern.getSourceRootOp(); 555 int numExpectedResults = rootOp.getNumResults(); 556 int numResultPatterns = pattern.getNumResultPatterns(); 557 558 // First register all symbols bound to ops generated in result patterns. 559 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 560 561 // Only the last N static values generated are used to replace the matched 562 // root N-result op. We need to calculate the starting index (of the results 563 // of the matched op) each result pattern is to replace. 564 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 565 // If we don't need to replace any value at all, set the replacement starting 566 // index as the number of result patterns so we skip all of them when trying 567 // to replace the matched op's results. 568 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 569 for (int i = numResultPatterns - 1; i >= 0; --i) { 570 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 571 offsets[i] = offsets[i + 1] - numValues; 572 if (offsets[i] == 0) { 573 if (replStartIndex == -1) 574 replStartIndex = i; 575 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 576 auto error = formatv( 577 "cannot use the same multi-result op '{0}' to generate both " 578 "auxiliary values and values to be used for replacing the matched op", 579 pattern.getResultPattern(i).getSymbol()); 580 PrintFatalError(loc, error); 581 } 582 } 583 584 if (offsets.front() > 0) { 585 const char error[] = "no enough values generated to replace the matched op"; 586 PrintFatalError(loc, error); 587 } 588 589 os.indent(4) << "auto odsLoc = rewriter.getFusedLoc({"; 590 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 591 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 592 } 593 os << "}); (void)odsLoc;\n"; 594 595 // Process auxiliary result patterns. 596 for (int i = 0; i < replStartIndex; ++i) { 597 DagNode resultTree = pattern.getResultPattern(i); 598 auto val = handleResultPattern(resultTree, offsets[i], 0); 599 // Normal op creation will be streamed to `os` by the above call; but 600 // NativeCodeCall will only be materialized to `os` if it is used. Here 601 // we are handling auxiliary patterns so we want the side effect even if 602 // NativeCodeCall is not replacing matched root op's results. 603 if (resultTree.isNativeCodeCall()) 604 os.indent(4) << val << ";\n"; 605 } 606 607 if (numExpectedResults == 0) { 608 assert(replStartIndex >= numResultPatterns && 609 "invalid auxiliary vs. replacement pattern division!"); 610 // No result to replace. Just erase the op. 611 os.indent(4) << "rewriter.eraseOp(op0);\n"; 612 } else { 613 // Process replacement result patterns. 614 os.indent(4) 615 << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n"; 616 for (int i = replStartIndex; i < numResultPatterns; ++i) { 617 DagNode resultTree = pattern.getResultPattern(i); 618 auto val = handleResultPattern(resultTree, offsets[i], 0); 619 os.indent(4) << "\n"; 620 // Resolve each symbol for all range use so that we can loop over them. 621 // We need an explicit cast to `SmallVector` to capture the cases where 622 // `{0}` resolves to an `Operation::result_range` as well as cases that 623 // are not iterable (e.g. vector that gets wrapped in additional braces by 624 // RewriterGen). 625 // TODO: Revisit the need for materializing a vector. 626 os << symbolInfoMap.getAllRangeUse( 627 val, 628 " for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{ " 629 "tblgen_repl_values.push_back(v); }", 630 "\n"); 631 } 632 os.indent(4) << "\n"; 633 os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n"; 634 } 635 636 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 637 } 638 639 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 640 return std::string( 641 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++)); 642 } 643 644 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 645 int resultIndex, int depth) { 646 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 647 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 648 LLVM_DEBUG(llvm::dbgs() << '\n'); 649 650 if (resultTree.isLocationDirective()) { 651 PrintFatalError(loc, 652 "location directive can only be used with op creation"); 653 } 654 655 if (resultTree.isNativeCodeCall()) { 656 auto symbol = handleReplaceWithNativeCodeCall(resultTree); 657 symbolInfoMap.bindValue(symbol); 658 return symbol; 659 } 660 661 if (resultTree.isReplaceWithValue()) 662 return handleReplaceWithValue(resultTree).str(); 663 664 // Normal op creation. 665 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 666 if (resultTree.getSymbol().empty()) { 667 // This is an op not explicitly bound to a symbol in the rewrite rule. 668 // Register the auto-generated symbol for it. 669 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 670 } 671 return symbol; 672 } 673 674 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { 675 assert(tree.isReplaceWithValue()); 676 677 if (tree.getNumArgs() != 1) { 678 PrintFatalError( 679 loc, "replaceWithValue directive must take exactly one argument"); 680 } 681 682 if (!tree.getSymbol().empty()) { 683 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 684 } 685 686 return tree.getArgName(0); 687 } 688 689 std::string PatternEmitter::handleLocationDirective(DagNode tree) { 690 assert(tree.isLocationDirective()); 691 auto lookUpArgLoc = [this, &tree](int idx) { 692 const auto *const lookupFmt = "(*{0}.begin()).getLoc()"; 693 return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt); 694 }; 695 696 if (tree.getNumArgs() == 0) 697 llvm::PrintFatalError( 698 "At least one argument to location directive required"); 699 700 if (!tree.getSymbol().empty()) 701 PrintFatalError(loc, "cannot bind symbol to location"); 702 703 if (tree.getNumArgs() == 1) { 704 DagLeaf leaf = tree.getArgAsLeaf(0); 705 if (leaf.isStringAttr()) 706 return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), " 707 "rewriter.getContext())", 708 leaf.getStringAttr()) 709 .str(); 710 return lookUpArgLoc(0); 711 } 712 713 std::string ret; 714 llvm::raw_string_ostream os(ret); 715 std::string strAttr; 716 os << "rewriter.getFusedLoc({"; 717 bool first = true; 718 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 719 DagLeaf leaf = tree.getArgAsLeaf(i); 720 // Handle the optional string value. 721 if (leaf.isStringAttr()) { 722 if (!strAttr.empty()) 723 llvm::PrintFatalError("Only one string attribute may be specified"); 724 strAttr = leaf.getStringAttr(); 725 continue; 726 } 727 os << (first ? "" : ", ") << lookUpArgLoc(i); 728 first = false; 729 } 730 os << "}"; 731 if (!strAttr.empty()) { 732 os << ", rewriter.getStringAttr(\"" << strAttr << "\")"; 733 } 734 os << ")"; 735 return os.str(); 736 } 737 738 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 739 StringRef patArgName) { 740 if (leaf.isStringAttr()) 741 PrintFatalError(loc, "raw string not supported as argument"); 742 if (leaf.isConstantAttr()) { 743 auto constAttr = leaf.getAsConstantAttr(); 744 return handleConstantAttr(constAttr.getAttribute(), 745 constAttr.getConstantValue()); 746 } 747 if (leaf.isEnumAttrCase()) { 748 auto enumCase = leaf.getAsEnumAttrCase(); 749 if (enumCase.isStrCase()) 750 return handleConstantAttr(enumCase, enumCase.getSymbol()); 751 // This is an enum case backed by an IntegerAttr. We need to get its value 752 // to build the constant. 753 std::string val = std::to_string(enumCase.getValue()); 754 return handleConstantAttr(enumCase, val); 755 } 756 757 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 758 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 759 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 760 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 761 << "' (via symbol ref)\n"); 762 return argName; 763 } 764 if (leaf.isNativeCodeCall()) { 765 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 766 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 767 << "' (via NativeCodeCall)\n"); 768 return std::string(repl); 769 } 770 PrintFatalError(loc, "unhandled case when rewriting op"); 771 } 772 773 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { 774 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 775 LLVM_DEBUG(tree.print(llvm::dbgs())); 776 LLVM_DEBUG(llvm::dbgs() << '\n'); 777 778 auto fmt = tree.getNativeCodeTemplate(); 779 // TODO: replace formatv arguments with the exact specified args. 780 SmallVector<std::string, 8> attrs(8); 781 if (tree.getNumArgs() > 8) { 782 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 783 Twine(tree.getNumArgs())); 784 } 785 bool hasLocationDirective; 786 std::string locToUse; 787 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 788 789 for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { 790 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 791 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i 792 << " replacement: " << attrs[i] << "\n"); 793 } 794 return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0], 795 attrs[1], attrs[2], attrs[3], attrs[4], attrs[5], 796 attrs[6], attrs[7])); 797 } 798 799 int PatternEmitter::getNodeValueCount(DagNode node) { 800 if (node.isOperation()) { 801 // If the op is bound to a symbol in the rewrite rule, query its result 802 // count from the symbol info map. 803 auto symbol = node.getSymbol(); 804 if (!symbol.empty()) { 805 return symbolInfoMap.getStaticValueCount(symbol); 806 } 807 // Otherwise this is an unbound op; we will use all its results. 808 return pattern.getDialectOp(node).getNumResults(); 809 } 810 // TODO: This considers all NativeCodeCall as returning one 811 // value. Enhance if multi-value ones are needed. 812 return 1; 813 } 814 815 std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) { 816 auto numPatArgs = tree.getNumArgs(); 817 818 if (numPatArgs != 0) { 819 if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1)) 820 if (lastArg.isLocationDirective()) { 821 return std::make_pair(true, handleLocationDirective(lastArg)); 822 } 823 } 824 825 // If no explicit location is given, use the default, all fused, location. 826 return std::make_pair(false, "odsLoc"); 827 } 828 829 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 830 int depth) { 831 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); 832 LLVM_DEBUG(tree.print(llvm::dbgs())); 833 LLVM_DEBUG(llvm::dbgs() << '\n'); 834 835 Operator &resultOp = tree.getDialectOp(opMap); 836 auto numOpArgs = resultOp.getNumArgs(); 837 auto numPatArgs = tree.getNumArgs(); 838 839 bool hasLocationDirective; 840 std::string locToUse; 841 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 842 843 auto inPattern = numPatArgs - hasLocationDirective; 844 if (numOpArgs != inPattern) { 845 PrintFatalError(loc, 846 formatv("resultant op '{0}' argument number mismatch: " 847 "{1} in pattern vs. {2} in definition", 848 resultOp.getOperationName(), inPattern, numOpArgs)); 849 } 850 851 // A map to collect all nested DAG child nodes' names, with operand index as 852 // the key. This includes both bound and unbound child nodes. 853 ChildNodeIndexNameMap childNodeNames; 854 855 // First go through all the child nodes who are nested DAG constructs to 856 // create ops for them and remember the symbol names for them, so that we can 857 // use the results in the current node. This happens in a recursive manner. 858 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 859 if (auto child = tree.getArgAsNestedDag(i)) 860 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 861 } 862 863 // The name of the local variable holding this op. 864 std::string valuePackName; 865 // The symbol for holding the result of this pattern. Note that the result of 866 // this pattern is not necessarily the same as the variable created by this 867 // pattern because we can use `__N` suffix to refer only a specific result if 868 // the generated op is a multi-result op. 869 std::string resultValue; 870 if (tree.getSymbol().empty()) { 871 // No symbol is explicitly bound to this op in the pattern. Generate a 872 // unique name. 873 valuePackName = resultValue = getUniqueSymbol(&resultOp); 874 } else { 875 resultValue = std::string(tree.getSymbol()); 876 // Strip the index to get the name for the value pack and use it to name the 877 // local variable for the op. 878 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue)); 879 } 880 881 // Create the local variable for this op. 882 os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(), 883 valuePackName); 884 os.indent(4) << "{\n"; 885 886 // Right now ODS don't have general type inference support. Except a few 887 // special cases listed below, DRR needs to supply types for all results 888 // when building an op. 889 bool isSameOperandsAndResultType = 890 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType"); 891 bool useFirstAttr = 892 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"); 893 894 if (isSameOperandsAndResultType || useFirstAttr) { 895 // We know how to deduce the result type for ops with these traits and we've 896 // generated builders taking aggregate parameters. Use those builders to 897 // create the ops. 898 899 // First prepare local variables for op arguments used in builder call. 900 createAggregateLocalVarsForOpArgs(tree, childNodeNames); 901 902 // Then create the op. 903 os.indent(6) << formatv( 904 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);\n", 905 valuePackName, resultOp.getQualCppClassName(), locToUse); 906 os.indent(4) << "}\n"; 907 return resultValue; 908 } 909 910 bool usePartialResults = valuePackName != resultValue; 911 912 if (usePartialResults || depth > 0 || resultIndex < 0) { 913 // For these cases (broadcastable ops, op results used both as auxiliary 914 // values and replacement values, ops in nested patterns, auxiliary ops), we 915 // still need to supply the result types when building the op. But because 916 // we don't generate a builder automatically with ODS for them, it's the 917 // developer's responsibility to make sure such a builder (with result type 918 // deduction ability) exists. We go through the separate-parameter builder 919 // here given that it's easier for developers to write compared to 920 // aggregate-parameter builders. 921 createSeparateLocalVarsForOpArgs(tree, childNodeNames); 922 923 os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, 924 resultOp.getQualCppClassName(), locToUse); 925 supplyValuesForOpArgs(tree, childNodeNames); 926 os << "\n );\n"; 927 os.indent(4) << "}\n"; 928 return resultValue; 929 } 930 931 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 932 // generated from the source pattern root op. Then we can use the source 933 // pattern's value types to determine the value type of the generated op 934 // here. 935 936 // First prepare local variables for op arguments used in builder call. 937 createAggregateLocalVarsForOpArgs(tree, childNodeNames); 938 939 // Then prepare the result types. We need to specify the types for all 940 // results. 941 os.indent(6) << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; " 942 "(void)tblgen_types;\n"); 943 int numResults = resultOp.getNumResults(); 944 if (numResults != 0) { 945 for (int i = 0; i < numResults; ++i) 946 os.indent(6) << formatv("for (auto v : castedOp0.getODSResults({0})) {{" 947 "tblgen_types.push_back(v.getType()); }\n", 948 resultIndex + i); 949 } 950 os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " 951 "tblgen_values, tblgen_attrs);\n", 952 valuePackName, resultOp.getQualCppClassName(), 953 locToUse); 954 os.indent(4) << "}\n"; 955 return resultValue; 956 } 957 958 void PatternEmitter::createSeparateLocalVarsForOpArgs( 959 DagNode node, ChildNodeIndexNameMap &childNodeNames) { 960 Operator &resultOp = node.getDialectOp(opMap); 961 962 // Now prepare operands used for building this op: 963 // * If the operand is non-variadic, we create a `Value` local variable. 964 // * If the operand is variadic, we create a `SmallVector<Value>` local 965 // variable. 966 967 int valueIndex = 0; // An index for uniquing local variable names. 968 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 969 const auto *operand = 970 resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>(); 971 if (!operand) { 972 // We do not need special handling for attributes. 973 continue; 974 } 975 976 std::string varName; 977 if (operand->isVariadic()) { 978 varName = std::string(formatv("tblgen_values_{0}", valueIndex++)); 979 os.indent(6) << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", 980 varName); 981 std::string range; 982 if (node.isNestedDagArg(argIndex)) { 983 range = childNodeNames[argIndex]; 984 } else { 985 range = std::string(node.getArgName(argIndex)); 986 } 987 // Resolve the symbol for all range use so that we have a uniform way of 988 // capturing the values. 989 range = symbolInfoMap.getValueAndRangeUse(range); 990 os.indent(6) << formatv("for (auto v : {0}) {1}.push_back(v);\n", range, 991 varName); 992 } else { 993 varName = std::string(formatv("tblgen_value_{0}", valueIndex++)); 994 os.indent(6) << formatv("::mlir::Value {0} = ", varName); 995 if (node.isNestedDagArg(argIndex)) { 996 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 997 } else { 998 DagLeaf leaf = node.getArgAsLeaf(argIndex); 999 auto symbol = 1000 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1001 if (leaf.isNativeCodeCall()) { 1002 os << std::string( 1003 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1004 } else { 1005 os << symbol; 1006 } 1007 } 1008 os << ";\n"; 1009 } 1010 1011 // Update to use the newly created local variable for building the op later. 1012 childNodeNames[argIndex] = varName; 1013 } 1014 } 1015 1016 void PatternEmitter::supplyValuesForOpArgs( 1017 DagNode node, const ChildNodeIndexNameMap &childNodeNames) { 1018 Operator &resultOp = node.getDialectOp(opMap); 1019 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); 1020 argIndex != numOpArgs; ++argIndex) { 1021 // Start each argument on its own line. 1022 (os << ",\n").indent(8); 1023 1024 Argument opArg = resultOp.getArg(argIndex); 1025 // Handle the case of operand first. 1026 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 1027 if (!operand->name.empty()) 1028 os << "/*" << operand->name << "=*/"; 1029 os << childNodeNames.lookup(argIndex); 1030 continue; 1031 } 1032 1033 // The argument in the op definition. 1034 auto opArgName = resultOp.getArgName(argIndex); 1035 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1036 if (!subTree.isNativeCodeCall()) 1037 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1038 "for creating attribute"); 1039 os << formatv("/*{0}=*/{1}", opArgName, 1040 handleReplaceWithNativeCodeCall(subTree)); 1041 } else { 1042 auto leaf = node.getArgAsLeaf(argIndex); 1043 // The argument in the result DAG pattern. 1044 auto patArgName = node.getArgName(argIndex); 1045 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 1046 // TODO: Refactor out into map to avoid recomputing these. 1047 if (!opArg.is<NamedAttribute *>()) 1048 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 1049 if (!patArgName.empty()) 1050 os << "/*" << patArgName << "=*/"; 1051 } else { 1052 os << "/*" << opArgName << "=*/"; 1053 } 1054 os << handleOpArgument(leaf, patArgName); 1055 } 1056 } 1057 } 1058 1059 void PatternEmitter::createAggregateLocalVarsForOpArgs( 1060 DagNode node, const ChildNodeIndexNameMap &childNodeNames) { 1061 Operator &resultOp = node.getDialectOp(opMap); 1062 1063 os.indent(6) << formatv("::mlir::SmallVector<::mlir::Value, 4> " 1064 "tblgen_values; (void)tblgen_values;\n"); 1065 os.indent(6) << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> " 1066 "tblgen_attrs; (void)tblgen_attrs;\n"); 1067 1068 const char *addAttrCmd = 1069 "if (auto tmpAttr = {1}) " 1070 "tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), tmpAttr);\n"; 1071 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1072 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) { 1073 // The argument in the op definition. 1074 auto opArgName = resultOp.getArgName(argIndex); 1075 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1076 if (!subTree.isNativeCodeCall()) 1077 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1078 "for creating attribute"); 1079 os.indent(6) << formatv(addAttrCmd, opArgName, 1080 handleReplaceWithNativeCodeCall(subTree)); 1081 } else { 1082 auto leaf = node.getArgAsLeaf(argIndex); 1083 // The argument in the result DAG pattern. 1084 auto patArgName = node.getArgName(argIndex); 1085 os.indent(6) << formatv(addAttrCmd, opArgName, 1086 handleOpArgument(leaf, patArgName)); 1087 } 1088 continue; 1089 } 1090 1091 const auto *operand = 1092 resultOp.getArg(argIndex).get<NamedTypeConstraint *>(); 1093 std::string varName; 1094 if (operand->isVariadic()) { 1095 std::string range; 1096 if (node.isNestedDagArg(argIndex)) { 1097 range = childNodeNames.lookup(argIndex); 1098 } else { 1099 range = std::string(node.getArgName(argIndex)); 1100 } 1101 // Resolve the symbol for all range use so that we have a uniform way of 1102 // capturing the values. 1103 range = symbolInfoMap.getValueAndRangeUse(range); 1104 os.indent(6) << formatv( 1105 "for (auto v : {0}) tblgen_values.push_back(v);\n", range); 1106 } else { 1107 os.indent(6) << formatv("tblgen_values.push_back(", varName); 1108 if (node.isNestedDagArg(argIndex)) { 1109 os << symbolInfoMap.getValueAndRangeUse( 1110 childNodeNames.lookup(argIndex)); 1111 } else { 1112 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1113 auto symbol = 1114 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1115 if (leaf.isNativeCodeCall()) { 1116 os << std::string( 1117 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1118 } else { 1119 os << symbol; 1120 } 1121 } 1122 os << ");\n"; 1123 } 1124 } 1125 } 1126 1127 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 1128 emitSourceFileHeader("Rewriters", os); 1129 1130 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 1131 auto numPatterns = patterns.size(); 1132 1133 // We put the map here because it can be shared among multiple patterns. 1134 RecordOperatorMap recordOpMap; 1135 1136 std::vector<std::string> rewriterNames; 1137 rewriterNames.reserve(numPatterns); 1138 1139 std::string baseRewriterName = "GeneratedConvert"; 1140 int rewriterIndex = 0; 1141 1142 for (Record *p : patterns) { 1143 std::string name; 1144 if (p->isAnonymous()) { 1145 // If no name is provided, ensure unique rewriter names simply by 1146 // appending unique suffix. 1147 name = baseRewriterName + llvm::utostr(rewriterIndex++); 1148 } else { 1149 name = std::string(p->getName()); 1150 } 1151 LLVM_DEBUG(llvm::dbgs() 1152 << "=== start generating pattern '" << name << "' ===\n"); 1153 PatternEmitter(p, &recordOpMap, os).emit(name); 1154 LLVM_DEBUG(llvm::dbgs() 1155 << "=== done generating pattern '" << name << "' ===\n"); 1156 rewriterNames.push_back(std::move(name)); 1157 } 1158 1159 // Emit function to add the generated matchers to the pattern list. 1160 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(MLIRContext " 1161 "*context, OwningRewritePatternList *patterns) {\n"; 1162 for (const auto &name : rewriterNames) { 1163 os << " patterns->insert<" << name << ">(context);\n"; 1164 } 1165 os << "}\n"; 1166 } 1167 1168 static mlir::GenRegistration 1169 genRewriters("gen-rewriters", "Generate pattern rewriters", 1170 [](const RecordKeeper &records, raw_ostream &os) { 1171 emitRewriters(records, os); 1172 return false; 1173 }); 1174