1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Support/STLExtras.h" 23 #include "mlir/TableGen/Attribute.h" 24 #include "mlir/TableGen/Format.h" 25 #include "mlir/TableGen/GenInfo.h" 26 #include "mlir/TableGen/Operator.h" 27 #include "mlir/TableGen/Pattern.h" 28 #include "mlir/TableGen/Predicate.h" 29 #include "mlir/TableGen/Type.h" 30 #include "llvm/ADT/StringExtras.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/FormatAdapters.h" 34 #include "llvm/Support/PrettyStackTrace.h" 35 #include "llvm/Support/Signals.h" 36 #include "llvm/TableGen/Error.h" 37 #include "llvm/TableGen/Main.h" 38 #include "llvm/TableGen/Record.h" 39 #include "llvm/TableGen/TableGenBackend.h" 40 41 using namespace llvm; 42 using namespace mlir; 43 using namespace mlir::tblgen; 44 45 namespace llvm { 46 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 47 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 48 raw_ostream &os, StringRef style) { 49 os << v.first << ":" << v.second; 50 } 51 }; 52 } // end namespace llvm 53 54 //===----------------------------------------------------------------------===// 55 // PatternEmitter 56 //===----------------------------------------------------------------------===// 57 58 namespace { 59 class PatternEmitter { 60 public: 61 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 62 63 // Emits the mlir::RewritePattern struct named `rewriteName`. 64 void emit(StringRef rewriteName); 65 66 private: 67 // Emits the code for matching ops. 68 void emitMatchLogic(DagNode tree); 69 70 // Emits the code for rewriting ops. 71 void emitRewriteLogic(); 72 73 //===--------------------------------------------------------------------===// 74 // Match utilities 75 //===--------------------------------------------------------------------===// 76 77 // Emits C++ statements for matching the op constrained by the given DAG 78 // `tree`. 79 void emitOpMatch(DagNode tree, int depth); 80 81 // Emits C++ statements for matching the `index`-th argument of the given DAG 82 // `tree` as an operand. 83 void emitOperandMatch(DagNode tree, int index, int depth, int indent); 84 85 // Emits C++ statements for matching the `index`-th argument of the given DAG 86 // `tree` as an attribute. 87 void emitAttributeMatch(DagNode tree, int index, int depth, int indent); 88 89 //===--------------------------------------------------------------------===// 90 // Rewrite utilities 91 //===--------------------------------------------------------------------===// 92 93 // Entry point for handling a result pattern rooted at `resultTree` and 94 // dispatches to concrete handlers. The given tree is the `resultIndex`-th 95 // argument of the enclosing DAG. 96 std::string handleResultPattern(DagNode resultTree, int resultIndex, 97 int depth); 98 99 // Emits the C++ statement to replace the matched DAG with a value built via 100 // calling native C++ code. 101 std::string handleReplaceWithNativeCodeCall(DagNode resultTree); 102 103 // Returns the C++ expression referencing the old value serving as the 104 // replacement. 105 std::string handleReplaceWithValue(DagNode tree); 106 107 // Emits the C++ statement to build a new op out of the given DAG `tree` and 108 // returns the variable name that this op is assigned to. If the root op in 109 // DAG `tree` has a specified name, the created op will be assigned to a 110 // variable of the given name. Otherwise, a unique name will be used as the 111 // result value name. 112 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 113 114 // Returns the C++ expression to construct a constant attribute of the given 115 // `value` for the given attribute kind `attr`. 116 std::string handleConstantAttr(Attribute attr, StringRef value); 117 118 // Returns the C++ expression to build an argument from the given DAG `leaf`. 119 // `patArgName` is used to bound the argument to the source pattern. 120 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 121 122 //===--------------------------------------------------------------------===// 123 // General utilities 124 //===--------------------------------------------------------------------===// 125 126 // Collects all of the operations within the given dag tree. 127 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 128 129 // Returns a unique symbol for a local variable of the given `op`. 130 std::string getUniqueSymbol(const Operator *op); 131 132 //===--------------------------------------------------------------------===// 133 // Symbol utilities 134 //===--------------------------------------------------------------------===// 135 136 // Gets the substitution for `symbol`. Aborts if `symbol` is not bound. 137 std::string resolveSymbol(StringRef symbol); 138 139 // Returns how many static values the given DAG `node` correspond to. 140 int getNodeValueCount(DagNode node); 141 142 private: 143 // Pattern instantiation location followed by the location of multiclass 144 // prototypes used. This is intended to be used as a whole to 145 // PrintFatalError() on errors. 146 ArrayRef<llvm::SMLoc> loc; 147 148 // Op's TableGen Record to wrapper object. 149 RecordOperatorMap *opMap; 150 151 // Handy wrapper for pattern being emitted. 152 Pattern pattern; 153 154 // Map for all bound symbols' info. 155 SymbolInfoMap symbolInfoMap; 156 157 // The next unused ID for newly created values. 158 unsigned nextValueId; 159 160 raw_ostream &os; 161 162 // Format contexts containing placeholder substitutations. 163 FmtContext fmtCtx; 164 165 // Number of op processed. 166 int opCounter = 0; 167 }; 168 } // end anonymous namespace 169 170 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 171 raw_ostream &os) 172 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 173 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 174 fmtCtx.withBuilder("rewriter"); 175 } 176 177 std::string PatternEmitter::handleConstantAttr(Attribute attr, 178 StringRef value) { 179 if (!attr.isConstBuildable()) 180 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 181 " does not have the 'constBuilderCall' field"); 182 183 // TODO(jpienaar): Verify the constants here 184 return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value); 185 } 186 187 // Helper function to match patterns. 188 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 189 Operator &op = tree.getDialectOp(opMap); 190 if (op.isVariadic()) { 191 PrintFatalError(loc, formatv("matching op '{0}' with variadic " 192 "operands/results is unsupported right now", 193 op.getOperationName())); 194 } 195 196 int indent = 4 + 2 * depth; 197 os.indent(indent) << formatv( 198 "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n", 199 depth, op.getQualCppClassName()); 200 // Skip the operand matching at depth 0 as the pattern rewriter already does. 201 if (depth != 0) { 202 // Skip if there is no defining operation (e.g., arguments to function). 203 os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n", 204 depth); 205 } 206 if (tree.getNumArgs() != op.getNumArgs()) { 207 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 208 "pattern vs. {2} in definition", 209 op.getOperationName(), tree.getNumArgs(), 210 op.getNumArgs())); 211 } 212 213 // If the operand's name is set, set to that variable. 214 auto name = tree.getSymbol(); 215 if (!name.empty()) 216 os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth); 217 218 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 219 auto opArg = op.getArg(i); 220 221 // Handle nested DAG construct first 222 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 223 os.indent(indent) << "{\n"; 224 os.indent(indent + 2) 225 << formatv("auto *op{0} = op{1}->getOperand({2})->getDefiningOp();\n", 226 depth + 1, depth, i); 227 emitOpMatch(argTree, depth + 1); 228 os.indent(indent + 2) 229 << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); 230 os.indent(indent) << "}\n"; 231 continue; 232 } 233 234 // Next handle DAG leaf: operand or attribute 235 if (opArg.is<NamedTypeConstraint *>()) { 236 emitOperandMatch(tree, i, depth, indent); 237 } else if (opArg.is<NamedAttribute *>()) { 238 emitAttributeMatch(tree, i, depth, indent); 239 } else { 240 PrintFatalError(loc, "unhandled case when matching op"); 241 } 242 } 243 } 244 245 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, 246 int indent) { 247 Operator &op = tree.getDialectOp(opMap); 248 auto *operand = op.getArg(index).get<NamedTypeConstraint *>(); 249 auto matcher = tree.getArgAsLeaf(index); 250 251 // If a constraint is specified, we need to generate C++ statements to 252 // check the constraint. 253 if (!matcher.isUnspecified()) { 254 if (!matcher.isOperandMatcher()) { 255 PrintFatalError( 256 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 257 op.getOperationName(), index + 1)); 258 } 259 260 // Only need to verify if the matcher's type is different from the one 261 // of op definition. 262 if (operand->constraint != matcher.getAsConstraint()) { 263 auto self = formatv("op{0}->getOperand({1})->getType()", depth, index); 264 os.indent(indent) << "if (!(" 265 << tgfmt(matcher.getConditionTemplate(), 266 &fmtCtx.withSelf(self)) 267 << ")) return matchFailure();\n"; 268 } 269 } 270 271 // Capture the value 272 auto name = tree.getArgName(index); 273 if (!name.empty()) { 274 os.indent(indent) << formatv("{0} = op{1}->getOperand({2});\n", name, depth, 275 index); 276 } 277 } 278 279 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, 280 int indent) { 281 Operator &op = tree.getDialectOp(opMap); 282 auto *namedAttr = op.getArg(index).get<NamedAttribute *>(); 283 const auto &attr = namedAttr->attr; 284 285 os.indent(indent) << "{\n"; 286 indent += 2; 287 os.indent(indent) << formatv( 288 "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, 289 attr.getStorageType(), namedAttr->name); 290 291 // TODO(antiagainst): This should use getter method to avoid duplication. 292 if (attr.hasDefaultValueInitializer()) { 293 os.indent(indent) << "if (!tblgen_attr) tblgen_attr = " 294 << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 295 attr.getDefaultValueInitializer()) 296 << ";\n"; 297 } else if (attr.isOptional()) { 298 // For a missing attribute that is optional according to definition, we 299 // should just capature a mlir::Attribute() to signal the missing state. 300 // That is precisely what getAttr() returns on missing attributes. 301 } else { 302 os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n"; 303 } 304 305 auto matcher = tree.getArgAsLeaf(index); 306 if (!matcher.isUnspecified()) { 307 if (!matcher.isAttrMatcher()) { 308 PrintFatalError( 309 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 310 op.getOperationName(), index + 1)); 311 } 312 313 // If a constraint is specified, we need to generate C++ statements to 314 // check the constraint. 315 os.indent(indent) << "if (!(" 316 << tgfmt(matcher.getConditionTemplate(), 317 &fmtCtx.withSelf("tblgen_attr")) 318 << ")) return matchFailure();\n"; 319 } 320 321 // Capture the value 322 auto name = tree.getArgName(index); 323 if (!name.empty()) { 324 os.indent(indent) << formatv("{0} = tblgen_attr;\n", name); 325 } 326 327 indent -= 2; 328 os.indent(indent) << "}\n"; 329 } 330 331 void PatternEmitter::emitMatchLogic(DagNode tree) { 332 emitOpMatch(tree, 0); 333 334 for (auto &appliedConstraint : pattern.getConstraints()) { 335 auto &constraint = appliedConstraint.constraint; 336 auto &entities = appliedConstraint.entities; 337 338 auto condition = constraint.getConditionTemplate(); 339 auto cmd = "if (!({0})) return matchFailure();\n"; 340 341 if (isa<TypeConstraint>(constraint)) { 342 auto self = formatv("({0}->getType())", resolveSymbol(entities.front())); 343 os.indent(4) << formatv(cmd, 344 tgfmt(condition, &fmtCtx.withSelf(self.str()))); 345 } else if (isa<AttrConstraint>(constraint)) { 346 PrintFatalError( 347 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 348 } else { 349 // TODO(b/138794486): replace formatv arguments with the exact specified 350 // args. 351 if (entities.size() > 4) { 352 PrintFatalError(loc, "only support up to 4-entity constraints now"); 353 } 354 SmallVector<std::string, 4> names; 355 int i = 0; 356 for (int e = entities.size(); i < e; ++i) 357 names.push_back(resolveSymbol(entities[i])); 358 std::string self = appliedConstraint.self; 359 if (!self.empty()) 360 self = resolveSymbol(self); 361 for (; i < 4; ++i) 362 names.push_back("<unused>"); 363 os.indent(4) << formatv(cmd, 364 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 365 names[1], names[2], names[3])); 366 } 367 } 368 } 369 370 void PatternEmitter::collectOps(DagNode tree, 371 llvm::SmallPtrSetImpl<const Operator *> &ops) { 372 // Check if this tree is an operation. 373 if (tree.isOperation()) 374 ops.insert(&tree.getDialectOp(opMap)); 375 376 // Recurse the arguments of the tree. 377 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 378 if (auto child = tree.getArgAsNestedDag(i)) 379 collectOps(child, ops); 380 } 381 382 void PatternEmitter::emit(StringRef rewriteName) { 383 // Get the DAG tree for the source pattern. 384 DagNode sourceTree = pattern.getSourcePattern(); 385 386 const Operator &rootOp = pattern.getSourceRootOp(); 387 auto rootName = rootOp.getOperationName(); 388 389 // Collect the set of result operations. 390 llvm::SmallPtrSet<const Operator *, 4> resultOps; 391 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) 392 collectOps(pattern.getResultPattern(i), resultOps); 393 394 // Emit RewritePattern for Pattern. 395 auto locs = pattern.getLocation(); 396 os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", 397 make_range(locs.rbegin(), locs.rend())); 398 os << formatv(R"(struct {0} : public RewritePattern { 399 {0}(MLIRContext *context) 400 : RewritePattern("{1}", {{)", 401 rewriteName, rootName); 402 interleaveComma(resultOps, os, [&](const Operator *op) { 403 os << '"' << op->getOperationName() << '"'; 404 }); 405 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 406 407 // Emit matchAndRewrite() function. 408 os << R"( 409 PatternMatchResult matchAndRewrite(Operation *op0, 410 PatternRewriter &rewriter) const override { 411 )"; 412 413 // Register all symbols bound in the source pattern. 414 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 415 416 os.indent(4) << "// Variables for capturing values and attributes used for " 417 "creating ops\n"; 418 // Create local variables for storing the arguments and results bound 419 // to symbols. 420 for (const auto &symbolInfoPair : symbolInfoMap) { 421 StringRef symbol = symbolInfoPair.getKey(); 422 auto &info = symbolInfoPair.getValue(); 423 os.indent(4) << info.getVarDecl(symbol); 424 } 425 // TODO(jpienaar): capture ops with consistent numbering so that it can be 426 // reused for fused loc. 427 os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n", 428 pattern.getSourcePattern().getNumOps()); 429 430 os.indent(4) << "// Match\n"; 431 os.indent(4) << "tblgen_ops[0] = op0;\n"; 432 emitMatchLogic(sourceTree); 433 os << "\n"; 434 435 os.indent(4) << "// Rewrite\n"; 436 emitRewriteLogic(); 437 438 os.indent(4) << "return matchSuccess();\n"; 439 os << " };\n"; 440 os << "};\n"; 441 } 442 443 void PatternEmitter::emitRewriteLogic() { 444 const Operator &rootOp = pattern.getSourceRootOp(); 445 int numExpectedResults = rootOp.getNumResults(); 446 int numResultPatterns = pattern.getNumResultPatterns(); 447 448 // First register all symbols bound to ops generated in result patterns. 449 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 450 451 // Only the last N static values generated are used to replace the matched 452 // root N-result op. We need to calculate the starting index (of the results 453 // of the matched op) each result pattern is to replace. 454 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 455 // If we don't need to replace any value at all, set the replacement starting 456 // index as the number of result patterns so we skip all of them when trying 457 // to replace the matched op's results. 458 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 459 for (int i = numResultPatterns - 1; i >= 0; --i) { 460 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 461 offsets[i] = offsets[i + 1] - numValues; 462 if (offsets[i] == 0) { 463 if (replStartIndex == -1) 464 replStartIndex = i; 465 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 466 auto error = formatv( 467 "cannot use the same multi-result op '{0}' to generate both " 468 "auxiliary values and values to be used for replacing the matched op", 469 pattern.getResultPattern(i).getSymbol()); 470 PrintFatalError(loc, error); 471 } 472 } 473 474 if (offsets.front() > 0) { 475 const char error[] = "no enough values generated to replace the matched op"; 476 PrintFatalError(loc, error); 477 } 478 479 os.indent(4) << "auto loc = rewriter.getFusedLoc({"; 480 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 481 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 482 } 483 os << "}); (void)loc;\n"; 484 485 // Collect the replacement value for each result 486 llvm::SmallVector<std::string, 2> resultValues; 487 for (int i = 0; i < numResultPatterns; ++i) { 488 DagNode resultTree = pattern.getResultPattern(i); 489 resultValues.push_back(handleResultPattern(resultTree, offsets[i], 0)); 490 } 491 492 // Emit the final replaceOp() statement 493 os.indent(4) << "rewriter.replaceOp(op0, {"; 494 interleaveComma( 495 ArrayRef<std::string>(resultValues).drop_front(replStartIndex), os, 496 [&](const std::string &symbol) { os << resolveSymbol(symbol); }); 497 os << "});\n"; 498 } 499 500 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 501 return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++); 502 } 503 504 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 505 int resultIndex, int depth) { 506 if (resultTree.isNativeCodeCall()) { 507 auto symbol = handleReplaceWithNativeCodeCall(resultTree); 508 symbolInfoMap.bindValue(symbol); 509 return symbol; 510 } 511 512 if (resultTree.isReplaceWithValue()) { 513 return handleReplaceWithValue(resultTree); 514 } 515 516 // Normal op creation. 517 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 518 if (resultTree.getSymbol().empty()) { 519 // This is an op not explicitly bound to a symbol in the rewrite rule. 520 // Register the auto-generated symbol for it. 521 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 522 } 523 return symbol; 524 } 525 526 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { 527 assert(tree.isReplaceWithValue()); 528 529 if (tree.getNumArgs() != 1) { 530 PrintFatalError( 531 loc, "replaceWithValue directive must take exactly one argument"); 532 } 533 534 if (!tree.getSymbol().empty()) { 535 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 536 } 537 538 return resolveSymbol(tree.getArgName(0)); 539 } 540 541 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) { 542 if (leaf.isConstantAttr()) { 543 auto constAttr = leaf.getAsConstantAttr(); 544 return handleConstantAttr(constAttr.getAttribute(), 545 constAttr.getConstantValue()); 546 } 547 if (leaf.isEnumAttrCase()) { 548 auto enumCase = leaf.getAsEnumAttrCase(); 549 if (enumCase.isStrCase()) 550 return handleConstantAttr(enumCase, enumCase.getSymbol()); 551 // This is an enum case backed by an IntegerAttr. We need to get its value 552 // to build the constant. 553 std::string val = std::to_string(enumCase.getValue()); 554 return handleConstantAttr(enumCase, val); 555 } 556 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 557 return argName; 558 } 559 if (leaf.isNativeCodeCall()) { 560 return tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 561 } 562 PrintFatalError(loc, "unhandled case when rewriting op"); 563 } 564 565 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { 566 auto fmt = tree.getNativeCodeTemplate(); 567 // TODO(b/138794486): replace formatv arguments with the exact specified args. 568 SmallVector<std::string, 8> attrs(8); 569 if (tree.getNumArgs() > 8) { 570 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 571 Twine(tree.getNumArgs())); 572 } 573 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 574 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 575 } 576 return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4], 577 attrs[5], attrs[6], attrs[7]); 578 } 579 580 std::string PatternEmitter::resolveSymbol(StringRef symbol) { 581 auto subst = symbolInfoMap.getValueAndRangeUse(symbol); 582 if (subst.empty()) { 583 PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol)); 584 } 585 return subst; 586 } 587 588 int PatternEmitter::getNodeValueCount(DagNode node) { 589 if (node.isOperation()) { 590 // If the op is bound to a symbol in the rewrite rule, query its result 591 // count from the symbol info map. 592 auto symbol = node.getSymbol(); 593 if (!symbol.empty()) { 594 return symbolInfoMap.getStaticValueCount(symbol); 595 } 596 // Otherwise this is an unbound op; we will use all its results. 597 return pattern.getDialectOp(node).getNumResults(); 598 } 599 // TODO(antiagainst): This considers all NativeCodeCall as returning one 600 // value. Enhance if multi-value ones are needed. 601 return 1; 602 } 603 604 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 605 int depth) { 606 Operator &resultOp = tree.getDialectOp(opMap); 607 auto numOpArgs = resultOp.getNumArgs(); 608 609 if (resultOp.isVariadic()) { 610 PrintFatalError(loc, formatv("generating op '{0}' with variadic " 611 "operands/results is unsupported now", 612 resultOp.getOperationName())); 613 } 614 615 if (numOpArgs != tree.getNumArgs()) { 616 PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " 617 "{1} in pattern vs. {2} in definition", 618 resultOp.getOperationName(), tree.getNumArgs(), 619 numOpArgs)); 620 } 621 622 // A map to collect all nested DAG child nodes' names, with operand index as 623 // the key. This includes both bound and unbound child nodes. Bound child 624 // nodes will additionally be tracked in `symbolResolver` so they can be 625 // referenced by other patterns. Unbound child nodes will only be used once 626 // to build this op. 627 llvm::DenseMap<unsigned, std::string> childNodeNames; 628 629 // First go through all the child nodes who are nested DAG constructs to 630 // create ops for them, so that we can use the results in the current node. 631 // This happens in a recursive manner. 632 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 633 if (auto child = tree.getArgAsNestedDag(i)) { 634 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 635 } 636 } 637 638 // Use the specified name for this op if available. Generate one otherwise. 639 std::string resultValue = tree.getSymbol(); 640 if (resultValue.empty()) 641 resultValue = getUniqueSymbol(&resultOp); 642 // Strip the index to get the name for the value pack. This will be used to 643 // name the local variable for the op. 644 StringRef valuePackName = SymbolInfoMap::getValuePackName(resultValue); 645 646 // Then we build the new op corresponding to this DAG node. 647 648 // Right now we don't have general type inference in MLIR. Except a few 649 // special cases listed below, we need to supply types for all results 650 // when building an op. 651 bool isSameOperandsAndResultType = 652 resultOp.hasTrait("SameOperandsAndResultType"); 653 bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult"); 654 bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType"); 655 bool usePartialResults = valuePackName != resultValue; 656 657 if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr || 658 usePartialResults || depth > 0 || resultIndex < 0) { 659 os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", 660 valuePackName, resultOp.getQualCppClassName()); 661 } else { 662 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 663 // generated from the source pattern root op. Then we can use the source 664 // pattern's value types to determine the value type of the generated op 665 // here. 666 667 // We need to specify the types for all results. 668 SmallVector<std::string, 4> resultTypes; 669 int numResults = resultOp.getNumResults(); 670 resultTypes.reserve(numResults); 671 for (int i = 0; i < numResults; ++i) { 672 resultTypes.push_back( 673 formatv("op0->getResult({0})->getType()", resultIndex + i)); 674 } 675 676 os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", 677 valuePackName, resultOp.getQualCppClassName()) 678 << (resultTypes.empty() ? "" : ", ") 679 << llvm::join(resultTypes, ", "); 680 } 681 682 // Create the builder call for the result. 683 // Add operands. 684 int argIndex = 0; 685 for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) { 686 const auto &operand = resultOp.getOperand(argIndex); 687 688 // Start each operand on its own line. 689 (os << ",\n").indent(6); 690 691 if (!operand.name.empty()) 692 os << "/*" << operand.name << "=*/"; 693 694 if (tree.isNestedDagArg(argIndex)) { 695 os << childNodeNames[argIndex]; 696 } else { 697 DagLeaf leaf = tree.getArgAsLeaf(argIndex); 698 auto symbol = resolveSymbol(tree.getArgName(argIndex)); 699 if (leaf.isNativeCodeCall()) { 700 os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)); 701 } else { 702 os << symbol; 703 } 704 } 705 // TODO(jpienaar): verify types 706 } 707 708 // Add attributes. 709 for (; argIndex != numOpArgs; ++argIndex) { 710 // Start each attribute on its own line. 711 (os << ",\n").indent(6); 712 // The argument in the op definition. 713 auto opArgName = resultOp.getArgName(argIndex); 714 if (auto subTree = tree.getArgAsNestedDag(argIndex)) { 715 if (!subTree.isNativeCodeCall()) 716 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 717 "for creating attribute"); 718 os << formatv("/*{0}=*/{1}", opArgName, 719 handleReplaceWithNativeCodeCall(subTree)); 720 } else { 721 auto leaf = tree.getArgAsLeaf(argIndex); 722 // The argument in the result DAG pattern. 723 auto patArgName = tree.getArgName(argIndex); 724 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 725 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 726 auto argument = resultOp.getArg(argIndex); 727 if (!argument.is<NamedAttribute *>()) 728 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 729 if (!patArgName.empty()) 730 os << "/*" << patArgName << "=*/"; 731 } else { 732 os << "/*" << opArgName << "=*/"; 733 } 734 os << handleOpArgument(leaf, patArgName); 735 } 736 } 737 os << "\n );\n"; 738 739 return resultValue; 740 } 741 742 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 743 emitSourceFileHeader("Rewriters", os); 744 745 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 746 auto numPatterns = patterns.size(); 747 748 // We put the map here because it can be shared among multiple patterns. 749 RecordOperatorMap recordOpMap; 750 751 std::vector<std::string> rewriterNames; 752 rewriterNames.reserve(numPatterns); 753 754 std::string baseRewriterName = "GeneratedConvert"; 755 int rewriterIndex = 0; 756 757 for (Record *p : patterns) { 758 std::string name; 759 if (p->isAnonymous()) { 760 // If no name is provided, ensure unique rewriter names simply by 761 // appending unique suffix. 762 name = baseRewriterName + llvm::utostr(rewriterIndex++); 763 } else { 764 name = p->getName(); 765 } 766 PatternEmitter(p, &recordOpMap, os).emit(name); 767 rewriterNames.push_back(std::move(name)); 768 } 769 770 // Emit function to add the generated matchers to the pattern list. 771 os << "void populateWithGenerated(MLIRContext *context, " 772 << "OwningRewritePatternList *patterns) {\n"; 773 for (const auto &name : rewriterNames) { 774 os << " patterns->insert<" << name << ">(context);\n"; 775 } 776 os << "}\n"; 777 } 778 779 static mlir::GenRegistration 780 genRewriters("gen-rewriters", "Generate pattern rewriters", 781 [](const RecordKeeper &records, raw_ostream &os) { 782 emitRewriters(records, os); 783 return false; 784 }); 785