1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Support/STLExtras.h" 23 #include "mlir/TableGen/Attribute.h" 24 #include "mlir/TableGen/Format.h" 25 #include "mlir/TableGen/GenInfo.h" 26 #include "mlir/TableGen/Operator.h" 27 #include "mlir/TableGen/Pattern.h" 28 #include "mlir/TableGen/Predicate.h" 29 #include "mlir/TableGen/Type.h" 30 #include "llvm/ADT/StringExtras.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/FormatAdapters.h" 34 #include "llvm/Support/PrettyStackTrace.h" 35 #include "llvm/Support/Signals.h" 36 #include "llvm/TableGen/Error.h" 37 #include "llvm/TableGen/Main.h" 38 #include "llvm/TableGen/Record.h" 39 #include "llvm/TableGen/TableGenBackend.h" 40 41 using namespace llvm; 42 using namespace mlir; 43 using namespace mlir::tblgen; 44 45 namespace llvm { 46 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 47 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 48 raw_ostream &os, StringRef style) { 49 os << v.first << ":" << v.second; 50 } 51 }; 52 } // end namespace llvm 53 54 // Returns the bound symbol for the given op argument or op named `symbol`. 55 // 56 // Arguments and ops bound in the source pattern are grouped into a 57 // transient `PatternState` struct. This struct can be accessed in both 58 // `match()` and `rewrite()` via the local variable named as `s`. 59 static Twine getBoundSymbol(const StringRef &symbol) { 60 return Twine("s.") + symbol; 61 } 62 63 // Gets the dynamic value pack's name by removing the index suffix from 64 // `symbol`. Returns `symbol` itself if it does not contain an index. 65 // 66 // We can use `name__<index>` to access the `<index>`-th value in the dynamic 67 // value pack bound to `name`. `name` is typically the results of an 68 // multi-result op. 69 static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) { 70 StringRef name, indexStr; 71 unsigned idx = 0; 72 std::tie(name, indexStr) = symbol.rsplit("__"); 73 if (indexStr.consumeInteger(10, idx)) { 74 // The second part is not an index. 75 return symbol; 76 } 77 if (index) 78 *index = idx; 79 return name; 80 } 81 82 // Formats all values from a dynamic value pack `symbol` according to the given 83 // `fmt` string. The `fmt` string should use `{0}` as a placeholder for `symbol` 84 // and `{1}` as a placeholder for the value index, which will be offsetted by 85 // `offset`. The `symbol` value pack has a total of `count` values. 86 // 87 // This extracts one value from the pack if `symbol` contains an index, 88 // otherwise it extracts all values sequentially and returns them as a 89 // comma-separated list. 90 static std::string formatValuePack(const char *fmt, StringRef symbol, 91 unsigned count, unsigned offset) { 92 auto getNthValue = [fmt, offset](StringRef results, 93 unsigned index) -> std::string { 94 return formatv(fmt, results, index + offset); 95 }; 96 97 unsigned index = 0; 98 StringRef name = getValuePackName(symbol, &index); 99 if (name != symbol) { 100 // The symbol contains an index. 101 return getNthValue(name, index); 102 } 103 104 // The symbol does not contain an index. Treat the symbol as a whole. 105 SmallVector<std::string, 4> values; 106 values.reserve(count); 107 for (unsigned i = 0; i < count; ++i) 108 values.emplace_back(getNthValue(symbol, i)); 109 return llvm::join(values, ", "); 110 } 111 112 //===----------------------------------------------------------------------===// 113 // PatternSymbolResolver 114 //===----------------------------------------------------------------------===// 115 116 namespace { 117 // A class for resolving symbols bound in patterns. 118 // 119 // Symbols can be bound to op arguments and ops in the source pattern and ops 120 // in result patterns. For example, in 121 // 122 // ``` 123 // def : Pattern<(SrcOp:$op1 $arg0, %arg1), 124 // [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>; 125 // ``` 126 // 127 // `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`. 128 // `$op2` is bound to `ResOp1`. 129 // 130 // If a symbol binds to a multi-result op and it does not have the `__N` 131 // suffix, the symbol is expanded to the whole value pack generated by the 132 // multi-result op. If the symbol has a `__N` suffix, then it will expand to 133 // only the N-th result. 134 // 135 // This class keeps track of such symbols and translates them into their bound 136 // values. 137 // 138 // Note that we also generate local variables for unnamed DAG nodes, like 139 // `(ResOp3)` in the above. Since we don't bind a symbol to the op, the 140 // generated local variable will be implicitly named. Those implicit names are 141 // not tracked in this class. 142 class PatternSymbolResolver { 143 public: 144 PatternSymbolResolver(const StringMap<Argument> &srcArgs, 145 const StringMap<const Operator *> &srcOperations); 146 147 // Marks the given `symbol` as bound to a value pack with `numValues` and 148 // returns true on success. Returns false if the `symbol` is already bound. 149 bool add(StringRef symbol, int numValues); 150 151 // Queries the substitution for the given `symbol`. Returns empty string if 152 // symbol not found. If the symbol represents a value pack, returns all the 153 // values separated via comma. 154 std::string query(StringRef symbol) const; 155 156 // Returns how many static values the given `symbol` correspond to. Returns a 157 // negative value if the given symbol is not bound. 158 // 159 // Normally a symbol would correspond to just one value; for symbols bound to 160 // multi-result ops, it can be more than one. 161 int getValueCount(StringRef symbol) const; 162 163 private: 164 // Symbols bound to arguments in source pattern. 165 const StringMap<Argument> &sourceArguments; 166 // Symbols bound to ops (for their results) in source pattern. 167 const StringMap<const Operator *> &sourceOps; 168 // Symbols bound to ops (for their results) in result patterns. 169 // Key: symbol; value: number of values inside the pack 170 StringMap<int> resultOps; 171 }; 172 } // end anonymous namespace 173 174 PatternSymbolResolver::PatternSymbolResolver( 175 const StringMap<Argument> &srcArgs, 176 const StringMap<const Operator *> &srcOperations) 177 : sourceArguments(srcArgs), sourceOps(srcOperations) {} 178 179 bool PatternSymbolResolver::add(StringRef symbol, int numValues) { 180 StringRef name = getValuePackName(symbol); 181 return resultOps.try_emplace(name, numValues).second; 182 } 183 184 std::string PatternSymbolResolver::query(StringRef symbol) const { 185 StringRef name = getValuePackName(symbol); 186 // Handle symbols bound to generated ops 187 auto resOpIt = resultOps.find(name); 188 if (resOpIt != resultOps.end()) 189 return formatValuePack("{0}.getOperation()->getResult({1})", symbol, 190 resOpIt->second, /*offset=*/0); 191 192 // Handle symbols bound to matched op arguments 193 auto srcArgIt = sourceArguments.find(symbol); 194 if (srcArgIt != sourceArguments.end()) 195 return getBoundSymbol(symbol).str(); 196 197 // Handle symbols bound to matched op results 198 auto srcOpIt = sourceOps.find(name); 199 if (srcOpIt != sourceOps.end()) 200 return formatValuePack("{0}->getResult({1})", getBoundSymbol(symbol).str(), 201 srcOpIt->second->getNumResults(), /*offset=*/0); 202 return {}; 203 } 204 205 int PatternSymbolResolver::getValueCount(StringRef symbol) const { 206 StringRef name = getValuePackName(symbol); 207 // Handle symbols bound to generated ops 208 auto resOpIt = resultOps.find(name); 209 if (resOpIt != resultOps.end()) 210 return name == symbol ? resOpIt->second : 1; 211 212 // Handle symbols bound to matched op arguments 213 if (sourceArguments.count(symbol)) 214 return 1; 215 216 // Handle symbols bound to matched op results 217 auto srcOpIt = sourceOps.find(name); 218 if (srcOpIt != sourceOps.end()) 219 return name == symbol ? srcOpIt->second->getNumResults() : 1; 220 return -1; 221 } 222 223 //===----------------------------------------------------------------------===// 224 // PatternEmitter 225 //===----------------------------------------------------------------------===// 226 227 namespace { 228 class PatternEmitter { 229 public: 230 static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper, 231 raw_ostream &os); 232 233 private: 234 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 235 236 // Emits the mlir::RewritePattern struct named `rewriteName`. 237 void emit(StringRef rewriteName); 238 239 // Emits the match() method. 240 void emitMatchMethod(DagNode tree); 241 242 // Collects all of the operations within the given dag tree. 243 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 244 245 // Emits the rewrite() method. 246 void emitRewriteMethod(); 247 248 // Emits C++ statements for matching the op constrained by the given DAG 249 // `tree`. 250 void emitOpMatch(DagNode tree, int depth); 251 252 // Emits C++ statements for matching the `index`-th argument of the given DAG 253 // `tree` as an operand. 254 void emitOperandMatch(DagNode tree, int index, int depth, int indent); 255 256 // Emits C++ statements for matching the `index`-th argument of the given DAG 257 // `tree` as an attribute. 258 void emitAttributeMatch(DagNode tree, int index, int depth, int indent); 259 260 // Returns a unique name for an value of the given `op`. 261 std::string getUniqueValueName(const Operator *op); 262 263 // Entry point for handling a rewrite pattern rooted at `resultTree` and 264 // dispatches to concrete handlers. The given tree is the `resultIndex`-th 265 // argument of the enclosing DAG. 266 std::string handleRewritePattern(DagNode resultTree, int resultIndex, 267 int depth); 268 269 // Emits the C++ statement to replace the matched DAG with a value built via 270 // calling native C++ code. 271 std::string emitReplaceWithNativeCodeCall(DagNode resultTree); 272 273 // Returns the C++ expression referencing the old value serving as the 274 // replacement. 275 std::string handleReplaceWithValue(DagNode tree); 276 277 // Handles the `verifyUnusedValue` directive: emitting C++ statements to check 278 // the `index`-th result of the source op is not used. 279 void handleVerifyUnusedValue(DagNode tree, int index); 280 281 // Emits the C++ statement to build a new op out of the given DAG `tree` and 282 // returns the variable name that this op is assigned to. If the root op in 283 // DAG `tree` has a specified name, the created op will be assigned to a 284 // variable of the given name. Otherwise, a unique name will be used as the 285 // result value name. 286 std::string emitOpCreate(DagNode tree, int resultIndex, int depth); 287 288 // Returns the C++ expression to construct a constant attribute of the given 289 // `value` for the given attribute kind `attr`. 290 std::string handleConstantAttr(Attribute attr, StringRef value); 291 292 // Returns the C++ expression to build an argument from the given DAG `leaf`. 293 // `patArgName` is used to bound the argument to the source pattern. 294 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 295 296 // Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol 297 // is already bound. 298 void addSymbol(StringRef symbol, int numValues); 299 300 // Gets the substitution for `symbol`. Aborts if `symbol` is not bound. 301 std::string resolveSymbol(StringRef symbol); 302 303 // Returns how many static values the given DAG `node` correspond to. 304 int getNodeValueCount(DagNode node); 305 306 private: 307 // Pattern instantiation location followed by the location of multiclass 308 // prototypes used. This is intended to be used as a whole to 309 // PrintFatalError() on errors. 310 ArrayRef<llvm::SMLoc> loc; 311 // Op's TableGen Record to wrapper object 312 RecordOperatorMap *opMap; 313 // Handy wrapper for pattern being emitted 314 Pattern pattern; 315 PatternSymbolResolver symbolResolver; 316 // The next unused ID for newly created values 317 unsigned nextValueId; 318 raw_ostream &os; 319 320 // Format contexts containing placeholder substitutations for match(). 321 FmtContext matchCtx; 322 // Format contexts containing placeholder substitutations for rewrite(). 323 FmtContext rewriteCtx; 324 325 // Number of op processed. 326 int opCounter = 0; 327 }; 328 } // end anonymous namespace 329 330 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 331 raw_ostream &os) 332 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 333 symbolResolver(pattern.getSourcePatternBoundArgs(), 334 pattern.getSourcePatternBoundOps()), 335 nextValueId(0), os(os) { 336 matchCtx.withBuilder("mlir::Builder(ctx)"); 337 rewriteCtx.withBuilder("rewriter"); 338 } 339 340 std::string PatternEmitter::handleConstantAttr(Attribute attr, 341 StringRef value) { 342 if (!attr.isConstBuildable()) 343 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 344 " does not have the 'constBuilderCall' field"); 345 346 // TODO(jpienaar): Verify the constants here 347 return tgfmt(attr.getConstBuilderTemplate(), 348 &rewriteCtx.withBuilder("rewriter"), value); 349 } 350 351 // Helper function to match patterns. 352 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 353 Operator &op = tree.getDialectOp(opMap); 354 if (op.isVariadic()) { 355 PrintFatalError(loc, formatv("matching op '{0}' with variadic " 356 "operands/results is unsupported right now", 357 op.getOperationName())); 358 } 359 360 int indent = 4 + 2 * depth; 361 // Skip the operand matching at depth 0 as the pattern rewriter already does. 362 if (depth != 0) { 363 // Skip if there is no defining operation (e.g., arguments to function). 364 os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); 365 os.indent(indent) << formatv( 366 "if (!isa<{1}>(op{0})) return matchFailure();\n", depth, 367 op.getQualCppClassName()); 368 } 369 if (tree.getNumArgs() != op.getNumArgs()) { 370 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 371 "pattern vs. {2} in definition", 372 op.getOperationName(), tree.getNumArgs(), 373 op.getNumArgs())); 374 } 375 376 // If the operand's name is set, set to that variable. 377 auto name = tree.getSymbol(); 378 if (!name.empty()) 379 os.indent(indent) << formatv("{0} = op{1};\n", getBoundSymbol(name), depth); 380 381 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 382 auto opArg = op.getArg(i); 383 384 // Handle nested DAG construct first 385 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 386 os.indent(indent) << "{\n"; 387 os.indent(indent + 2) 388 << formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n", 389 depth + 1, depth, i); 390 emitOpMatch(argTree, depth + 1); 391 os.indent(indent + 2) 392 << formatv("s.autogeneratedRewritePatternOps[{0}] = op{1};\n", 393 ++opCounter, depth + 1); 394 os.indent(indent) << "}\n"; 395 continue; 396 } 397 398 // Next handle DAG leaf: operand or attribute 399 if (opArg.is<NamedTypeConstraint *>()) { 400 emitOperandMatch(tree, i, depth, indent); 401 } else if (opArg.is<NamedAttribute *>()) { 402 emitAttributeMatch(tree, i, depth, indent); 403 } else { 404 PrintFatalError(loc, "unhandled case when matching op"); 405 } 406 } 407 } 408 409 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, 410 int indent) { 411 Operator &op = tree.getDialectOp(opMap); 412 auto *operand = op.getArg(index).get<NamedTypeConstraint *>(); 413 auto matcher = tree.getArgAsLeaf(index); 414 415 // If a constraint is specified, we need to generate C++ statements to 416 // check the constraint. 417 if (!matcher.isUnspecified()) { 418 if (!matcher.isOperandMatcher()) { 419 PrintFatalError( 420 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 421 op.getOperationName(), index + 1)); 422 } 423 424 // Only need to verify if the matcher's type is different from the one 425 // of op definition. 426 if (operand->constraint != matcher.getAsConstraint()) { 427 auto self = formatv("op{0}->getOperand({1})->getType()", depth, index); 428 os.indent(indent) << "if (!(" 429 << tgfmt(matcher.getConditionTemplate(), 430 &matchCtx.withSelf(self)) 431 << ")) return matchFailure();\n"; 432 } 433 } 434 435 // Capture the value 436 auto name = tree.getArgName(index); 437 if (!name.empty()) { 438 os.indent(indent) << getBoundSymbol(name) << " = op" << depth 439 << "->getOperand(" << index << ");\n"; 440 } 441 } 442 443 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, 444 int indent) { 445 Operator &op = tree.getDialectOp(opMap); 446 auto *namedAttr = op.getArg(index).get<NamedAttribute *>(); 447 const auto &attr = namedAttr->attr; 448 449 os.indent(indent) << "{\n"; 450 indent += 2; 451 os.indent(indent) << formatv( 452 "auto attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, 453 attr.getStorageType(), namedAttr->name); 454 455 // TODO(antiagainst): This should use getter method to avoid duplication. 456 if (attr.hasDefaultValueInitializer()) { 457 os.indent(indent) << "if (!attr) attr = " 458 << tgfmt(attr.getConstBuilderTemplate(), &matchCtx, 459 attr.getDefaultValueInitializer()) 460 << ";\n"; 461 } else if (attr.isOptional()) { 462 // For a missing attribut that is optional according to definition, we 463 // should just capature a mlir::Attribute() to signal the missing state. 464 // That is precisely what getAttr() returns on missing attributes. 465 } else { 466 os.indent(indent) << "if (!attr) return matchFailure();\n"; 467 } 468 469 auto matcher = tree.getArgAsLeaf(index); 470 if (!matcher.isUnspecified()) { 471 if (!matcher.isAttrMatcher()) { 472 PrintFatalError( 473 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 474 op.getOperationName(), index + 1)); 475 } 476 477 // If a constraint is specified, we need to generate C++ statements to 478 // check the constraint. 479 os.indent(indent) << "if (!(" 480 << tgfmt(matcher.getConditionTemplate(), 481 &matchCtx.withSelf("attr")) 482 << ")) return matchFailure();\n"; 483 } 484 485 // Capture the value 486 auto name = tree.getArgName(index); 487 if (!name.empty()) { 488 os.indent(indent) << getBoundSymbol(name) << " = attr;\n"; 489 } 490 491 indent -= 2; 492 os.indent(indent) << "}\n"; 493 } 494 495 void PatternEmitter::emitMatchMethod(DagNode tree) { 496 // Emit the heading. 497 os << R"( 498 PatternMatchResult match(Operation *op0) const override { 499 auto ctx = op0->getContext(); (void)ctx; 500 auto state = llvm::make_unique<MatchedState>(); 501 auto &s = *state; 502 s.autogeneratedRewritePatternOps[0] = op0; 503 )"; 504 505 // The rewrite pattern may specify that certain outputs should be unused in 506 // the source IR. Check it here. 507 for (int i = 0, e = pattern.getNumResultPatterns(); i < e; ++i) { 508 DagNode resultTree = pattern.getResultPattern(i); 509 if (resultTree.isVerifyUnusedValue()) { 510 handleVerifyUnusedValue(resultTree, i); 511 } 512 } 513 514 emitOpMatch(tree, 0); 515 516 for (auto &appliedConstraint : pattern.getConstraints()) { 517 auto &constraint = appliedConstraint.constraint; 518 auto &entities = appliedConstraint.entities; 519 520 auto condition = constraint.getConditionTemplate(); 521 auto cmd = "if (!({0})) return matchFailure();\n"; 522 523 if (isa<TypeConstraint>(constraint)) { 524 auto self = formatv("({0}->getType())", resolveSymbol(entities.front())); 525 os.indent(4) << formatv(cmd, 526 tgfmt(condition, &matchCtx.withSelf(self.str()))); 527 } else if (isa<AttrConstraint>(constraint)) { 528 PrintFatalError( 529 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 530 } else { 531 // TODO(fengliuai): replace formatv arguments with the exact specified 532 // args. 533 if (entities.size() > 4) { 534 PrintFatalError(loc, "only support up to 4-entity constraints now"); 535 } 536 SmallVector<std::string, 4> names; 537 int i = 0; 538 for (int e = entities.size(); i < e; ++i) 539 names.push_back(resolveSymbol(entities[i])); 540 for (; i < 4; ++i) 541 names.push_back("<unused>"); 542 os.indent(4) << formatv(cmd, tgfmt(condition, &matchCtx, names[0], 543 names[1], names[2], names[3])); 544 } 545 } 546 547 os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; 548 } 549 550 void PatternEmitter::collectOps(DagNode tree, 551 llvm::SmallPtrSetImpl<const Operator *> &ops) { 552 // Check if this tree is an operation. 553 if (tree.isOperation()) 554 ops.insert(&tree.getDialectOp(opMap)); 555 556 // Recurse the arguments of the tree. 557 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 558 if (auto child = tree.getArgAsNestedDag(i)) 559 collectOps(child, ops); 560 } 561 562 void PatternEmitter::emit(StringRef rewriteName) { 563 // Get the DAG tree for the source pattern 564 DagNode tree = pattern.getSourcePattern(); 565 566 const Operator &rootOp = pattern.getSourceRootOp(); 567 auto rootName = rootOp.getOperationName(); 568 569 // Collect the set of result operations. 570 llvm::SmallPtrSet<const Operator *, 4> results; 571 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) 572 collectOps(pattern.getResultPattern(i), results); 573 574 // Emit RewritePattern for Pattern. 575 auto locs = pattern.getLocation(); 576 os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", 577 make_range(locs.rbegin(), locs.rend())); 578 os << formatv(R"(struct {0} : public RewritePattern { 579 {0}(MLIRContext *context) : RewritePattern("{1}", {{)", 580 rewriteName, rootName); 581 interleaveComma(results, os, [&](const Operator *op) { 582 os << '"' << op->getOperationName() << '"'; 583 }); 584 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 585 586 // Emit matched state. 587 os << " struct MatchedState : public PatternState {\n"; 588 for (const auto &arg : pattern.getSourcePatternBoundArgs()) { 589 auto fieldName = arg.first(); 590 if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) { 591 os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName 592 << ";\n"; 593 } else { 594 os.indent(4) << "Value *" << fieldName << ";\n"; 595 } 596 } 597 for (const auto &result : pattern.getSourcePatternBoundOps()) { 598 os.indent(4) << "Operation *" << result.getKey() << ";\n"; 599 } 600 // TODO(jpienaar): Change to matchAndRewrite & capture ops with consistent 601 // numbering so that it can be reused for fused loc. 602 os.indent(4) << "Operation* autogeneratedRewritePatternOps[" 603 << pattern.getSourcePattern().getNumOps() << "];\n"; 604 os << " };\n"; 605 606 emitMatchMethod(tree); 607 emitRewriteMethod(); 608 609 os << "};\n"; 610 } 611 612 void PatternEmitter::emitRewriteMethod() { 613 const Operator &rootOp = pattern.getSourceRootOp(); 614 int numExpectedResults = rootOp.getNumResults(); 615 int numResultPatterns = pattern.getNumResultPatterns(); 616 617 // First register all symbols bound to ops generated in result patterns. 618 for (const auto &boundOp : pattern.getResultPatternBoundOps()) { 619 addSymbol(boundOp.getKey(), boundOp.getValue()->getNumResults()); 620 } 621 622 // Only the last N static values generated are used to replace the matched 623 // root N-result op. We need to calculate the starting index (of the results 624 // of the matched op) each result pattern is to replace. 625 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 626 int replStartIndex = -1; 627 for (int i = numResultPatterns - 1; i >= 0; --i) { 628 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 629 offsets[i] = offsets[i + 1] - numValues; 630 if (offsets[i] == 0) { 631 replStartIndex = i; 632 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 633 auto error = formatv( 634 "cannot use the same multi-result op '{0}' to generate both " 635 "auxiliary values and values to be used for replacing the matched op", 636 pattern.getResultPattern(i).getSymbol()); 637 PrintFatalError(loc, error); 638 } 639 } 640 641 if (offsets.front() > 0) { 642 const char error[] = "no enough values generated to replace the matched op"; 643 PrintFatalError(loc, error); 644 } 645 646 os << R"( 647 void rewrite(Operation *op, std::unique_ptr<PatternState> state, 648 PatternRewriter &rewriter) const override { 649 auto& s = *static_cast<MatchedState *>(state.get()); 650 auto loc = rewriter.getFusedLoc({)"; 651 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 652 os << (i ? ", " : "") << "s.autogeneratedRewritePatternOps[" << i 653 << "]->getLoc()"; 654 } 655 os << "}); (void)loc;\n"; 656 657 // Collect the replacement value for each result 658 llvm::SmallVector<std::string, 2> resultValues; 659 for (int i = 0; i < numResultPatterns; ++i) { 660 DagNode resultTree = pattern.getResultPattern(i); 661 resultValues.push_back(handleRewritePattern(resultTree, offsets[i], 0)); 662 } 663 664 // Emit the final replaceOp() statement 665 os.indent(4) << "rewriter.replaceOp(op, {"; 666 interleave( 667 ArrayRef<std::string>(resultValues).drop_front(replStartIndex), 668 [&](const std::string &name) { os << name; }, [&]() { os << ", "; }); 669 os << "});\n }\n"; 670 } 671 672 std::string PatternEmitter::getUniqueValueName(const Operator *op) { 673 return formatv("v{0}{1}", op->getCppClassName(), nextValueId++); 674 } 675 676 std::string PatternEmitter::handleRewritePattern(DagNode resultTree, 677 int resultIndex, int depth) { 678 if (resultTree.isNativeCodeCall()) 679 return emitReplaceWithNativeCodeCall(resultTree); 680 681 if (resultTree.isVerifyUnusedValue()) { 682 if (depth > 0) { 683 // TODO: Revisit this when we have use cases of matching an intermediate 684 // multi-result op with no uses of its certain results. 685 PrintFatalError(loc, "verifyUnusedValue directive can only be used to " 686 "verify top-level result"); 687 } 688 689 if (!resultTree.getSymbol().empty()) { 690 PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue"); 691 } 692 693 // The C++ statements to check that this result value is unused are already 694 // emitted in the match() method. So returning a nullptr here directly 695 // should be safe because the C++ RewritePattern harness will use it to 696 // replace nothing. 697 return "nullptr"; 698 } 699 700 if (resultTree.isReplaceWithValue()) 701 return handleReplaceWithValue(resultTree); 702 703 // Create the op and get the local variable for it. 704 auto results = emitOpCreate(resultTree, resultIndex, depth); 705 // We need to get all the values out of this local variable if we've created a 706 // multi-result op. 707 const auto &numResults = pattern.getDialectOp(resultTree).getNumResults(); 708 return formatValuePack("{0}.getOperation()->getResult({1})", results, 709 numResults, /*offset=*/0); 710 } 711 712 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { 713 assert(tree.isReplaceWithValue()); 714 715 if (tree.getNumArgs() != 1) { 716 PrintFatalError( 717 loc, "replaceWithValue directive must take exactly one argument"); 718 } 719 720 if (!tree.getSymbol().empty()) { 721 PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue"); 722 } 723 724 return resolveSymbol(tree.getArgName(0)); 725 } 726 727 void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) { 728 assert(tree.isVerifyUnusedValue()); 729 730 os.indent(4) << "if (!op0->getResult(" << index 731 << ")->use_empty()) return matchFailure();\n"; 732 } 733 734 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) { 735 if (leaf.isConstantAttr()) { 736 auto constAttr = leaf.getAsConstantAttr(); 737 return handleConstantAttr(constAttr.getAttribute(), 738 constAttr.getConstantValue()); 739 } 740 if (leaf.isEnumAttrCase()) { 741 auto enumCase = leaf.getAsEnumAttrCase(); 742 if (enumCase.isStrCase()) 743 return handleConstantAttr(enumCase, enumCase.getSymbol()); 744 // This is an enum case backed by an IntegerAttr. We need to get its value 745 // to build the constant. 746 std::string val = std::to_string(enumCase.getValue()); 747 return handleConstantAttr(enumCase, val); 748 } 749 pattern.ensureBoundInSourcePattern(argName); 750 std::string result = getBoundSymbol(argName).str(); 751 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 752 return result; 753 } 754 if (leaf.isNativeCodeCall()) { 755 return tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(result)); 756 } 757 PrintFatalError(loc, "unhandled case when rewriting op"); 758 } 759 760 std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) { 761 auto fmt = tree.getNativeCodeTemplate(); 762 // TODO(fengliuai): replace formatv arguments with the exact specified args. 763 SmallVector<std::string, 8> attrs(8); 764 if (tree.getNumArgs() > 8) { 765 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 766 Twine(tree.getNumArgs())); 767 } 768 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 769 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 770 } 771 return tgfmt(fmt, &rewriteCtx, attrs[0], attrs[1], attrs[2], attrs[3], 772 attrs[4], attrs[5], attrs[6], attrs[7]); 773 } 774 775 void PatternEmitter::addSymbol(StringRef symbol, int numValues) { 776 if (!symbolResolver.add(symbol, numValues)) 777 PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol)); 778 } 779 780 std::string PatternEmitter::resolveSymbol(StringRef symbol) { 781 auto subst = symbolResolver.query(symbol); 782 if (subst.empty()) 783 PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol)); 784 return subst; 785 } 786 787 int PatternEmitter::getNodeValueCount(DagNode node) { 788 if (node.isOperation()) { 789 // First to see whether this op is bound and we just want a specific result 790 // of it with `__N` suffix in symbol. 791 int count = symbolResolver.getValueCount(node.getSymbol()); 792 if (count >= 0) 793 return count; 794 795 // No symbol. Then we are using all the results. 796 return pattern.getDialectOp(node).getNumResults(); 797 } 798 // TODO(antiagainst): This considers all NativeCodeCall as returning one 799 // value. Enhance if multi-value ones are needed. 800 return 1; 801 } 802 803 std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, 804 int depth) { 805 Operator &resultOp = tree.getDialectOp(opMap); 806 auto numOpArgs = resultOp.getNumArgs(); 807 808 if (resultOp.isVariadic()) { 809 PrintFatalError(loc, formatv("generating op '{0}' with variadic " 810 "operands/results is unsupported now", 811 resultOp.getOperationName())); 812 } 813 814 if (numOpArgs != tree.getNumArgs()) { 815 PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " 816 "{1} in pattern vs. {2} in definition", 817 resultOp.getOperationName(), tree.getNumArgs(), 818 numOpArgs)); 819 } 820 821 // A map to collect all nested DAG child nodes' names, with operand index as 822 // the key. This includes both bound and unbound child nodes. Bound child 823 // nodes will additionally be tracked in `symbolResolver` so they can be 824 // referenced by other patterns. Unbound child nodes will only be used once 825 // to build this op. 826 llvm::DenseMap<unsigned, std::string> childNodeNames; 827 828 // First go through all the child nodes who are nested DAG constructs to 829 // create ops for them, so that we can use the results in the current node. 830 // This happens in a recursive manner. 831 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 832 if (auto child = tree.getArgAsNestedDag(i)) { 833 childNodeNames[i] = handleRewritePattern(child, i, depth + 1); 834 } 835 } 836 837 // Use the specified name for this op if available. Generate one otherwise. 838 std::string resultValue = tree.getSymbol(); 839 if (resultValue.empty()) 840 resultValue = getUniqueValueName(&resultOp); 841 // Strip the index to get the name for the value pack. This will be used to 842 // name the local variable for the op. 843 StringRef valuePackName = getValuePackName(resultValue); 844 845 // Then we build the new op corresponding to this DAG node. 846 847 // Right now we don't have general type inference in MLIR. Except a few 848 // special cases listed below, we need to supply types for all results 849 // when building an op. 850 bool isSameOperandsAndResultType = 851 resultOp.hasTrait("SameOperandsAndResultType"); 852 bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult"); 853 bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType"); 854 bool usePartialResults = valuePackName != resultValue; 855 856 if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr || 857 usePartialResults || depth > 0 || resultIndex < 0) { 858 os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", 859 valuePackName, resultOp.getQualCppClassName()); 860 } else { 861 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 862 // generated from the source pattern root op. Then we can use the source 863 // pattern's value types to determine the value type of the generated op 864 // here. 865 866 // We need to specify the types for all results. 867 auto resultTypes = 868 formatValuePack("op->getResult({1})->getType()", valuePackName, 869 resultOp.getNumResults(), resultIndex); 870 871 os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", 872 valuePackName, resultOp.getQualCppClassName()) 873 << (resultTypes.empty() ? "" : ", ") << resultTypes; 874 } 875 876 // Create the builder call for the result. 877 // Add operands. 878 int i = 0; 879 for (int e = resultOp.getNumOperands(); i < e; ++i) { 880 const auto &operand = resultOp.getOperand(i); 881 882 // Start each operand on its own line. 883 (os << ",\n").indent(6); 884 885 if (!operand.name.empty()) 886 os << "/*" << operand.name << "=*/"; 887 888 if (tree.isNestedDagArg(i)) { 889 os << childNodeNames[i]; 890 } else { 891 DagLeaf leaf = tree.getArgAsLeaf(i); 892 auto symbol = resolveSymbol(tree.getArgName(i)); 893 if (leaf.isNativeCodeCall()) { 894 os << tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(symbol)); 895 } else { 896 os << symbol; 897 } 898 } 899 // TODO(jpienaar): verify types 900 } 901 902 // Add attributes. 903 for (int e = tree.getNumArgs(); i != e; ++i) { 904 // Start each attribute on its own line. 905 (os << ",\n").indent(6); 906 // The argument in the op definition. 907 auto opArgName = resultOp.getArgName(i); 908 if (auto subTree = tree.getArgAsNestedDag(i)) { 909 if (!subTree.isNativeCodeCall()) 910 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 911 "for creating attribute"); 912 os << formatv("/*{0}=*/{1}", opArgName, 913 emitReplaceWithNativeCodeCall(subTree)); 914 } else { 915 auto leaf = tree.getArgAsLeaf(i); 916 // The argument in the result DAG pattern. 917 auto patArgName = tree.getArgName(i); 918 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 919 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 920 auto argument = resultOp.getArg(i); 921 if (!argument.is<NamedAttribute *>()) 922 PrintFatalError(loc, Twine("expected attribute ") + Twine(i)); 923 if (!patArgName.empty()) 924 os << "/*" << patArgName << "=*/"; 925 } else { 926 os << "/*" << opArgName << "=*/"; 927 } 928 os << handleOpArgument(leaf, patArgName); 929 } 930 } 931 os << "\n );\n"; 932 933 return resultValue; 934 } 935 936 void PatternEmitter::emit(StringRef rewriteName, Record *p, 937 RecordOperatorMap *mapper, raw_ostream &os) { 938 PatternEmitter(p, mapper, os).emit(rewriteName); 939 } 940 941 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 942 emitSourceFileHeader("Rewriters", os); 943 944 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 945 auto numPatterns = patterns.size(); 946 947 // We put the map here because it can be shared among multiple patterns. 948 RecordOperatorMap recordOpMap; 949 950 std::vector<std::string> rewriterNames; 951 rewriterNames.reserve(numPatterns); 952 953 std::string baseRewriterName = "GeneratedConvert"; 954 int rewriterIndex = 0; 955 956 for (Record *p : patterns) { 957 std::string name; 958 if (p->isAnonymous()) { 959 // If no name is provided, ensure unique rewriter names simply by 960 // appending unique suffix. 961 name = baseRewriterName + llvm::utostr(rewriterIndex++); 962 } else { 963 name = p->getName(); 964 } 965 PatternEmitter::emit(name, p, &recordOpMap, os); 966 rewriterNames.push_back(std::move(name)); 967 } 968 969 // Emit function to add the generated matchers to the pattern list. 970 os << "void populateWithGenerated(MLIRContext *context, " 971 << "OwningRewritePatternList *patterns) {\n"; 972 for (const auto &name : rewriterNames) { 973 os << " patterns->push_back(llvm::make_unique<" << name 974 << ">(context));\n"; 975 } 976 os << "}\n"; 977 } 978 979 static mlir::GenRegistration 980 genRewriters("gen-rewriters", "Generate pattern rewriters", 981 [](const RecordKeeper &records, raw_ostream &os) { 982 emitRewriters(records, os); 983 return false; 984 }); 985