1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Support/IndentedOstream.h" 14 #include "mlir/TableGen/Attribute.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/GenInfo.h" 17 #include "mlir/TableGen/Operator.h" 18 #include "mlir/TableGen/Pattern.h" 19 #include "mlir/TableGen/Predicate.h" 20 #include "mlir/TableGen/Type.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/StringSet.h" 23 #include "llvm/Support/CommandLine.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/FormatAdapters.h" 26 #include "llvm/Support/PrettyStackTrace.h" 27 #include "llvm/Support/Signals.h" 28 #include "llvm/TableGen/Error.h" 29 #include "llvm/TableGen/Main.h" 30 #include "llvm/TableGen/Record.h" 31 #include "llvm/TableGen/TableGenBackend.h" 32 33 using namespace mlir; 34 using namespace mlir::tblgen; 35 36 using llvm::formatv; 37 using llvm::Record; 38 using llvm::RecordKeeper; 39 40 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 41 42 namespace llvm { 43 template <> 44 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 45 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 46 raw_ostream &os, StringRef style) { 47 os << v.first << ":" << v.second; 48 } 49 }; 50 } // end namespace llvm 51 52 //===----------------------------------------------------------------------===// 53 // PatternEmitter 54 //===----------------------------------------------------------------------===// 55 56 namespace { 57 class PatternEmitter { 58 public: 59 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 60 61 // Emits the mlir::RewritePattern struct named `rewriteName`. 62 void emit(StringRef rewriteName); 63 64 private: 65 // Emits the code for matching ops. 66 void emitMatchLogic(DagNode tree, StringRef opName); 67 68 // Emits the code for rewriting ops. 69 void emitRewriteLogic(); 70 71 //===--------------------------------------------------------------------===// 72 // Match utilities 73 //===--------------------------------------------------------------------===// 74 75 // Emits C++ statements for matching the DAG structure. 76 void emitMatch(DagNode tree, StringRef name, int depth); 77 78 // Emits C++ statements for matching using a native code call. 79 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); 80 81 // Emits C++ statements for matching the op constrained by the given DAG 82 // `tree` returning the op's variable name. 83 void emitOpMatch(DagNode tree, StringRef opName, int depth); 84 85 // Emits C++ statements for matching the `argIndex`-th argument of the given 86 // DAG `tree` as an operand. 87 void emitOperandMatch(DagNode tree, StringRef opName, int argIndex, 88 int depth); 89 90 // Emits C++ statements for matching the `argIndex`-th argument of the given 91 // DAG `tree` as an attribute. 92 void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, 93 int depth); 94 95 // Emits C++ for checking a match with a corresponding match failure 96 // diagnostic. 97 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, 98 const llvm::formatv_object_base &failureFmt); 99 100 // Emits C++ for checking a match with a corresponding match failure 101 // diagnostics. 102 void emitMatchCheck(StringRef opName, const std::string &matchStr, 103 const std::string &failureStr); 104 105 //===--------------------------------------------------------------------===// 106 // Rewrite utilities 107 //===--------------------------------------------------------------------===// 108 109 // The entry point for handling a result pattern rooted at `resultTree`. This 110 // method dispatches to concrete handlers according to `resultTree`'s kind and 111 // returns a symbol representing the whole value pack. Callers are expected to 112 // further resolve the symbol according to the specific use case. 113 // 114 // `depth` is the nesting level of `resultTree`; 0 means top-level result 115 // pattern. For top-level result pattern, `resultIndex` indicates which result 116 // of the matched root op this pattern is intended to replace, which can be 117 // used to deduce the result type of the op generated from this result 118 // pattern. 119 std::string handleResultPattern(DagNode resultTree, int resultIndex, 120 int depth); 121 122 // Emits the C++ statement to replace the matched DAG with a value built via 123 // calling native C++ code. 124 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); 125 126 // Returns the symbol of the old value serving as the replacement. 127 StringRef handleReplaceWithValue(DagNode tree); 128 129 // Returns the location value to use. 130 std::pair<bool, std::string> getLocation(DagNode tree); 131 132 // Returns the location value to use. 133 std::string handleLocationDirective(DagNode tree); 134 135 // Emits the C++ statement to build a new op out of the given DAG `tree` and 136 // returns the variable name that this op is assigned to. If the root op in 137 // DAG `tree` has a specified name, the created op will be assigned to a 138 // variable of the given name. Otherwise, a unique name will be used as the 139 // result value name. 140 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 141 142 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; 143 144 // Emits a local variable for each value and attribute to be used for creating 145 // an op. 146 void createSeparateLocalVarsForOpArgs(DagNode node, 147 ChildNodeIndexNameMap &childNodeNames); 148 149 // Emits the concrete arguments used to call an op's builder. 150 void supplyValuesForOpArgs(DagNode node, 151 const ChildNodeIndexNameMap &childNodeNames, 152 int depth); 153 154 // Emits the local variables for holding all values as a whole and all named 155 // attributes as a whole to be used for creating an op. 156 void createAggregateLocalVarsForOpArgs( 157 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); 158 159 // Returns the C++ expression to construct a constant attribute of the given 160 // `value` for the given attribute kind `attr`. 161 std::string handleConstantAttr(Attribute attr, StringRef value); 162 163 // Returns the C++ expression to build an argument from the given DAG `leaf`. 164 // `patArgName` is used to bound the argument to the source pattern. 165 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 166 167 //===--------------------------------------------------------------------===// 168 // General utilities 169 //===--------------------------------------------------------------------===// 170 171 // Collects all of the operations within the given dag tree. 172 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 173 174 // Returns a unique symbol for a local variable of the given `op`. 175 std::string getUniqueSymbol(const Operator *op); 176 177 //===--------------------------------------------------------------------===// 178 // Symbol utilities 179 //===--------------------------------------------------------------------===// 180 181 // Returns how many static values the given DAG `node` correspond to. 182 int getNodeValueCount(DagNode node); 183 184 private: 185 // Pattern instantiation location followed by the location of multiclass 186 // prototypes used. This is intended to be used as a whole to 187 // PrintFatalError() on errors. 188 ArrayRef<llvm::SMLoc> loc; 189 190 // Op's TableGen Record to wrapper object. 191 RecordOperatorMap *opMap; 192 193 // Handy wrapper for pattern being emitted. 194 Pattern pattern; 195 196 // Map for all bound symbols' info. 197 SymbolInfoMap symbolInfoMap; 198 199 // The next unused ID for newly created values. 200 unsigned nextValueId; 201 202 raw_indented_ostream os; 203 204 // Format contexts containing placeholder substitutions. 205 FmtContext fmtCtx; 206 207 // Number of op processed. 208 int opCounter = 0; 209 }; 210 } // end anonymous namespace 211 212 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 213 raw_ostream &os) 214 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 215 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 216 fmtCtx.withBuilder("rewriter"); 217 } 218 219 std::string PatternEmitter::handleConstantAttr(Attribute attr, 220 StringRef value) { 221 if (!attr.isConstBuildable()) 222 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 223 " does not have the 'constBuilderCall' field"); 224 225 // TODO: Verify the constants here 226 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value)); 227 } 228 229 // Helper function to match patterns. 230 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { 231 if (tree.isNativeCodeCall()) { 232 emitNativeCodeMatch(tree, name, depth); 233 return; 234 } 235 236 if (tree.isOperation()) { 237 emitOpMatch(tree, name, depth); 238 return; 239 } 240 241 PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match."); 242 } 243 244 // Helper function to match patterns. 245 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, 246 int depth) { 247 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: "); 248 LLVM_DEBUG(tree.print(llvm::dbgs())); 249 LLVM_DEBUG(llvm::dbgs() << '\n'); 250 251 // TODO(suderman): iterate through arguments, determine their types, output 252 // names. 253 SmallVector<std::string, 8> capture(8); 254 if (tree.getNumArgs() > 8) { 255 PrintFatalError(loc, 256 "unsupported NativeCodeCall matcher argument numbers: " + 257 Twine(tree.getNumArgs())); 258 } 259 260 raw_indented_ostream::DelimitedScope scope(os); 261 262 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 263 std::string argName = formatv("arg{0}_{1}", depth, i); 264 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 265 os << "Value " << argName << ";\n"; 266 } else { 267 auto leaf = tree.getArgAsLeaf(i); 268 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { 269 os << "Attribute " << argName << ";\n"; 270 } else if (leaf.isOperandMatcher()) { 271 os << "Operation " << argName << ";\n"; 272 } 273 } 274 275 capture[i] = std::move(argName); 276 } 277 278 bool hasLocationDirective; 279 std::string locToUse; 280 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 281 282 auto fmt = tree.getNativeCodeTemplate(); 283 auto nativeCodeCall = std::string(tgfmt( 284 fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1], 285 capture[2], capture[3], capture[4], capture[5], capture[6], capture[7])); 286 287 os << "if (failed(" << nativeCodeCall << ")) return failure();\n"; 288 289 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 290 auto name = tree.getArgName(i); 291 if (!name.empty() && name != "_") { 292 os << formatv("{0} = {1};\n", name, capture[i]); 293 } 294 } 295 296 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 297 std::string argName = capture[i]; 298 299 // Handle nested DAG construct first 300 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 301 PrintFatalError( 302 loc, formatv("Matching nested tree in NativeCodecall not support for " 303 "{0} as arg {1}", 304 argName, i)); 305 } 306 307 DagLeaf leaf = tree.getArgAsLeaf(i); 308 auto constraint = leaf.getAsConstraint(); 309 310 auto self = formatv("{0}", argName); 311 emitMatchCheck( 312 opName, 313 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), 314 formatv("\"operand {0} of native code call '{1}' failed to satisfy " 315 "constraint: " 316 "'{2}'\"", 317 i, tree.getNativeCodeTemplate(), constraint.getDescription())); 318 } 319 320 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n"); 321 } 322 323 // Helper function to match patterns. 324 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { 325 Operator &op = tree.getDialectOp(opMap); 326 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 327 << op.getOperationName() << "' at depth " << depth 328 << '\n'); 329 330 std::string castedName = formatv("castedOp{0}", depth); 331 os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); " 332 "(void){0};\n", 333 castedName, opName, op.getQualCppClassName()); 334 // Skip the operand matching at depth 0 as the pattern rewriter already does. 335 if (depth != 0) { 336 // Skip if there is no defining operation (e.g., arguments to function). 337 os << formatv("if (!{0}) return failure();\n", castedName); 338 } 339 if (tree.getNumArgs() != op.getNumArgs()) { 340 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 341 "pattern vs. {2} in definition", 342 op.getOperationName(), tree.getNumArgs(), 343 op.getNumArgs())); 344 } 345 346 // If the operand's name is set, set to that variable. 347 auto name = tree.getSymbol(); 348 if (!name.empty()) 349 os << formatv("{0} = {1};\n", name, castedName); 350 351 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 352 auto opArg = op.getArg(i); 353 std::string argName = formatv("op{0}", depth + 1); 354 355 // Handle nested DAG construct first 356 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 357 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 358 if (operand->isVariableLength()) { 359 auto error = formatv("use nested DAG construct to match op {0}'s " 360 "variadic operand #{1} unsupported now", 361 op.getOperationName(), i); 362 PrintFatalError(loc, error); 363 } 364 } 365 os << "{\n"; 366 367 os.indent() << formatv( 368 "auto *{0} = " 369 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", 370 argName, castedName, i); 371 emitMatch(argTree, argName, depth + 1); 372 os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName); 373 os.unindent() << "}\n"; 374 continue; 375 } 376 377 // Next handle DAG leaf: operand or attribute 378 if (opArg.is<NamedTypeConstraint *>()) { 379 emitOperandMatch(tree, castedName, i, depth); 380 } else if (opArg.is<NamedAttribute *>()) { 381 emitAttributeMatch(tree, opName, i, depth); 382 } else { 383 PrintFatalError(loc, "unhandled case when matching op"); 384 } 385 } 386 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 387 << op.getOperationName() << "' at depth " << depth 388 << '\n'); 389 } 390 391 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, 392 int argIndex, int depth) { 393 Operator &op = tree.getDialectOp(opMap); 394 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>(); 395 auto matcher = tree.getArgAsLeaf(argIndex); 396 397 // If a constraint is specified, we need to generate C++ statements to 398 // check the constraint. 399 if (!matcher.isUnspecified()) { 400 if (!matcher.isOperandMatcher()) { 401 PrintFatalError( 402 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 403 op.getOperationName(), argIndex + 1)); 404 } 405 406 // Only need to verify if the matcher's type is different from the one 407 // of op definition. 408 Constraint constraint = matcher.getAsConstraint(); 409 if (operand->constraint != constraint) { 410 if (operand->isVariableLength()) { 411 auto error = formatv( 412 "further constrain op {0}'s variadic operand #{1} unsupported now", 413 op.getOperationName(), argIndex); 414 PrintFatalError(loc, error); 415 } 416 auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()", 417 opName, argIndex); 418 emitMatchCheck( 419 opName, 420 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), 421 formatv("\"operand {0} of op '{1}' failed to satisfy constraint: " 422 "'{2}'\"", 423 operand - op.operand_begin(), op.getOperationName(), 424 constraint.getDescription())); 425 } 426 } 427 428 // Capture the value 429 auto name = tree.getArgName(argIndex); 430 // `$_` is a special symbol to ignore op argument matching. 431 if (!name.empty() && name != "_") { 432 // We need to subtract the number of attributes before this operand to get 433 // the index in the operand list. 434 auto numPrevAttrs = std::count_if( 435 op.arg_begin(), op.arg_begin() + argIndex, 436 [](const Argument &arg) { return arg.is<NamedAttribute *>(); }); 437 438 auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex); 439 os << formatv("{0} = {1}.getODSOperands({2});\n", 440 res->second.getVarName(name), opName, 441 argIndex - numPrevAttrs); 442 } 443 } 444 445 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, 446 int argIndex, int depth) { 447 Operator &op = tree.getDialectOp(opMap); 448 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>(); 449 const auto &attr = namedAttr->attr; 450 451 os << "{\n"; 452 os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" 453 "(void)tblgen_attr;\n", 454 opName, attr.getStorageType(), namedAttr->name); 455 456 // TODO: This should use getter method to avoid duplication. 457 if (attr.hasDefaultValue()) { 458 os << "if (!tblgen_attr) tblgen_attr = " 459 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 460 attr.getDefaultValue())) 461 << ";\n"; 462 } else if (attr.isOptional()) { 463 // For a missing attribute that is optional according to definition, we 464 // should just capture a mlir::Attribute() to signal the missing state. 465 // That is precisely what getAttr() returns on missing attributes. 466 } else { 467 emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx), 468 formatv("\"expected op '{0}' to have attribute '{1}' " 469 "of type '{2}'\"", 470 op.getOperationName(), namedAttr->name, 471 attr.getStorageType())); 472 } 473 474 auto matcher = tree.getArgAsLeaf(argIndex); 475 if (!matcher.isUnspecified()) { 476 if (!matcher.isAttrMatcher()) { 477 PrintFatalError( 478 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 479 op.getOperationName(), argIndex + 1)); 480 } 481 482 // If a constraint is specified, we need to generate C++ statements to 483 // check the constraint. 484 emitMatchCheck( 485 opName, 486 tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")), 487 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " 488 "{2}\"", 489 op.getOperationName(), namedAttr->name, 490 matcher.getAsConstraint().getDescription())); 491 } 492 493 // Capture the value 494 auto name = tree.getArgName(argIndex); 495 // `$_` is a special symbol to ignore op argument matching. 496 if (!name.empty() && name != "_") { 497 os << formatv("{0} = tblgen_attr;\n", name); 498 } 499 500 os.unindent() << "}\n"; 501 } 502 503 void PatternEmitter::emitMatchCheck( 504 StringRef opName, const FmtObjectBase &matchFmt, 505 const llvm::formatv_object_base &failureFmt) { 506 emitMatchCheck(opName, matchFmt.str(), failureFmt.str()); 507 } 508 509 void PatternEmitter::emitMatchCheck(StringRef opName, 510 const std::string &matchStr, 511 const std::string &failureStr) { 512 513 os << "if (!(" << matchStr << "))"; 514 os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName 515 << ", [&](::mlir::Diagnostic &diag) {\n diag << " 516 << failureStr << ";\n});"; 517 } 518 519 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { 520 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 521 int depth = 0; 522 emitMatch(tree, opName, depth); 523 524 for (auto &appliedConstraint : pattern.getConstraints()) { 525 auto &constraint = appliedConstraint.constraint; 526 auto &entities = appliedConstraint.entities; 527 528 auto condition = constraint.getConditionTemplate(); 529 if (isa<TypeConstraint>(constraint)) { 530 auto self = formatv("({0}.getType())", 531 symbolInfoMap.getValueAndRangeUse(entities.front())); 532 emitMatchCheck( 533 opName, tgfmt(condition, &fmtCtx.withSelf(self.str())), 534 formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"", 535 entities.front(), constraint.getDescription())); 536 537 } else if (isa<AttrConstraint>(constraint)) { 538 PrintFatalError( 539 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 540 } else { 541 // TODO: replace formatv arguments with the exact specified 542 // args. 543 if (entities.size() > 4) { 544 PrintFatalError(loc, "only support up to 4-entity constraints now"); 545 } 546 SmallVector<std::string, 4> names; 547 int i = 0; 548 for (int e = entities.size(); i < e; ++i) 549 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 550 std::string self = appliedConstraint.self; 551 if (!self.empty()) 552 self = symbolInfoMap.getValueAndRangeUse(self); 553 for (; i < 4; ++i) 554 names.push_back("<unused>"); 555 emitMatchCheck(opName, 556 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 557 names[1], names[2], names[3]), 558 formatv("\"entities '{0}' failed to satisfy constraint: " 559 "{1}\"", 560 llvm::join(entities, ", "), 561 constraint.getDescription())); 562 } 563 } 564 565 // Some of the operands could be bound to the same symbol name, we need 566 // to enforce equality constraint on those. 567 // TODO: we should be able to emit equality checks early 568 // and short circuit unnecessary work if vars are not equal. 569 for (auto symbolInfoIt = symbolInfoMap.begin(); 570 symbolInfoIt != symbolInfoMap.end();) { 571 auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); 572 auto startRange = range.first; 573 auto endRange = range.second; 574 575 auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); 576 for (++startRange; startRange != endRange; ++startRange) { 577 auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); 578 emitMatchCheck( 579 opName, 580 formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), 581 formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, 582 secondOperand)); 583 } 584 585 symbolInfoIt = endRange; 586 } 587 588 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 589 } 590 591 void PatternEmitter::collectOps(DagNode tree, 592 llvm::SmallPtrSetImpl<const Operator *> &ops) { 593 // Check if this tree is an operation. 594 if (tree.isOperation()) { 595 const Operator &op = tree.getDialectOp(opMap); 596 LLVM_DEBUG(llvm::dbgs() 597 << "found operation " << op.getOperationName() << '\n'); 598 ops.insert(&op); 599 } 600 601 // Recurse the arguments of the tree. 602 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 603 if (auto child = tree.getArgAsNestedDag(i)) 604 collectOps(child, ops); 605 } 606 607 void PatternEmitter::emit(StringRef rewriteName) { 608 // Get the DAG tree for the source pattern. 609 DagNode sourceTree = pattern.getSourcePattern(); 610 611 const Operator &rootOp = pattern.getSourceRootOp(); 612 auto rootName = rootOp.getOperationName(); 613 614 // Collect the set of result operations. 615 llvm::SmallPtrSet<const Operator *, 4> resultOps; 616 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 617 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 618 collectOps(pattern.getResultPattern(i), resultOps); 619 } 620 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 621 622 // Emit RewritePattern for Pattern. 623 auto locs = pattern.getLocation(); 624 os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n", 625 make_range(locs.rbegin(), locs.rend())); 626 os << formatv(R"(struct {0} : public ::mlir::RewritePattern { 627 {0}(::mlir::MLIRContext *context) 628 : ::mlir::RewritePattern("{1}", {{)", 629 rewriteName, rootName); 630 // Sort result operators by name. 631 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 632 resultOps.end()); 633 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 634 return lhs->getOperationName() < rhs->getOperationName(); 635 }); 636 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) { 637 os << '"' << op->getOperationName() << '"'; 638 }); 639 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 640 641 // Emit matchAndRewrite() function. 642 { 643 auto classScope = os.scope(); 644 os.reindent(R"( 645 ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, 646 ::mlir::PatternRewriter &rewriter) const override {)") 647 << '\n'; 648 { 649 auto functionScope = os.scope(); 650 651 // Register all symbols bound in the source pattern. 652 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 653 654 LLVM_DEBUG(llvm::dbgs() 655 << "start creating local variables for capturing matches\n"); 656 os << "// Variables for capturing values and attributes used while " 657 "creating ops\n"; 658 // Create local variables for storing the arguments and results bound 659 // to symbols. 660 for (const auto &symbolInfoPair : symbolInfoMap) { 661 const auto &symbol = symbolInfoPair.first; 662 const auto &info = symbolInfoPair.second; 663 664 os << info.getVarDecl(symbol); 665 } 666 // TODO: capture ops with consistent numbering so that it can be 667 // reused for fused loc. 668 os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n", 669 pattern.getSourcePattern().getNumOps()); 670 LLVM_DEBUG(llvm::dbgs() 671 << "done creating local variables for capturing matches\n"); 672 673 os << "// Match\n"; 674 os << "tblgen_ops[0] = op0;\n"; 675 emitMatchLogic(sourceTree, "op0"); 676 677 os << "\n// Rewrite\n"; 678 emitRewriteLogic(); 679 680 os << "return ::mlir::success();\n"; 681 } 682 os << "};\n"; 683 } 684 os << "};\n\n"; 685 } 686 687 void PatternEmitter::emitRewriteLogic() { 688 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 689 const Operator &rootOp = pattern.getSourceRootOp(); 690 int numExpectedResults = rootOp.getNumResults(); 691 int numResultPatterns = pattern.getNumResultPatterns(); 692 693 // First register all symbols bound to ops generated in result patterns. 694 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 695 696 // Only the last N static values generated are used to replace the matched 697 // root N-result op. We need to calculate the starting index (of the results 698 // of the matched op) each result pattern is to replace. 699 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 700 // If we don't need to replace any value at all, set the replacement starting 701 // index as the number of result patterns so we skip all of them when trying 702 // to replace the matched op's results. 703 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 704 for (int i = numResultPatterns - 1; i >= 0; --i) { 705 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 706 offsets[i] = offsets[i + 1] - numValues; 707 if (offsets[i] == 0) { 708 if (replStartIndex == -1) 709 replStartIndex = i; 710 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 711 auto error = formatv( 712 "cannot use the same multi-result op '{0}' to generate both " 713 "auxiliary values and values to be used for replacing the matched op", 714 pattern.getResultPattern(i).getSymbol()); 715 PrintFatalError(loc, error); 716 } 717 } 718 719 if (offsets.front() > 0) { 720 const char error[] = "no enough values generated to replace the matched op"; 721 PrintFatalError(loc, error); 722 } 723 724 os << "auto odsLoc = rewriter.getFusedLoc({"; 725 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 726 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 727 } 728 os << "}); (void)odsLoc;\n"; 729 730 // Process auxiliary result patterns. 731 for (int i = 0; i < replStartIndex; ++i) { 732 DagNode resultTree = pattern.getResultPattern(i); 733 auto val = handleResultPattern(resultTree, offsets[i], 0); 734 // Normal op creation will be streamed to `os` by the above call; but 735 // NativeCodeCall will only be materialized to `os` if it is used. Here 736 // we are handling auxiliary patterns so we want the side effect even if 737 // NativeCodeCall is not replacing matched root op's results. 738 if (resultTree.isNativeCodeCall()) 739 os << val << ";\n"; 740 } 741 742 if (numExpectedResults == 0) { 743 assert(replStartIndex >= numResultPatterns && 744 "invalid auxiliary vs. replacement pattern division!"); 745 // No result to replace. Just erase the op. 746 os << "rewriter.eraseOp(op0);\n"; 747 } else { 748 // Process replacement result patterns. 749 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n"; 750 for (int i = replStartIndex; i < numResultPatterns; ++i) { 751 DagNode resultTree = pattern.getResultPattern(i); 752 auto val = handleResultPattern(resultTree, offsets[i], 0); 753 os << "\n"; 754 // Resolve each symbol for all range use so that we can loop over them. 755 // We need an explicit cast to `SmallVector` to capture the cases where 756 // `{0}` resolves to an `Operation::result_range` as well as cases that 757 // are not iterable (e.g. vector that gets wrapped in additional braces by 758 // RewriterGen). 759 // TODO: Revisit the need for materializing a vector. 760 os << symbolInfoMap.getAllRangeUse( 761 val, 762 "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" 763 " tblgen_repl_values.push_back(v);\n}\n", 764 "\n"); 765 } 766 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n"; 767 } 768 769 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 770 } 771 772 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 773 return std::string( 774 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++)); 775 } 776 777 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 778 int resultIndex, int depth) { 779 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 780 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 781 LLVM_DEBUG(llvm::dbgs() << '\n'); 782 783 if (resultTree.isLocationDirective()) { 784 PrintFatalError(loc, 785 "location directive can only be used with op creation"); 786 } 787 788 if (resultTree.isNativeCodeCall()) { 789 auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth); 790 symbolInfoMap.bindValue(symbol); 791 return symbol; 792 } 793 794 if (resultTree.isReplaceWithValue()) 795 return handleReplaceWithValue(resultTree).str(); 796 797 // Normal op creation. 798 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 799 if (resultTree.getSymbol().empty()) { 800 // This is an op not explicitly bound to a symbol in the rewrite rule. 801 // Register the auto-generated symbol for it. 802 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 803 } 804 return symbol; 805 } 806 807 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { 808 assert(tree.isReplaceWithValue()); 809 810 if (tree.getNumArgs() != 1) { 811 PrintFatalError( 812 loc, "replaceWithValue directive must take exactly one argument"); 813 } 814 815 if (!tree.getSymbol().empty()) { 816 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 817 } 818 819 return tree.getArgName(0); 820 } 821 822 std::string PatternEmitter::handleLocationDirective(DagNode tree) { 823 assert(tree.isLocationDirective()); 824 auto lookUpArgLoc = [this, &tree](int idx) { 825 const auto *const lookupFmt = "(*{0}.begin()).getLoc()"; 826 return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt); 827 }; 828 829 if (tree.getNumArgs() == 0) 830 llvm::PrintFatalError( 831 "At least one argument to location directive required"); 832 833 if (!tree.getSymbol().empty()) 834 PrintFatalError(loc, "cannot bind symbol to location"); 835 836 if (tree.getNumArgs() == 1) { 837 DagLeaf leaf = tree.getArgAsLeaf(0); 838 if (leaf.isStringAttr()) 839 return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), " 840 "rewriter.getContext())", 841 leaf.getStringAttr()) 842 .str(); 843 return lookUpArgLoc(0); 844 } 845 846 std::string ret; 847 llvm::raw_string_ostream os(ret); 848 std::string strAttr; 849 os << "rewriter.getFusedLoc({"; 850 bool first = true; 851 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 852 DagLeaf leaf = tree.getArgAsLeaf(i); 853 // Handle the optional string value. 854 if (leaf.isStringAttr()) { 855 if (!strAttr.empty()) 856 llvm::PrintFatalError("Only one string attribute may be specified"); 857 strAttr = leaf.getStringAttr(); 858 continue; 859 } 860 os << (first ? "" : ", ") << lookUpArgLoc(i); 861 first = false; 862 } 863 os << "}"; 864 if (!strAttr.empty()) { 865 os << ", rewriter.getStringAttr(\"" << strAttr << "\")"; 866 } 867 os << ")"; 868 return os.str(); 869 } 870 871 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 872 StringRef patArgName) { 873 if (leaf.isStringAttr()) 874 PrintFatalError(loc, "raw string not supported as argument"); 875 if (leaf.isConstantAttr()) { 876 auto constAttr = leaf.getAsConstantAttr(); 877 return handleConstantAttr(constAttr.getAttribute(), 878 constAttr.getConstantValue()); 879 } 880 if (leaf.isEnumAttrCase()) { 881 auto enumCase = leaf.getAsEnumAttrCase(); 882 if (enumCase.isStrCase()) 883 return handleConstantAttr(enumCase, enumCase.getSymbol()); 884 // This is an enum case backed by an IntegerAttr. We need to get its value 885 // to build the constant. 886 std::string val = std::to_string(enumCase.getValue()); 887 return handleConstantAttr(enumCase, val); 888 } 889 890 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 891 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 892 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 893 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 894 << "' (via symbol ref)\n"); 895 return argName; 896 } 897 if (leaf.isNativeCodeCall()) { 898 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 899 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 900 << "' (via NativeCodeCall)\n"); 901 return std::string(repl); 902 } 903 PrintFatalError(loc, "unhandled case when rewriting op"); 904 } 905 906 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, 907 int depth) { 908 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 909 LLVM_DEBUG(tree.print(llvm::dbgs())); 910 LLVM_DEBUG(llvm::dbgs() << '\n'); 911 912 auto fmt = tree.getNativeCodeTemplate(); 913 // TODO: replace formatv arguments with the exact specified args. 914 SmallVector<std::string, 8> attrs(8); 915 if (tree.getNumArgs() > 8) { 916 PrintFatalError(loc, 917 "unsupported NativeCodeCall replace argument numbers: " + 918 Twine(tree.getNumArgs())); 919 } 920 bool hasLocationDirective; 921 std::string locToUse; 922 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 923 924 for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { 925 if (tree.isNestedDagArg(i)) { 926 attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1); 927 } else { 928 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 929 } 930 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i 931 << " replacement: " << attrs[i] << "\n"); 932 } 933 return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0], 934 attrs[1], attrs[2], attrs[3], attrs[4], attrs[5], 935 attrs[6], attrs[7])); 936 } 937 938 int PatternEmitter::getNodeValueCount(DagNode node) { 939 if (node.isOperation()) { 940 // If the op is bound to a symbol in the rewrite rule, query its result 941 // count from the symbol info map. 942 auto symbol = node.getSymbol(); 943 if (!symbol.empty()) { 944 return symbolInfoMap.getStaticValueCount(symbol); 945 } 946 // Otherwise this is an unbound op; we will use all its results. 947 return pattern.getDialectOp(node).getNumResults(); 948 } 949 // TODO: This considers all NativeCodeCall as returning one 950 // value. Enhance if multi-value ones are needed. 951 return 1; 952 } 953 954 std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) { 955 auto numPatArgs = tree.getNumArgs(); 956 957 if (numPatArgs != 0) { 958 if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1)) 959 if (lastArg.isLocationDirective()) { 960 return std::make_pair(true, handleLocationDirective(lastArg)); 961 } 962 } 963 964 // If no explicit location is given, use the default, all fused, location. 965 return std::make_pair(false, "odsLoc"); 966 } 967 968 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 969 int depth) { 970 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); 971 LLVM_DEBUG(tree.print(llvm::dbgs())); 972 LLVM_DEBUG(llvm::dbgs() << '\n'); 973 974 Operator &resultOp = tree.getDialectOp(opMap); 975 auto numOpArgs = resultOp.getNumArgs(); 976 auto numPatArgs = tree.getNumArgs(); 977 978 bool hasLocationDirective; 979 std::string locToUse; 980 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 981 982 auto inPattern = numPatArgs - hasLocationDirective; 983 if (numOpArgs != inPattern) { 984 PrintFatalError(loc, 985 formatv("resultant op '{0}' argument number mismatch: " 986 "{1} in pattern vs. {2} in definition", 987 resultOp.getOperationName(), inPattern, numOpArgs)); 988 } 989 990 // A map to collect all nested DAG child nodes' names, with operand index as 991 // the key. This includes both bound and unbound child nodes. 992 ChildNodeIndexNameMap childNodeNames; 993 994 // First go through all the child nodes who are nested DAG constructs to 995 // create ops for them and remember the symbol names for them, so that we can 996 // use the results in the current node. This happens in a recursive manner. 997 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 998 if (auto child = tree.getArgAsNestedDag(i)) 999 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 1000 } 1001 1002 // The name of the local variable holding this op. 1003 std::string valuePackName; 1004 // The symbol for holding the result of this pattern. Note that the result of 1005 // this pattern is not necessarily the same as the variable created by this 1006 // pattern because we can use `__N` suffix to refer only a specific result if 1007 // the generated op is a multi-result op. 1008 std::string resultValue; 1009 if (tree.getSymbol().empty()) { 1010 // No symbol is explicitly bound to this op in the pattern. Generate a 1011 // unique name. 1012 valuePackName = resultValue = getUniqueSymbol(&resultOp); 1013 } else { 1014 resultValue = std::string(tree.getSymbol()); 1015 // Strip the index to get the name for the value pack and use it to name the 1016 // local variable for the op. 1017 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue)); 1018 } 1019 1020 // Create the local variable for this op. 1021 os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(), 1022 valuePackName); 1023 1024 // Right now ODS don't have general type inference support. Except a few 1025 // special cases listed below, DRR needs to supply types for all results 1026 // when building an op. 1027 bool isSameOperandsAndResultType = 1028 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType"); 1029 bool useFirstAttr = 1030 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"); 1031 1032 if (isSameOperandsAndResultType || useFirstAttr) { 1033 // We know how to deduce the result type for ops with these traits and we've 1034 // generated builders taking aggregate parameters. Use those builders to 1035 // create the ops. 1036 1037 // First prepare local variables for op arguments used in builder call. 1038 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 1039 1040 // Then create the op. 1041 os.scope("", "\n}\n").os << formatv( 1042 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);", 1043 valuePackName, resultOp.getQualCppClassName(), locToUse); 1044 return resultValue; 1045 } 1046 1047 bool usePartialResults = valuePackName != resultValue; 1048 1049 if (usePartialResults || depth > 0 || resultIndex < 0) { 1050 // For these cases (broadcastable ops, op results used both as auxiliary 1051 // values and replacement values, ops in nested patterns, auxiliary ops), we 1052 // still need to supply the result types when building the op. But because 1053 // we don't generate a builder automatically with ODS for them, it's the 1054 // developer's responsibility to make sure such a builder (with result type 1055 // deduction ability) exists. We go through the separate-parameter builder 1056 // here given that it's easier for developers to write compared to 1057 // aggregate-parameter builders. 1058 createSeparateLocalVarsForOpArgs(tree, childNodeNames); 1059 1060 os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, 1061 resultOp.getQualCppClassName(), locToUse); 1062 supplyValuesForOpArgs(tree, childNodeNames, depth); 1063 os << "\n );\n}\n"; 1064 return resultValue; 1065 } 1066 1067 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 1068 // generated from the source pattern root op. Then we can use the source 1069 // pattern's value types to determine the value type of the generated op 1070 // here. 1071 1072 // First prepare local variables for op arguments used in builder call. 1073 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 1074 1075 // Then prepare the result types. We need to specify the types for all 1076 // results. 1077 os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; " 1078 "(void)tblgen_types;\n"); 1079 int numResults = resultOp.getNumResults(); 1080 if (numResults != 0) { 1081 for (int i = 0; i < numResults; ++i) 1082 os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" 1083 " tblgen_types.push_back(v.getType());\n}\n", 1084 resultIndex + i); 1085 } 1086 os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " 1087 "tblgen_values, tblgen_attrs);\n", 1088 valuePackName, resultOp.getQualCppClassName(), locToUse); 1089 os.unindent() << "}\n"; 1090 return resultValue; 1091 } 1092 1093 void PatternEmitter::createSeparateLocalVarsForOpArgs( 1094 DagNode node, ChildNodeIndexNameMap &childNodeNames) { 1095 Operator &resultOp = node.getDialectOp(opMap); 1096 1097 // Now prepare operands used for building this op: 1098 // * If the operand is non-variadic, we create a `Value` local variable. 1099 // * If the operand is variadic, we create a `SmallVector<Value>` local 1100 // variable. 1101 1102 int valueIndex = 0; // An index for uniquing local variable names. 1103 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1104 const auto *operand = 1105 resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>(); 1106 // We do not need special handling for attributes. 1107 if (!operand) 1108 continue; 1109 1110 raw_indented_ostream::DelimitedScope scope(os); 1111 std::string varName; 1112 if (operand->isVariadic()) { 1113 varName = std::string(formatv("tblgen_values_{0}", valueIndex++)); 1114 os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName); 1115 std::string range; 1116 if (node.isNestedDagArg(argIndex)) { 1117 range = childNodeNames[argIndex]; 1118 } else { 1119 range = std::string(node.getArgName(argIndex)); 1120 } 1121 // Resolve the symbol for all range use so that we have a uniform way of 1122 // capturing the values. 1123 range = symbolInfoMap.getValueAndRangeUse(range); 1124 os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range, 1125 varName); 1126 } else { 1127 varName = std::string(formatv("tblgen_value_{0}", valueIndex++)); 1128 os << formatv("::mlir::Value {0} = ", varName); 1129 if (node.isNestedDagArg(argIndex)) { 1130 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 1131 } else { 1132 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1133 auto symbol = 1134 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1135 if (leaf.isNativeCodeCall()) { 1136 os << std::string( 1137 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1138 } else { 1139 os << symbol; 1140 } 1141 } 1142 os << ";\n"; 1143 } 1144 1145 // Update to use the newly created local variable for building the op later. 1146 childNodeNames[argIndex] = varName; 1147 } 1148 } 1149 1150 void PatternEmitter::supplyValuesForOpArgs( 1151 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 1152 Operator &resultOp = node.getDialectOp(opMap); 1153 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); 1154 argIndex != numOpArgs; ++argIndex) { 1155 // Start each argument on its own line. 1156 os << ",\n "; 1157 1158 Argument opArg = resultOp.getArg(argIndex); 1159 // Handle the case of operand first. 1160 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 1161 if (!operand->name.empty()) 1162 os << "/*" << operand->name << "=*/"; 1163 os << childNodeNames.lookup(argIndex); 1164 continue; 1165 } 1166 1167 // The argument in the op definition. 1168 auto opArgName = resultOp.getArgName(argIndex); 1169 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1170 if (!subTree.isNativeCodeCall()) 1171 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1172 "for creating attribute"); 1173 os << formatv("/*{0}=*/{1}", opArgName, 1174 handleReplaceWithNativeCodeCall(subTree, depth)); 1175 } else { 1176 auto leaf = node.getArgAsLeaf(argIndex); 1177 // The argument in the result DAG pattern. 1178 auto patArgName = node.getArgName(argIndex); 1179 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 1180 // TODO: Refactor out into map to avoid recomputing these. 1181 if (!opArg.is<NamedAttribute *>()) 1182 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 1183 if (!patArgName.empty()) 1184 os << "/*" << patArgName << "=*/"; 1185 } else { 1186 os << "/*" << opArgName << "=*/"; 1187 } 1188 os << handleOpArgument(leaf, patArgName); 1189 } 1190 } 1191 } 1192 1193 void PatternEmitter::createAggregateLocalVarsForOpArgs( 1194 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 1195 Operator &resultOp = node.getDialectOp(opMap); 1196 1197 auto scope = os.scope(); 1198 os << formatv("::mlir::SmallVector<::mlir::Value, 4> " 1199 "tblgen_values; (void)tblgen_values;\n"); 1200 os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> " 1201 "tblgen_attrs; (void)tblgen_attrs;\n"); 1202 1203 const char *addAttrCmd = 1204 "if (auto tmpAttr = {1}) {\n" 1205 " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), " 1206 "tmpAttr);\n}\n"; 1207 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1208 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) { 1209 // The argument in the op definition. 1210 auto opArgName = resultOp.getArgName(argIndex); 1211 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1212 if (!subTree.isNativeCodeCall()) 1213 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1214 "for creating attribute"); 1215 os << formatv(addAttrCmd, opArgName, 1216 handleReplaceWithNativeCodeCall(subTree, depth + 1)); 1217 } else { 1218 auto leaf = node.getArgAsLeaf(argIndex); 1219 // The argument in the result DAG pattern. 1220 auto patArgName = node.getArgName(argIndex); 1221 os << formatv(addAttrCmd, opArgName, 1222 handleOpArgument(leaf, patArgName)); 1223 } 1224 continue; 1225 } 1226 1227 const auto *operand = 1228 resultOp.getArg(argIndex).get<NamedTypeConstraint *>(); 1229 std::string varName; 1230 if (operand->isVariadic()) { 1231 std::string range; 1232 if (node.isNestedDagArg(argIndex)) { 1233 range = childNodeNames.lookup(argIndex); 1234 } else { 1235 range = std::string(node.getArgName(argIndex)); 1236 } 1237 // Resolve the symbol for all range use so that we have a uniform way of 1238 // capturing the values. 1239 range = symbolInfoMap.getValueAndRangeUse(range); 1240 os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n", 1241 range); 1242 } else { 1243 os << formatv("tblgen_values.push_back("); 1244 if (node.isNestedDagArg(argIndex)) { 1245 os << symbolInfoMap.getValueAndRangeUse( 1246 childNodeNames.lookup(argIndex)); 1247 } else { 1248 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1249 auto symbol = 1250 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1251 if (leaf.isNativeCodeCall()) { 1252 os << std::string( 1253 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1254 } else { 1255 os << symbol; 1256 } 1257 } 1258 os << ");\n"; 1259 } 1260 } 1261 } 1262 1263 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 1264 emitSourceFileHeader("Rewriters", os); 1265 1266 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 1267 auto numPatterns = patterns.size(); 1268 1269 // We put the map here because it can be shared among multiple patterns. 1270 RecordOperatorMap recordOpMap; 1271 1272 std::vector<std::string> rewriterNames; 1273 rewriterNames.reserve(numPatterns); 1274 1275 std::string baseRewriterName = "GeneratedConvert"; 1276 int rewriterIndex = 0; 1277 1278 for (Record *p : patterns) { 1279 std::string name; 1280 if (p->isAnonymous()) { 1281 // If no name is provided, ensure unique rewriter names simply by 1282 // appending unique suffix. 1283 name = baseRewriterName + llvm::utostr(rewriterIndex++); 1284 } else { 1285 name = std::string(p->getName()); 1286 } 1287 LLVM_DEBUG(llvm::dbgs() 1288 << "=== start generating pattern '" << name << "' ===\n"); 1289 PatternEmitter(p, &recordOpMap, os).emit(name); 1290 LLVM_DEBUG(llvm::dbgs() 1291 << "=== done generating pattern '" << name << "' ===\n"); 1292 rewriterNames.push_back(std::move(name)); 1293 } 1294 1295 // Emit function to add the generated matchers to the pattern list. 1296 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::MLIRContext " 1297 "*context, ::mlir::OwningRewritePatternList &patterns) {\n"; 1298 for (const auto &name : rewriterNames) { 1299 os << " patterns.insert<" << name << ">(context);\n"; 1300 } 1301 os << "}\n"; 1302 } 1303 1304 static mlir::GenRegistration 1305 genRewriters("gen-rewriters", "Generate pattern rewriters", 1306 [](const RecordKeeper &records, raw_ostream &os) { 1307 emitRewriters(records, os); 1308 return false; 1309 }); 1310