1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Support/STLExtras.h" 23 #include "mlir/TableGen/Attribute.h" 24 #include "mlir/TableGen/Format.h" 25 #include "mlir/TableGen/GenInfo.h" 26 #include "mlir/TableGen/Operator.h" 27 #include "mlir/TableGen/Pattern.h" 28 #include "mlir/TableGen/Predicate.h" 29 #include "mlir/TableGen/Type.h" 30 #include "llvm/ADT/StringExtras.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/FormatAdapters.h" 34 #include "llvm/Support/PrettyStackTrace.h" 35 #include "llvm/Support/Signals.h" 36 #include "llvm/TableGen/Error.h" 37 #include "llvm/TableGen/Main.h" 38 #include "llvm/TableGen/Record.h" 39 #include "llvm/TableGen/TableGenBackend.h" 40 41 using namespace llvm; 42 using namespace mlir; 43 using namespace mlir::tblgen; 44 45 namespace llvm { 46 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 47 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 48 raw_ostream &os, StringRef style) { 49 os << v.first << ":" << v.second; 50 } 51 }; 52 } // end namespace llvm 53 54 // Returns the bound symbol for the given op argument or op named `symbol`. 55 // 56 // Arguments and ops bound in the source pattern are grouped into a 57 // transient `PatternState` struct. This struct can be accessed in both 58 // `match()` and `rewrite()` via the local variable named as `s`. 59 static Twine getBoundSymbol(const StringRef &symbol) { 60 return Twine("s.") + symbol; 61 } 62 63 // Gets the dynamic value pack's name by removing the index suffix from 64 // `symbol`. Returns `symbol` itself if it does not contain an index. 65 // 66 // We can use `name__<index>` to access the `<index>`-th value in the dynamic 67 // value pack bound to `name`. `name` is typically the results of an 68 // multi-result op. 69 static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) { 70 StringRef name, indexStr; 71 unsigned idx = 0; 72 std::tie(name, indexStr) = symbol.rsplit("__"); 73 if (indexStr.consumeInteger(10, idx)) { 74 // The second part is not an index. 75 return symbol; 76 } 77 if (index) 78 *index = idx; 79 return name; 80 } 81 82 // Formats all values from a dynamic value pack `symbol` according to the given 83 // `fmt` string. The `fmt` string should use `{0}` as a placeholder for `symbol` 84 // and `{1}` as a placeholder for the value index, which will be offsetted by 85 // `offset`. The `symbol` value pack has a total of `count` values. 86 // 87 // This extracts one value from the pack if `symbol` contains an index, 88 // otherwise it extracts all values sequentially and returns them as a 89 // comma-separated list. 90 static std::string formatValuePack(const char *fmt, StringRef symbol, 91 unsigned count, unsigned offset) { 92 auto getNthValue = [fmt, offset](StringRef results, 93 unsigned index) -> std::string { 94 return formatv(fmt, results, index + offset); 95 }; 96 97 unsigned index = 0; 98 StringRef name = getValuePackName(symbol, &index); 99 if (name != symbol) { 100 // The symbol contains an index. 101 return getNthValue(name, index); 102 } 103 104 // The symbol does not contain an index. Treat the symbol as a whole. 105 SmallVector<std::string, 4> values; 106 values.reserve(count); 107 for (unsigned i = 0; i < count; ++i) 108 values.emplace_back(getNthValue(symbol, i)); 109 return llvm::join(values, ", "); 110 } 111 112 //===----------------------------------------------------------------------===// 113 // PatternSymbolResolver 114 //===----------------------------------------------------------------------===// 115 116 namespace { 117 // A class for resolving symbols bound in patterns. 118 // 119 // Symbols can be bound to op arguments and ops in the source pattern and ops 120 // in result patterns. For example, in 121 // 122 // ``` 123 // def : Pattern<(SrcOp:$op1 $arg0, %arg1), 124 // [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>; 125 // ``` 126 // 127 // `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`. 128 // `$op2` is bound to `ResOp1`. 129 // 130 // If a symbol binds to a multi-result op and it does not have the `__N` 131 // suffix, the symbol is expanded to the whole value pack generated by the 132 // multi-result op. If the symbol has a `__N` suffix, then it will expand to 133 // only the N-th result. 134 // 135 // This class keeps track of such symbols and translates them into their bound 136 // values. 137 // 138 // Note that we also generate local variables for unnamed DAG nodes, like 139 // `(ResOp3)` in the above. Since we don't bind a symbol to the op, the 140 // generated local variable will be implicitly named. Those implicit names are 141 // not tracked in this class. 142 class PatternSymbolResolver { 143 public: 144 PatternSymbolResolver(const StringMap<Argument> &srcArgs, 145 const StringMap<const Operator *> &srcOperations); 146 147 // Marks the given `symbol` as bound to a value pack with `numValues` and 148 // returns true on success. Returns false if the `symbol` is already bound. 149 bool add(StringRef symbol, int numValues); 150 151 // Queries the substitution for the given `symbol`. Returns empty string if 152 // symbol not found. If the symbol represents a value pack, returns all the 153 // values separated via comma. 154 std::string query(StringRef symbol) const; 155 156 private: 157 // Symbols bound to arguments in source pattern. 158 const StringMap<Argument> &sourceArguments; 159 // Symbols bound to ops (for their results) in source pattern. 160 const StringMap<const Operator *> &sourceOps; 161 // Symbols bound to ops (for their results) in result patterns. 162 // Key: symbol; value: number of values inside the pack 163 StringMap<int> resultOps; 164 }; 165 } // end anonymous namespace 166 167 PatternSymbolResolver::PatternSymbolResolver( 168 const StringMap<Argument> &srcArgs, 169 const StringMap<const Operator *> &srcOperations) 170 : sourceArguments(srcArgs), sourceOps(srcOperations) {} 171 172 bool PatternSymbolResolver::add(StringRef symbol, int numValues) { 173 StringRef name = getValuePackName(symbol); 174 return resultOps.try_emplace(name, numValues).second; 175 } 176 177 std::string PatternSymbolResolver::query(StringRef symbol) const { 178 { 179 StringRef name = getValuePackName(symbol); 180 auto it = resultOps.find(name); 181 if (it != resultOps.end()) 182 return formatValuePack("{0}.getOperation()->getResult({1})", symbol, 183 it->second, /*offset=*/0); 184 } 185 { 186 auto it = sourceArguments.find(symbol); 187 if (it != sourceArguments.end()) 188 return getBoundSymbol(symbol).str(); 189 } 190 { 191 StringRef name = getValuePackName(symbol); 192 auto it = sourceOps.find(name); 193 if (it != sourceOps.end()) 194 return formatValuePack("{0}->getResult({1})", 195 getBoundSymbol(symbol).str(), 196 it->second->getNumResults(), /*offset=*/0); 197 } 198 return {}; 199 } 200 201 //===----------------------------------------------------------------------===// 202 // PatternEmitter 203 //===----------------------------------------------------------------------===// 204 205 namespace { 206 class PatternEmitter { 207 public: 208 static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper, 209 raw_ostream &os); 210 211 private: 212 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 213 214 // Emits the mlir::RewritePattern struct named `rewriteName`. 215 void emit(StringRef rewriteName); 216 217 // Emits the match() method. 218 void emitMatchMethod(DagNode tree); 219 220 // Collects all of the operations within the given dag tree. 221 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 222 223 // Emits the rewrite() method. 224 void emitRewriteMethod(); 225 226 // Emits C++ statements for matching the op constrained by the given DAG 227 // `tree`. 228 void emitOpMatch(DagNode tree, int depth); 229 230 // Emits C++ statements for matching the `index`-th argument of the given DAG 231 // `tree` as an operand. 232 void emitOperandMatch(DagNode tree, int index, int depth, int indent); 233 234 // Emits C++ statements for matching the `index`-th argument of the given DAG 235 // `tree` as an attribute. 236 void emitAttributeMatch(DagNode tree, int index, int depth, int indent); 237 238 // Returns a unique name for an value of the given `op`. 239 std::string getUniqueValueName(const Operator *op); 240 241 // Entry point for handling a rewrite pattern rooted at `resultTree` and 242 // dispatches to concrete handlers. The given tree is the `resultIndex`-th 243 // argument of the enclosing DAG. 244 std::string handleRewritePattern(DagNode resultTree, int resultIndex, 245 int depth); 246 247 // Emits the C++ statement to replace the matched DAG with a value built via 248 // calling native C++ code. 249 std::string emitReplaceWithNativeCodeCall(DagNode resultTree); 250 251 // Returns the C++ expression referencing the old value serving as the 252 // replacement. 253 std::string handleReplaceWithValue(DagNode tree); 254 255 // Handles the `verifyUnusedValue` directive: emitting C++ statements to check 256 // the `index`-th result of the source op is not used. 257 void handleVerifyUnusedValue(DagNode tree, int index); 258 259 // Emits the C++ statement to build a new op out of the given DAG `tree` and 260 // returns the variable name that this op is assigned to. If the root op in 261 // DAG `tree` has a specified name, the created op will be assigned to a 262 // variable of the given name. Otherwise, a unique name will be used as the 263 // result value name. 264 std::string emitOpCreate(DagNode tree, int resultIndex, int depth); 265 266 // Returns the C++ expression to construct a constant attribute of the given 267 // `value` for the given attribute kind `attr`. 268 std::string handleConstantAttr(Attribute attr, StringRef value); 269 270 // Returns the C++ expression to build an argument from the given DAG `leaf`. 271 // `patArgName` is used to bound the argument to the source pattern. 272 std::string handleOpArgument(DagLeaf leaf, llvm::StringRef patArgName); 273 274 // Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol 275 // is already bound. 276 void addSymbol(DagNode node); 277 278 // Gets the substitution for `symbol`. Aborts if `symbol` is not bound. 279 std::string resolveSymbol(StringRef symbol); 280 281 private: 282 // Pattern instantiation location followed by the location of multiclass 283 // prototypes used. This is intended to be used as a whole to 284 // PrintFatalError() on errors. 285 ArrayRef<llvm::SMLoc> loc; 286 // Op's TableGen Record to wrapper object 287 RecordOperatorMap *opMap; 288 // Handy wrapper for pattern being emitted 289 Pattern pattern; 290 PatternSymbolResolver symbolResolver; 291 // The next unused ID for newly created values 292 unsigned nextValueId; 293 raw_ostream &os; 294 295 // Format contexts containing placeholder substitutations for match(). 296 FmtContext matchCtx; 297 // Format contexts containing placeholder substitutations for rewrite(). 298 FmtContext rewriteCtx; 299 300 // Number of op processed. 301 int opCounter = 0; 302 }; 303 } // end anonymous namespace 304 305 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 306 raw_ostream &os) 307 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 308 symbolResolver(pattern.getSourcePatternBoundArgs(), 309 pattern.getSourcePatternBoundOps()), 310 nextValueId(0), os(os) { 311 matchCtx.withBuilder("mlir::Builder(ctx)"); 312 rewriteCtx.withBuilder("rewriter"); 313 } 314 315 std::string PatternEmitter::handleConstantAttr(Attribute attr, 316 StringRef value) { 317 if (!attr.isConstBuildable()) 318 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 319 " does not have the 'constBuilderCall' field"); 320 321 // TODO(jpienaar): Verify the constants here 322 return tgfmt(attr.getConstBuilderTemplate(), 323 &rewriteCtx.withBuilder("rewriter"), value); 324 } 325 326 // Helper function to match patterns. 327 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 328 Operator &op = tree.getDialectOp(opMap); 329 if (op.isVariadic()) { 330 PrintFatalError(loc, formatv("matching op '{0}' with variadic " 331 "operands/results is unsupported right now", 332 op.getOperationName())); 333 } 334 335 int indent = 4 + 2 * depth; 336 // Skip the operand matching at depth 0 as the pattern rewriter already does. 337 if (depth != 0) { 338 // Skip if there is no defining operation (e.g., arguments to function). 339 os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); 340 os.indent(indent) << formatv( 341 "if (!isa<{1}>(op{0})) return matchFailure();\n", depth, 342 op.getQualCppClassName()); 343 } 344 if (tree.getNumArgs() != op.getNumArgs()) { 345 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 346 "pattern vs. {2} in definition", 347 op.getOperationName(), tree.getNumArgs(), 348 op.getNumArgs())); 349 } 350 351 // If the operand's name is set, set to that variable. 352 auto name = tree.getOpName(); 353 if (!name.empty()) 354 os.indent(indent) << formatv("{0} = op{1};\n", getBoundSymbol(name), depth); 355 356 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 357 auto opArg = op.getArg(i); 358 359 // Handle nested DAG construct first 360 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 361 os.indent(indent) << "{\n"; 362 os.indent(indent + 2) 363 << formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n", 364 depth + 1, depth, i); 365 emitOpMatch(argTree, depth + 1); 366 os.indent(indent + 2) 367 << formatv("s.autogeneratedRewritePatternOps[{0}] = op{1};\n", 368 ++opCounter, depth + 1); 369 os.indent(indent) << "}\n"; 370 continue; 371 } 372 373 // Next handle DAG leaf: operand or attribute 374 if (opArg.is<NamedTypeConstraint *>()) { 375 emitOperandMatch(tree, i, depth, indent); 376 } else if (opArg.is<NamedAttribute *>()) { 377 emitAttributeMatch(tree, i, depth, indent); 378 } else { 379 PrintFatalError(loc, "unhandled case when matching op"); 380 } 381 } 382 } 383 384 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, 385 int indent) { 386 Operator &op = tree.getDialectOp(opMap); 387 auto *operand = op.getArg(index).get<NamedTypeConstraint *>(); 388 auto matcher = tree.getArgAsLeaf(index); 389 390 // If a constraint is specified, we need to generate C++ statements to 391 // check the constraint. 392 if (!matcher.isUnspecified()) { 393 if (!matcher.isOperandMatcher()) { 394 PrintFatalError( 395 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 396 op.getOperationName(), index + 1)); 397 } 398 399 // Only need to verify if the matcher's type is different from the one 400 // of op definition. 401 if (operand->constraint != matcher.getAsConstraint()) { 402 auto self = formatv("op{0}->getOperand({1})->getType()", depth, index); 403 os.indent(indent) << "if (!(" 404 << tgfmt(matcher.getConditionTemplate(), 405 &matchCtx.withSelf(self)) 406 << ")) return matchFailure();\n"; 407 } 408 } 409 410 // Capture the value 411 auto name = tree.getArgName(index); 412 if (!name.empty()) { 413 os.indent(indent) << getBoundSymbol(name) << " = op" << depth 414 << "->getOperand(" << index << ");\n"; 415 } 416 } 417 418 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, 419 int indent) { 420 Operator &op = tree.getDialectOp(opMap); 421 auto *namedAttr = op.getArg(index).get<NamedAttribute *>(); 422 const auto &attr = namedAttr->attr; 423 424 os.indent(indent) << "{\n"; 425 indent += 2; 426 os.indent(indent) << formatv( 427 "auto attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, 428 attr.getStorageType(), namedAttr->name); 429 430 // TODO(antiagainst): This should use getter method to avoid duplication. 431 if (attr.hasDefaultValueInitializer()) { 432 os.indent(indent) << "if (!attr) attr = " 433 << tgfmt(attr.getConstBuilderTemplate(), &matchCtx, 434 attr.getDefaultValueInitializer()) 435 << ";\n"; 436 } else if (attr.isOptional()) { 437 // For a missing attribut that is optional according to definition, we 438 // should just capature a mlir::Attribute() to signal the missing state. 439 // That is precisely what getAttr() returns on missing attributes. 440 } else { 441 os.indent(indent) << "if (!attr) return matchFailure();\n"; 442 } 443 444 auto matcher = tree.getArgAsLeaf(index); 445 if (!matcher.isUnspecified()) { 446 if (!matcher.isAttrMatcher()) { 447 PrintFatalError( 448 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 449 op.getOperationName(), index + 1)); 450 } 451 452 // If a constraint is specified, we need to generate C++ statements to 453 // check the constraint. 454 os.indent(indent) << "if (!(" 455 << tgfmt(matcher.getConditionTemplate(), 456 &matchCtx.withSelf("attr")) 457 << ")) return matchFailure();\n"; 458 } 459 460 // Capture the value 461 auto name = tree.getArgName(index); 462 if (!name.empty()) { 463 os.indent(indent) << getBoundSymbol(name) << " = attr;\n"; 464 } 465 466 indent -= 2; 467 os.indent(indent) << "}\n"; 468 } 469 470 void PatternEmitter::emitMatchMethod(DagNode tree) { 471 // Emit the heading. 472 os << R"( 473 PatternMatchResult match(Operation *op0) const override { 474 auto ctx = op0->getContext(); (void)ctx; 475 auto state = llvm::make_unique<MatchedState>(); 476 auto &s = *state; 477 s.autogeneratedRewritePatternOps[0] = op0; 478 )"; 479 480 // The rewrite pattern may specify that certain outputs should be unused in 481 // the source IR. Check it here. 482 for (int i = 0, e = pattern.getNumResults(); i < e; ++i) { 483 DagNode resultTree = pattern.getResultPattern(i); 484 if (resultTree.isVerifyUnusedValue()) { 485 handleVerifyUnusedValue(resultTree, i); 486 } 487 } 488 489 emitOpMatch(tree, 0); 490 491 for (auto &appliedConstraint : pattern.getConstraints()) { 492 auto &constraint = appliedConstraint.constraint; 493 auto &entities = appliedConstraint.entities; 494 495 auto condition = constraint.getConditionTemplate(); 496 auto cmd = "if (!({0})) return matchFailure();\n"; 497 498 if (isa<TypeConstraint>(constraint)) { 499 auto self = formatv("({0}->getType())", resolveSymbol(entities.front())); 500 os.indent(4) << formatv(cmd, 501 tgfmt(condition, &matchCtx.withSelf(self.str()))); 502 } else if (isa<AttrConstraint>(constraint)) { 503 PrintFatalError( 504 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 505 } else { 506 // TODO(fengliuai): replace formatv arguments with the exact specified 507 // args. 508 if (entities.size() > 4) { 509 PrintFatalError(loc, "only support up to 4-entity constraints now"); 510 } 511 SmallVector<std::string, 4> names; 512 int i = 0; 513 for (int e = entities.size(); i < e; ++i) 514 names.push_back(resolveSymbol(entities[i])); 515 for (; i < 4; ++i) 516 names.push_back("<unused>"); 517 os.indent(4) << formatv(cmd, tgfmt(condition, &matchCtx, names[0], 518 names[1], names[2], names[3])); 519 } 520 } 521 522 os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; 523 } 524 525 void PatternEmitter::collectOps(DagNode tree, 526 llvm::SmallPtrSetImpl<const Operator *> &ops) { 527 // Check if this tree is an operation. 528 if (tree.isOperation()) 529 ops.insert(&tree.getDialectOp(opMap)); 530 531 // Recurse the arguments of the tree. 532 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 533 if (auto child = tree.getArgAsNestedDag(i)) 534 collectOps(child, ops); 535 } 536 537 void PatternEmitter::emit(StringRef rewriteName) { 538 // Get the DAG tree for the source pattern 539 DagNode tree = pattern.getSourcePattern(); 540 541 const Operator &rootOp = pattern.getSourceRootOp(); 542 auto rootName = rootOp.getOperationName(); 543 544 // Collect the set of result operations. 545 llvm::SmallPtrSet<const Operator *, 4> results; 546 for (unsigned i = 0, e = pattern.getNumResults(); i != e; ++i) 547 collectOps(pattern.getResultPattern(i), results); 548 549 // Emit RewritePattern for Pattern. 550 auto locs = pattern.getLocation(); 551 os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", 552 make_range(locs.rbegin(), locs.rend())); 553 os << formatv(R"(struct {0} : public RewritePattern { 554 {0}(MLIRContext *context) : RewritePattern("{1}", {{)", 555 rewriteName, rootName); 556 interleaveComma(results, os, [&](const Operator *op) { 557 os << '"' << op->getOperationName() << '"'; 558 }); 559 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 560 561 // Emit matched state. 562 os << " struct MatchedState : public PatternState {\n"; 563 for (const auto &arg : pattern.getSourcePatternBoundArgs()) { 564 auto fieldName = arg.first(); 565 if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) { 566 os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName 567 << ";\n"; 568 } else { 569 os.indent(4) << "Value *" << fieldName << ";\n"; 570 } 571 } 572 for (const auto &result : pattern.getSourcePatternBoundOps()) { 573 os.indent(4) << "Operation *" << result.getKey() << ";\n"; 574 } 575 // TODO(jpienaar): Change to matchAndRewrite & capture ops with consistent 576 // numbering so that it can be reused for fused loc. 577 os.indent(4) << "Operation* autogeneratedRewritePatternOps[" 578 << pattern.getSourcePattern().getNumOps() << "];\n"; 579 os << " };\n"; 580 581 emitMatchMethod(tree); 582 emitRewriteMethod(); 583 584 os << "};\n"; 585 } 586 587 void PatternEmitter::emitRewriteMethod() { 588 const Operator &rootOp = pattern.getSourceRootOp(); 589 int numExpectedResults = rootOp.getNumResults(); 590 int numProvidedResults = pattern.getNumResults(); 591 592 os << R"( 593 void rewrite(Operation *op, std::unique_ptr<PatternState> state, 594 PatternRewriter &rewriter) const override { 595 auto& s = *static_cast<MatchedState *>(state.get()); 596 auto loc = rewriter.getFusedLoc({)"; 597 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 598 os << (i ? ", " : "") << "s.autogeneratedRewritePatternOps[" << i 599 << "]->getLoc()"; 600 } 601 os << "}); (void)loc;\n"; 602 603 // Collect the replacement value for each result 604 llvm::SmallVector<std::string, 2> resultValues; 605 for (int i = 0; i < numProvidedResults; ++i) { 606 DagNode resultTree = pattern.getResultPattern(i); 607 resultValues.push_back(handleRewritePattern(resultTree, i, 0)); 608 // Keep track of bound symbols at the top-level DAG nodes 609 addSymbol(resultTree); 610 } 611 612 // Emit the final replaceOp() statement 613 os.indent(4) << "rewriter.replaceOp(op, {"; 614 interleave( 615 // We only use the last numExpectedResults ones to replace the root op. 616 ArrayRef<std::string>(resultValues).take_back(numExpectedResults), 617 [&](const std::string &name) { os << name; }, [&]() { os << ", "; }); 618 os << "});\n }\n"; 619 } 620 621 std::string PatternEmitter::getUniqueValueName(const Operator *op) { 622 return formatv("v{0}{1}", op->getCppClassName(), nextValueId++); 623 } 624 625 std::string PatternEmitter::handleRewritePattern(DagNode resultTree, 626 int resultIndex, int depth) { 627 if (resultTree.isNativeCodeCall()) 628 return emitReplaceWithNativeCodeCall(resultTree); 629 630 if (resultTree.isVerifyUnusedValue()) { 631 if (depth > 0) { 632 // TODO: Revisit this when we have use cases of matching an intermediate 633 // multi-result op with no uses of its certain results. 634 PrintFatalError(loc, "verifyUnusedValue directive can only be used to " 635 "verify top-level result"); 636 } 637 638 if (!resultTree.getOpName().empty()) { 639 PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue"); 640 } 641 642 // The C++ statements to check that this result value is unused are already 643 // emitted in the match() method. So returning a nullptr here directly 644 // should be safe because the C++ RewritePattern harness will use it to 645 // replace nothing. 646 return "nullptr"; 647 } 648 649 if (resultTree.isReplaceWithValue()) 650 return handleReplaceWithValue(resultTree); 651 652 // Create the op and get the local variable for it. 653 auto results = emitOpCreate(resultTree, resultIndex, depth); 654 // We need to get all the values out of this local variable if we've created a 655 // multi-result op. 656 const auto &numResults = pattern.getDialectOp(resultTree).getNumResults(); 657 return formatValuePack("{0}.getOperation()->getResult({1})", results, 658 numResults, /*offset=*/0); 659 } 660 661 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { 662 assert(tree.isReplaceWithValue()); 663 664 if (tree.getNumArgs() != 1) { 665 PrintFatalError( 666 loc, "replaceWithValue directive must take exactly one argument"); 667 } 668 669 if (!tree.getOpName().empty()) { 670 PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue"); 671 } 672 673 return resolveSymbol(tree.getArgName(0)); 674 } 675 676 void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) { 677 assert(tree.isVerifyUnusedValue()); 678 679 os.indent(4) << "if (!op0->getResult(" << index 680 << ")->use_empty()) return matchFailure();\n"; 681 } 682 683 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 684 llvm::StringRef argName) { 685 if (leaf.isConstantAttr()) { 686 auto constAttr = leaf.getAsConstantAttr(); 687 return handleConstantAttr(constAttr.getAttribute(), 688 constAttr.getConstantValue()); 689 } 690 if (leaf.isEnumAttrCase()) { 691 auto enumCase = leaf.getAsEnumAttrCase(); 692 if (enumCase.isStrCase()) 693 return handleConstantAttr(enumCase, enumCase.getSymbol()); 694 // This is an enum case backed by an IntegerAttr. We need to get its value 695 // to build the constant. 696 std::string val = std::to_string(enumCase.getValue()); 697 return handleConstantAttr(enumCase, val); 698 } 699 pattern.ensureBoundInSourcePattern(argName); 700 std::string result = getBoundSymbol(argName).str(); 701 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 702 return result; 703 } 704 if (leaf.isNativeCodeCall()) { 705 return tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(result)); 706 } 707 PrintFatalError(loc, "unhandled case when rewriting op"); 708 } 709 710 std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) { 711 auto fmt = tree.getNativeCodeTemplate(); 712 // TODO(fengliuai): replace formatv arguments with the exact specified args. 713 SmallVector<std::string, 8> attrs(8); 714 if (tree.getNumArgs() > 8) { 715 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 716 Twine(tree.getNumArgs())); 717 } 718 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 719 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 720 } 721 return tgfmt(fmt, &rewriteCtx, attrs[0], attrs[1], attrs[2], attrs[3], 722 attrs[4], attrs[5], attrs[6], attrs[7]); 723 } 724 725 void PatternEmitter::addSymbol(DagNode node) { 726 StringRef symbol = node.getOpName(); 727 // Skip empty-named symbols, which happen for unbound ops in result patterns. 728 if (symbol.empty()) 729 return; 730 if (!symbolResolver.add(symbol, pattern.getDialectOp(node).getNumResults())) 731 PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol)); 732 } 733 734 std::string PatternEmitter::resolveSymbol(StringRef symbol) { 735 auto subst = symbolResolver.query(symbol); 736 if (subst.empty()) 737 PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol)); 738 return subst; 739 } 740 741 std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, 742 int depth) { 743 Operator &resultOp = tree.getDialectOp(opMap); 744 auto numOpArgs = resultOp.getNumArgs(); 745 746 if (resultOp.isVariadic()) { 747 PrintFatalError(loc, formatv("generating op '{0}' with variadic " 748 "operands/results is unsupported now", 749 resultOp.getOperationName())); 750 } 751 752 if (numOpArgs != tree.getNumArgs()) { 753 PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " 754 "{1} in pattern vs. {2} in definition", 755 resultOp.getOperationName(), tree.getNumArgs(), 756 numOpArgs)); 757 } 758 759 // A map to collect all nested DAG child nodes' names, with operand index as 760 // the key. This includes both bound and unbound child nodes. Bound child 761 // nodes will additionally be tracked in `symbolResolver` so they can be 762 // referenced by other patterns. Unbound child nodes will only be used once 763 // to build this op. 764 llvm::DenseMap<unsigned, std::string> childNodeNames; 765 766 // First go through all the child nodes who are nested DAG constructs to 767 // create ops for them, so that we can use the results in the current node. 768 // This happens in a recursive manner. 769 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 770 if (auto child = tree.getArgAsNestedDag(i)) { 771 childNodeNames[i] = handleRewritePattern(child, i, depth + 1); 772 // Keep track of bound symbols at the middle-level DAG nodes 773 addSymbol(child); 774 } 775 } 776 777 // Use the specified name for this op if available. Generate one otherwise. 778 std::string resultValue = tree.getOpName(); 779 if (resultValue.empty()) 780 resultValue = getUniqueValueName(&resultOp); 781 // Strip the index to get the name for the value pack. This will be used to 782 // name the local variable for the op. 783 StringRef valuePackName = getValuePackName(resultValue); 784 785 // Then we build the new op corresponding to this DAG node. 786 787 // Right now we don't have general type inference in MLIR. Except a few 788 // special cases listed below, we need to supply types for all results 789 // when building an op. 790 bool isSameOperandsAndResultType = 791 resultOp.hasTrait("SameOperandsAndResultType"); 792 bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult"); 793 bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType"); 794 bool usePartialResults = valuePackName != resultValue; 795 796 if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr || 797 usePartialResults || depth > 0) { 798 os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", 799 valuePackName, resultOp.getQualCppClassName()); 800 } else { 801 // If depth == 0 we can use the equivalence of the source and target root 802 // ops in the pattern to determine the return type. 803 804 // We need to specify the types for all results. 805 auto resultTypes = 806 formatValuePack("op->getResult({1})->getType()", valuePackName, 807 resultOp.getNumResults(), resultIndex); 808 809 os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", 810 valuePackName, resultOp.getQualCppClassName()) 811 << (resultTypes.empty() ? "" : ", ") << resultTypes; 812 } 813 814 // Create the builder call for the result. 815 // Add operands. 816 int i = 0; 817 for (int e = resultOp.getNumOperands(); i < e; ++i) { 818 const auto &operand = resultOp.getOperand(i); 819 820 // Start each operand on its own line. 821 (os << ",\n").indent(6); 822 823 if (!operand.name.empty()) 824 os << "/*" << operand.name << "=*/"; 825 826 if (tree.isNestedDagArg(i)) { 827 os << childNodeNames[i]; 828 } else { 829 DagLeaf leaf = tree.getArgAsLeaf(i); 830 auto symbol = resolveSymbol(tree.getArgName(i)); 831 if (leaf.isNativeCodeCall()) { 832 os << tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(symbol)); 833 } else { 834 os << symbol; 835 } 836 } 837 // TODO(jpienaar): verify types 838 } 839 840 // Add attributes. 841 for (int e = tree.getNumArgs(); i != e; ++i) { 842 // Start each attribute on its own line. 843 (os << ",\n").indent(6); 844 // The argument in the op definition. 845 auto opArgName = resultOp.getArgName(i); 846 if (auto subTree = tree.getArgAsNestedDag(i)) { 847 if (!subTree.isNativeCodeCall()) 848 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 849 "for creating attribute"); 850 os << formatv("/*{0}=*/{1}", opArgName, 851 emitReplaceWithNativeCodeCall(subTree)); 852 } else { 853 auto leaf = tree.getArgAsLeaf(i); 854 // The argument in the result DAG pattern. 855 auto patArgName = tree.getArgName(i); 856 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 857 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 858 auto argument = resultOp.getArg(i); 859 if (!argument.is<NamedAttribute *>()) 860 PrintFatalError(loc, Twine("expected attribute ") + Twine(i)); 861 if (!patArgName.empty()) 862 os << "/*" << patArgName << "=*/"; 863 } else { 864 os << "/*" << opArgName << "=*/"; 865 } 866 os << handleOpArgument(leaf, patArgName); 867 } 868 } 869 os << "\n );\n"; 870 871 return resultValue; 872 } 873 874 void PatternEmitter::emit(StringRef rewriteName, Record *p, 875 RecordOperatorMap *mapper, raw_ostream &os) { 876 PatternEmitter(p, mapper, os).emit(rewriteName); 877 } 878 879 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 880 emitSourceFileHeader("Rewriters", os); 881 882 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 883 auto numPatterns = patterns.size(); 884 885 // We put the map here because it can be shared among multiple patterns. 886 RecordOperatorMap recordOpMap; 887 888 std::vector<std::string> rewriterNames; 889 rewriterNames.reserve(numPatterns); 890 891 std::string baseRewriterName = "GeneratedConvert"; 892 int rewriterIndex = 0; 893 894 for (Record *p : patterns) { 895 std::string name; 896 if (p->isAnonymous()) { 897 // If no name is provided, ensure unique rewriter names simply by 898 // appending unique suffix. 899 name = baseRewriterName + llvm::utostr(rewriterIndex++); 900 } else { 901 name = p->getName(); 902 } 903 PatternEmitter::emit(name, p, &recordOpMap, os); 904 rewriterNames.push_back(std::move(name)); 905 } 906 907 // Emit function to add the generated matchers to the pattern list. 908 os << "void populateWithGenerated(MLIRContext *context, " 909 << "OwningRewritePatternList *patterns) {\n"; 910 for (const auto &name : rewriterNames) { 911 os << " patterns->push_back(llvm::make_unique<" << name 912 << ">(context));\n"; 913 } 914 os << "}\n"; 915 } 916 917 static mlir::GenRegistration 918 genRewriters("gen-rewriters", "Generate pattern rewriters", 919 [](const RecordKeeper &records, raw_ostream &os) { 920 emitRewriters(records, os); 921 return false; 922 }); 923