1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Support/STLExtras.h" 23 #include "mlir/TableGen/Attribute.h" 24 #include "mlir/TableGen/Format.h" 25 #include "mlir/TableGen/GenInfo.h" 26 #include "mlir/TableGen/Operator.h" 27 #include "mlir/TableGen/Pattern.h" 28 #include "mlir/TableGen/Predicate.h" 29 #include "mlir/TableGen/Type.h" 30 #include "llvm/ADT/StringExtras.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/FormatAdapters.h" 34 #include "llvm/Support/PrettyStackTrace.h" 35 #include "llvm/Support/Signals.h" 36 #include "llvm/TableGen/Error.h" 37 #include "llvm/TableGen/Main.h" 38 #include "llvm/TableGen/Record.h" 39 #include "llvm/TableGen/TableGenBackend.h" 40 41 using namespace llvm; 42 using namespace mlir; 43 using namespace mlir::tblgen; 44 45 namespace llvm { 46 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 47 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 48 raw_ostream &os, StringRef style) { 49 os << v.first << ":" << v.second; 50 } 51 }; 52 } // end namespace llvm 53 54 // Returns the bound symbol for the given op argument or op named `symbol`. 55 // 56 // Arguments and ops bound in the source pattern are grouped into a 57 // transient `PatternState` struct. This struct can be accessed in both 58 // `match()` and `rewrite()` via the local variable named as `s`. 59 static Twine getBoundSymbol(const StringRef &symbol) { 60 return Twine("s.") + symbol; 61 } 62 63 // Gets the dynamic value pack's name by removing the index suffix from 64 // `symbol`. Returns `symbol` itself if it does not contain an index. 65 // 66 // We can use `name__<index>` to access the `<index>`-th value in the dynamic 67 // value pack bound to `name`. `name` is typically the results of an 68 // multi-result op. 69 static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) { 70 StringRef name, indexStr; 71 unsigned idx = 0; 72 std::tie(name, indexStr) = symbol.rsplit("__"); 73 if (indexStr.consumeInteger(10, idx)) { 74 // The second part is not an index. 75 return symbol; 76 } 77 if (index) 78 *index = idx; 79 return name; 80 } 81 82 // Formats all values from a dynamic value pack `symbol` according to the given 83 // `fmt` string. The `fmt` string should use `{0}` as a placeholder for `symbol` 84 // and `{1}` as a placeholder for the value index, which will be offsetted by 85 // `offset`. The `symbol` value pack has a total of `count` values. 86 // 87 // This extracts one value from the pack if `symbol` contains an index, 88 // otherwise it extracts all values sequentially and returns them as a 89 // comma-separated list. 90 static std::string formatValuePack(const char *fmt, StringRef symbol, 91 unsigned count, unsigned offset) { 92 auto getNthValue = [fmt, offset](StringRef results, 93 unsigned index) -> std::string { 94 return formatv(fmt, results, index + offset); 95 }; 96 97 unsigned index = 0; 98 StringRef name = getValuePackName(symbol, &index); 99 if (name != symbol) { 100 // The symbol contains an index. 101 return getNthValue(name, index); 102 } 103 104 // The symbol does not contain an index. Treat the symbol as a whole. 105 SmallVector<std::string, 4> values; 106 values.reserve(count); 107 for (unsigned i = 0; i < count; ++i) 108 values.emplace_back(getNthValue(symbol, i)); 109 return llvm::join(values, ", "); 110 } 111 112 //===----------------------------------------------------------------------===// 113 // PatternSymbolResolver 114 //===----------------------------------------------------------------------===// 115 116 namespace { 117 // A class for resolving symbols bound in patterns. 118 // 119 // Symbols can be bound to op arguments and ops in the source pattern and ops 120 // in result patterns. For example, in 121 // 122 // ``` 123 // def : Pattern<(SrcOp:$op1 $arg0, %arg1), 124 // [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>; 125 // ``` 126 // 127 // `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`. 128 // `$op2` is bound to `ResOp1`. 129 // 130 // If a symbol binds to a multi-result op and it does not have the `__N` 131 // suffix, the symbol is expanded to the whole value pack generated by the 132 // multi-result op. If the symbol has a `__N` suffix, then it will expand to 133 // only the N-th result. 134 // 135 // This class keeps track of such symbols and translates them into their bound 136 // values. 137 // 138 // Note that we also generate local variables for unnamed DAG nodes, like 139 // `(ResOp3)` in the above. Since we don't bind a symbol to the op, the 140 // generated local variable will be implicitly named. Those implicit names are 141 // not tracked in this class. 142 class PatternSymbolResolver { 143 public: 144 PatternSymbolResolver(const StringMap<Argument> &srcArgs, 145 const StringMap<const Operator *> &srcOperations); 146 147 // Marks the given `symbol` as bound to a value pack with `numValues` and 148 // returns true on success. Returns false if the `symbol` is already bound. 149 bool add(StringRef symbol, int numValues); 150 151 // Queries the substitution for the given `symbol`. Returns empty string if 152 // symbol not found. If the symbol represents a value pack, returns all the 153 // values separated via comma. 154 std::string query(StringRef symbol) const; 155 156 // Returns how many static values the given `symbol` correspond to. Returns a 157 // negative value if the given symbol is not bound. 158 // 159 // Normally a symbol would correspond to just one value; for symbols bound to 160 // multi-result ops, it can be more than one. 161 int getValueCount(StringRef symbol) const; 162 163 private: 164 // Symbols bound to arguments in source pattern. 165 const StringMap<Argument> &sourceArguments; 166 // Symbols bound to ops (for their results) in source pattern. 167 const StringMap<const Operator *> &sourceOps; 168 // Symbols bound to ops (for their results) in result patterns. 169 // Key: symbol; value: number of values inside the pack 170 StringMap<int> resultOps; 171 }; 172 } // end anonymous namespace 173 174 PatternSymbolResolver::PatternSymbolResolver( 175 const StringMap<Argument> &srcArgs, 176 const StringMap<const Operator *> &srcOperations) 177 : sourceArguments(srcArgs), sourceOps(srcOperations) {} 178 179 bool PatternSymbolResolver::add(StringRef symbol, int numValues) { 180 StringRef name = getValuePackName(symbol); 181 return resultOps.try_emplace(name, numValues).second; 182 } 183 184 std::string PatternSymbolResolver::query(StringRef symbol) const { 185 StringRef name = getValuePackName(symbol); 186 // Handle symbols bound to generated ops 187 auto resOpIt = resultOps.find(name); 188 if (resOpIt != resultOps.end()) 189 return formatValuePack("{0}.getOperation()->getResult({1})", symbol, 190 resOpIt->second, /*offset=*/0); 191 192 // Handle symbols bound to matched op arguments 193 auto srcArgIt = sourceArguments.find(symbol); 194 if (srcArgIt != sourceArguments.end()) 195 return getBoundSymbol(symbol).str(); 196 197 // Handle symbols bound to matched op results 198 auto srcOpIt = sourceOps.find(name); 199 if (srcOpIt != sourceOps.end()) 200 return formatValuePack("{0}->getResult({1})", getBoundSymbol(symbol).str(), 201 srcOpIt->second->getNumResults(), /*offset=*/0); 202 return {}; 203 } 204 205 int PatternSymbolResolver::getValueCount(StringRef symbol) const { 206 StringRef name = getValuePackName(symbol); 207 // Handle symbols bound to generated ops 208 auto resOpIt = resultOps.find(name); 209 if (resOpIt != resultOps.end()) 210 return name == symbol ? resOpIt->second : 1; 211 212 // Handle symbols bound to matched op arguments 213 if (sourceArguments.count(symbol)) 214 return 1; 215 216 // Handle symbols bound to matched op results 217 auto srcOpIt = sourceOps.find(name); 218 if (srcOpIt != sourceOps.end()) 219 return name == symbol ? srcOpIt->second->getNumResults() : 1; 220 return -1; 221 } 222 223 //===----------------------------------------------------------------------===// 224 // PatternEmitter 225 //===----------------------------------------------------------------------===// 226 227 namespace { 228 class PatternEmitter { 229 public: 230 static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper, 231 raw_ostream &os); 232 233 private: 234 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 235 236 // Emits the mlir::RewritePattern struct named `rewriteName`. 237 void emit(StringRef rewriteName); 238 239 // Emits the match() method. 240 void emitMatchMethod(DagNode tree); 241 242 // Collects all of the operations within the given dag tree. 243 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 244 245 // Emits the rewrite() method. 246 void emitRewriteMethod(); 247 248 // Emits C++ statements for matching the op constrained by the given DAG 249 // `tree`. 250 void emitOpMatch(DagNode tree, int depth); 251 252 // Emits C++ statements for matching the `index`-th argument of the given DAG 253 // `tree` as an operand. 254 void emitOperandMatch(DagNode tree, int index, int depth, int indent); 255 256 // Emits C++ statements for matching the `index`-th argument of the given DAG 257 // `tree` as an attribute. 258 void emitAttributeMatch(DagNode tree, int index, int depth, int indent); 259 260 // Returns a unique name for an value of the given `op`. 261 std::string getUniqueValueName(const Operator *op); 262 263 // Entry point for handling a rewrite pattern rooted at `resultTree` and 264 // dispatches to concrete handlers. The given tree is the `resultIndex`-th 265 // argument of the enclosing DAG. 266 std::string handleRewritePattern(DagNode resultTree, int resultIndex, 267 int depth); 268 269 // Emits the C++ statement to replace the matched DAG with a value built via 270 // calling native C++ code. 271 std::string emitReplaceWithNativeCodeCall(DagNode resultTree); 272 273 // Returns the C++ expression referencing the old value serving as the 274 // replacement. 275 std::string handleReplaceWithValue(DagNode tree); 276 277 // Emits the C++ statement to build a new op out of the given DAG `tree` and 278 // returns the variable name that this op is assigned to. If the root op in 279 // DAG `tree` has a specified name, the created op will be assigned to a 280 // variable of the given name. Otherwise, a unique name will be used as the 281 // result value name. 282 std::string emitOpCreate(DagNode tree, int resultIndex, int depth); 283 284 // Returns the C++ expression to construct a constant attribute of the given 285 // `value` for the given attribute kind `attr`. 286 std::string handleConstantAttr(Attribute attr, StringRef value); 287 288 // Returns the C++ expression to build an argument from the given DAG `leaf`. 289 // `patArgName` is used to bound the argument to the source pattern. 290 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 291 292 // Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol 293 // is already bound. 294 void addSymbol(StringRef symbol, int numValues); 295 296 // Gets the substitution for `symbol`. Aborts if `symbol` is not bound. 297 std::string resolveSymbol(StringRef symbol); 298 299 // Returns how many static values the given DAG `node` correspond to. 300 int getNodeValueCount(DagNode node); 301 302 private: 303 // Pattern instantiation location followed by the location of multiclass 304 // prototypes used. This is intended to be used as a whole to 305 // PrintFatalError() on errors. 306 ArrayRef<llvm::SMLoc> loc; 307 // Op's TableGen Record to wrapper object 308 RecordOperatorMap *opMap; 309 // Handy wrapper for pattern being emitted 310 Pattern pattern; 311 PatternSymbolResolver symbolResolver; 312 // The next unused ID for newly created values 313 unsigned nextValueId; 314 raw_ostream &os; 315 316 // Format contexts containing placeholder substitutations for match(). 317 FmtContext matchCtx; 318 // Format contexts containing placeholder substitutations for rewrite(). 319 FmtContext rewriteCtx; 320 321 // Number of op processed. 322 int opCounter = 0; 323 }; 324 } // end anonymous namespace 325 326 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 327 raw_ostream &os) 328 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 329 symbolResolver(pattern.getSourcePatternBoundArgs(), 330 pattern.getSourcePatternBoundOps()), 331 nextValueId(0), os(os) { 332 matchCtx.withBuilder("mlir::Builder(ctx)"); 333 rewriteCtx.withBuilder("rewriter"); 334 } 335 336 std::string PatternEmitter::handleConstantAttr(Attribute attr, 337 StringRef value) { 338 if (!attr.isConstBuildable()) 339 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 340 " does not have the 'constBuilderCall' field"); 341 342 // TODO(jpienaar): Verify the constants here 343 return tgfmt(attr.getConstBuilderTemplate(), 344 &rewriteCtx.withBuilder("rewriter"), value); 345 } 346 347 // Helper function to match patterns. 348 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 349 Operator &op = tree.getDialectOp(opMap); 350 if (op.isVariadic()) { 351 PrintFatalError(loc, formatv("matching op '{0}' with variadic " 352 "operands/results is unsupported right now", 353 op.getOperationName())); 354 } 355 356 int indent = 4 + 2 * depth; 357 // Skip the operand matching at depth 0 as the pattern rewriter already does. 358 if (depth != 0) { 359 // Skip if there is no defining operation (e.g., arguments to function). 360 os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); 361 os.indent(indent) << formatv( 362 "if (!isa<{1}>(op{0})) return matchFailure();\n", depth, 363 op.getQualCppClassName()); 364 } 365 if (tree.getNumArgs() != op.getNumArgs()) { 366 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 367 "pattern vs. {2} in definition", 368 op.getOperationName(), tree.getNumArgs(), 369 op.getNumArgs())); 370 } 371 372 // If the operand's name is set, set to that variable. 373 auto name = tree.getSymbol(); 374 if (!name.empty()) 375 os.indent(indent) << formatv("{0} = op{1};\n", getBoundSymbol(name), depth); 376 377 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 378 auto opArg = op.getArg(i); 379 380 // Handle nested DAG construct first 381 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 382 os.indent(indent) << "{\n"; 383 os.indent(indent + 2) 384 << formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n", 385 depth + 1, depth, i); 386 emitOpMatch(argTree, depth + 1); 387 os.indent(indent + 2) 388 << formatv("s.autogeneratedRewritePatternOps[{0}] = op{1};\n", 389 ++opCounter, depth + 1); 390 os.indent(indent) << "}\n"; 391 continue; 392 } 393 394 // Next handle DAG leaf: operand or attribute 395 if (opArg.is<NamedTypeConstraint *>()) { 396 emitOperandMatch(tree, i, depth, indent); 397 } else if (opArg.is<NamedAttribute *>()) { 398 emitAttributeMatch(tree, i, depth, indent); 399 } else { 400 PrintFatalError(loc, "unhandled case when matching op"); 401 } 402 } 403 } 404 405 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, 406 int indent) { 407 Operator &op = tree.getDialectOp(opMap); 408 auto *operand = op.getArg(index).get<NamedTypeConstraint *>(); 409 auto matcher = tree.getArgAsLeaf(index); 410 411 // If a constraint is specified, we need to generate C++ statements to 412 // check the constraint. 413 if (!matcher.isUnspecified()) { 414 if (!matcher.isOperandMatcher()) { 415 PrintFatalError( 416 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 417 op.getOperationName(), index + 1)); 418 } 419 420 // Only need to verify if the matcher's type is different from the one 421 // of op definition. 422 if (operand->constraint != matcher.getAsConstraint()) { 423 auto self = formatv("op{0}->getOperand({1})->getType()", depth, index); 424 os.indent(indent) << "if (!(" 425 << tgfmt(matcher.getConditionTemplate(), 426 &matchCtx.withSelf(self)) 427 << ")) return matchFailure();\n"; 428 } 429 } 430 431 // Capture the value 432 auto name = tree.getArgName(index); 433 if (!name.empty()) { 434 os.indent(indent) << getBoundSymbol(name) << " = op" << depth 435 << "->getOperand(" << index << ");\n"; 436 } 437 } 438 439 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, 440 int indent) { 441 Operator &op = tree.getDialectOp(opMap); 442 auto *namedAttr = op.getArg(index).get<NamedAttribute *>(); 443 const auto &attr = namedAttr->attr; 444 445 os.indent(indent) << "{\n"; 446 indent += 2; 447 os.indent(indent) << formatv( 448 "auto attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, 449 attr.getStorageType(), namedAttr->name); 450 451 // TODO(antiagainst): This should use getter method to avoid duplication. 452 if (attr.hasDefaultValueInitializer()) { 453 os.indent(indent) << "if (!attr) attr = " 454 << tgfmt(attr.getConstBuilderTemplate(), &matchCtx, 455 attr.getDefaultValueInitializer()) 456 << ";\n"; 457 } else if (attr.isOptional()) { 458 // For a missing attribut that is optional according to definition, we 459 // should just capature a mlir::Attribute() to signal the missing state. 460 // That is precisely what getAttr() returns on missing attributes. 461 } else { 462 os.indent(indent) << "if (!attr) return matchFailure();\n"; 463 } 464 465 auto matcher = tree.getArgAsLeaf(index); 466 if (!matcher.isUnspecified()) { 467 if (!matcher.isAttrMatcher()) { 468 PrintFatalError( 469 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 470 op.getOperationName(), index + 1)); 471 } 472 473 // If a constraint is specified, we need to generate C++ statements to 474 // check the constraint. 475 os.indent(indent) << "if (!(" 476 << tgfmt(matcher.getConditionTemplate(), 477 &matchCtx.withSelf("attr")) 478 << ")) return matchFailure();\n"; 479 } 480 481 // Capture the value 482 auto name = tree.getArgName(index); 483 if (!name.empty()) { 484 os.indent(indent) << getBoundSymbol(name) << " = attr;\n"; 485 } 486 487 indent -= 2; 488 os.indent(indent) << "}\n"; 489 } 490 491 void PatternEmitter::emitMatchMethod(DagNode tree) { 492 // Emit the heading. 493 os << R"( 494 PatternMatchResult match(Operation *op0) const override { 495 auto ctx = op0->getContext(); (void)ctx; 496 auto state = llvm::make_unique<MatchedState>(); 497 auto &s = *state; 498 s.autogeneratedRewritePatternOps[0] = op0; 499 )"; 500 501 emitOpMatch(tree, 0); 502 503 for (auto &appliedConstraint : pattern.getConstraints()) { 504 auto &constraint = appliedConstraint.constraint; 505 auto &entities = appliedConstraint.entities; 506 507 auto condition = constraint.getConditionTemplate(); 508 auto cmd = "if (!({0})) return matchFailure();\n"; 509 510 if (isa<TypeConstraint>(constraint)) { 511 auto self = formatv("({0}->getType())", resolveSymbol(entities.front())); 512 os.indent(4) << formatv(cmd, 513 tgfmt(condition, &matchCtx.withSelf(self.str()))); 514 } else if (isa<AttrConstraint>(constraint)) { 515 PrintFatalError( 516 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 517 } else { 518 // TODO(b/138794486): replace formatv arguments with the exact specified 519 // args. 520 if (entities.size() > 4) { 521 PrintFatalError(loc, "only support up to 4-entity constraints now"); 522 } 523 SmallVector<std::string, 4> names; 524 int i = 0; 525 for (int e = entities.size(); i < e; ++i) 526 names.push_back(resolveSymbol(entities[i])); 527 std::string self = appliedConstraint.self; 528 if (!self.empty()) 529 self = resolveSymbol(self); 530 for (; i < 4; ++i) 531 names.push_back("<unused>"); 532 os.indent(4) << formatv(cmd, 533 tgfmt(condition, &matchCtx.withSelf(self), 534 names[0], names[1], names[2], names[3])); 535 } 536 } 537 538 os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; 539 } 540 541 void PatternEmitter::collectOps(DagNode tree, 542 llvm::SmallPtrSetImpl<const Operator *> &ops) { 543 // Check if this tree is an operation. 544 if (tree.isOperation()) 545 ops.insert(&tree.getDialectOp(opMap)); 546 547 // Recurse the arguments of the tree. 548 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 549 if (auto child = tree.getArgAsNestedDag(i)) 550 collectOps(child, ops); 551 } 552 553 void PatternEmitter::emit(StringRef rewriteName) { 554 // Get the DAG tree for the source pattern 555 DagNode tree = pattern.getSourcePattern(); 556 557 const Operator &rootOp = pattern.getSourceRootOp(); 558 auto rootName = rootOp.getOperationName(); 559 560 // Collect the set of result operations. 561 llvm::SmallPtrSet<const Operator *, 4> results; 562 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) 563 collectOps(pattern.getResultPattern(i), results); 564 565 // Emit RewritePattern for Pattern. 566 auto locs = pattern.getLocation(); 567 os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", 568 make_range(locs.rbegin(), locs.rend())); 569 os << formatv(R"(struct {0} : public RewritePattern { 570 {0}(MLIRContext *context) : RewritePattern("{1}", {{)", 571 rewriteName, rootName); 572 interleaveComma(results, os, [&](const Operator *op) { 573 os << '"' << op->getOperationName() << '"'; 574 }); 575 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 576 577 // Emit matched state. 578 os << " struct MatchedState : public PatternState {\n"; 579 for (const auto &arg : pattern.getSourcePatternBoundArgs()) { 580 auto fieldName = arg.first(); 581 if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) { 582 os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName 583 << ";\n"; 584 } else { 585 os.indent(4) << "Value *" << fieldName << ";\n"; 586 } 587 } 588 for (const auto &result : pattern.getSourcePatternBoundOps()) { 589 os.indent(4) << "Operation *" << result.getKey() << ";\n"; 590 } 591 // TODO(jpienaar): Change to matchAndRewrite & capture ops with consistent 592 // numbering so that it can be reused for fused loc. 593 os.indent(4) << "Operation* autogeneratedRewritePatternOps[" 594 << pattern.getSourcePattern().getNumOps() << "];\n"; 595 os << " };\n"; 596 597 emitMatchMethod(tree); 598 emitRewriteMethod(); 599 600 os << "};\n"; 601 } 602 603 void PatternEmitter::emitRewriteMethod() { 604 const Operator &rootOp = pattern.getSourceRootOp(); 605 int numExpectedResults = rootOp.getNumResults(); 606 int numResultPatterns = pattern.getNumResultPatterns(); 607 608 // First register all symbols bound to ops generated in result patterns. 609 for (const auto &boundOp : pattern.getResultPatternBoundOps()) { 610 addSymbol(boundOp.getKey(), boundOp.getValue()->getNumResults()); 611 } 612 613 // Only the last N static values generated are used to replace the matched 614 // root N-result op. We need to calculate the starting index (of the results 615 // of the matched op) each result pattern is to replace. 616 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 617 int replStartIndex = -1; 618 for (int i = numResultPatterns - 1; i >= 0; --i) { 619 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 620 offsets[i] = offsets[i + 1] - numValues; 621 if (offsets[i] == 0) { 622 replStartIndex = i; 623 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 624 auto error = formatv( 625 "cannot use the same multi-result op '{0}' to generate both " 626 "auxiliary values and values to be used for replacing the matched op", 627 pattern.getResultPattern(i).getSymbol()); 628 PrintFatalError(loc, error); 629 } 630 } 631 632 if (offsets.front() > 0) { 633 const char error[] = "no enough values generated to replace the matched op"; 634 PrintFatalError(loc, error); 635 } 636 637 os << R"( 638 void rewrite(Operation *op, std::unique_ptr<PatternState> state, 639 PatternRewriter &rewriter) const override { 640 auto& s = *static_cast<MatchedState *>(state.get()); 641 auto loc = rewriter.getFusedLoc({)"; 642 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 643 os << (i ? ", " : "") << "s.autogeneratedRewritePatternOps[" << i 644 << "]->getLoc()"; 645 } 646 os << "}); (void)loc;\n"; 647 648 // Collect the replacement value for each result 649 llvm::SmallVector<std::string, 2> resultValues; 650 for (int i = 0; i < numResultPatterns; ++i) { 651 DagNode resultTree = pattern.getResultPattern(i); 652 resultValues.push_back(handleRewritePattern(resultTree, offsets[i], 0)); 653 } 654 655 // Emit the final replaceOp() statement 656 os.indent(4) << "rewriter.replaceOp(op, {"; 657 interleave( 658 ArrayRef<std::string>(resultValues).drop_front(replStartIndex), 659 [&](const std::string &name) { os << name; }, [&]() { os << ", "; }); 660 os << "});\n }\n"; 661 } 662 663 std::string PatternEmitter::getUniqueValueName(const Operator *op) { 664 return formatv("v{0}{1}", op->getCppClassName(), nextValueId++); 665 } 666 667 std::string PatternEmitter::handleRewritePattern(DagNode resultTree, 668 int resultIndex, int depth) { 669 if (resultTree.isNativeCodeCall()) 670 return emitReplaceWithNativeCodeCall(resultTree); 671 672 if (resultTree.isReplaceWithValue()) 673 return handleReplaceWithValue(resultTree); 674 675 // Create the op and get the local variable for it. 676 auto results = emitOpCreate(resultTree, resultIndex, depth); 677 // We need to get all the values out of this local variable if we've created a 678 // multi-result op. 679 const auto &numResults = pattern.getDialectOp(resultTree).getNumResults(); 680 return formatValuePack("{0}.getOperation()->getResult({1})", results, 681 numResults, /*offset=*/0); 682 } 683 684 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { 685 assert(tree.isReplaceWithValue()); 686 687 if (tree.getNumArgs() != 1) { 688 PrintFatalError( 689 loc, "replaceWithValue directive must take exactly one argument"); 690 } 691 692 if (!tree.getSymbol().empty()) { 693 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 694 } 695 696 return resolveSymbol(tree.getArgName(0)); 697 } 698 699 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) { 700 if (leaf.isConstantAttr()) { 701 auto constAttr = leaf.getAsConstantAttr(); 702 return handleConstantAttr(constAttr.getAttribute(), 703 constAttr.getConstantValue()); 704 } 705 if (leaf.isEnumAttrCase()) { 706 auto enumCase = leaf.getAsEnumAttrCase(); 707 if (enumCase.isStrCase()) 708 return handleConstantAttr(enumCase, enumCase.getSymbol()); 709 // This is an enum case backed by an IntegerAttr. We need to get its value 710 // to build the constant. 711 std::string val = std::to_string(enumCase.getValue()); 712 return handleConstantAttr(enumCase, val); 713 } 714 pattern.ensureBoundInSourcePattern(argName); 715 std::string result = getBoundSymbol(argName).str(); 716 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 717 return result; 718 } 719 if (leaf.isNativeCodeCall()) { 720 return tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(result)); 721 } 722 PrintFatalError(loc, "unhandled case when rewriting op"); 723 } 724 725 std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) { 726 auto fmt = tree.getNativeCodeTemplate(); 727 // TODO(b/138794486): replace formatv arguments with the exact specified args. 728 SmallVector<std::string, 8> attrs(8); 729 if (tree.getNumArgs() > 8) { 730 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 731 Twine(tree.getNumArgs())); 732 } 733 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 734 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 735 } 736 return tgfmt(fmt, &rewriteCtx, attrs[0], attrs[1], attrs[2], attrs[3], 737 attrs[4], attrs[5], attrs[6], attrs[7]); 738 } 739 740 void PatternEmitter::addSymbol(StringRef symbol, int numValues) { 741 if (!symbolResolver.add(symbol, numValues)) 742 PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol)); 743 } 744 745 std::string PatternEmitter::resolveSymbol(StringRef symbol) { 746 auto subst = symbolResolver.query(symbol); 747 if (subst.empty()) 748 PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol)); 749 return subst; 750 } 751 752 int PatternEmitter::getNodeValueCount(DagNode node) { 753 if (node.isOperation()) { 754 // First to see whether this op is bound and we just want a specific result 755 // of it with `__N` suffix in symbol. 756 int count = symbolResolver.getValueCount(node.getSymbol()); 757 if (count >= 0) 758 return count; 759 760 // No symbol. Then we are using all the results. 761 return pattern.getDialectOp(node).getNumResults(); 762 } 763 // TODO(antiagainst): This considers all NativeCodeCall as returning one 764 // value. Enhance if multi-value ones are needed. 765 return 1; 766 } 767 768 std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, 769 int depth) { 770 Operator &resultOp = tree.getDialectOp(opMap); 771 auto numOpArgs = resultOp.getNumArgs(); 772 773 if (resultOp.isVariadic()) { 774 PrintFatalError(loc, formatv("generating op '{0}' with variadic " 775 "operands/results is unsupported now", 776 resultOp.getOperationName())); 777 } 778 779 if (numOpArgs != tree.getNumArgs()) { 780 PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " 781 "{1} in pattern vs. {2} in definition", 782 resultOp.getOperationName(), tree.getNumArgs(), 783 numOpArgs)); 784 } 785 786 // A map to collect all nested DAG child nodes' names, with operand index as 787 // the key. This includes both bound and unbound child nodes. Bound child 788 // nodes will additionally be tracked in `symbolResolver` so they can be 789 // referenced by other patterns. Unbound child nodes will only be used once 790 // to build this op. 791 llvm::DenseMap<unsigned, std::string> childNodeNames; 792 793 // First go through all the child nodes who are nested DAG constructs to 794 // create ops for them, so that we can use the results in the current node. 795 // This happens in a recursive manner. 796 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 797 if (auto child = tree.getArgAsNestedDag(i)) { 798 childNodeNames[i] = handleRewritePattern(child, i, depth + 1); 799 } 800 } 801 802 // Use the specified name for this op if available. Generate one otherwise. 803 std::string resultValue = tree.getSymbol(); 804 if (resultValue.empty()) 805 resultValue = getUniqueValueName(&resultOp); 806 // Strip the index to get the name for the value pack. This will be used to 807 // name the local variable for the op. 808 StringRef valuePackName = getValuePackName(resultValue); 809 810 // Then we build the new op corresponding to this DAG node. 811 812 // Right now we don't have general type inference in MLIR. Except a few 813 // special cases listed below, we need to supply types for all results 814 // when building an op. 815 bool isSameOperandsAndResultType = 816 resultOp.hasTrait("SameOperandsAndResultType"); 817 bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult"); 818 bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType"); 819 bool usePartialResults = valuePackName != resultValue; 820 821 if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr || 822 usePartialResults || depth > 0 || resultIndex < 0) { 823 os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", 824 valuePackName, resultOp.getQualCppClassName()); 825 } else { 826 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 827 // generated from the source pattern root op. Then we can use the source 828 // pattern's value types to determine the value type of the generated op 829 // here. 830 831 // We need to specify the types for all results. 832 auto resultTypes = 833 formatValuePack("op->getResult({1})->getType()", valuePackName, 834 resultOp.getNumResults(), resultIndex); 835 836 os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", 837 valuePackName, resultOp.getQualCppClassName()) 838 << (resultTypes.empty() ? "" : ", ") << resultTypes; 839 } 840 841 // Create the builder call for the result. 842 // Add operands. 843 int i = 0; 844 for (int e = resultOp.getNumOperands(); i < e; ++i) { 845 const auto &operand = resultOp.getOperand(i); 846 847 // Start each operand on its own line. 848 (os << ",\n").indent(6); 849 850 if (!operand.name.empty()) 851 os << "/*" << operand.name << "=*/"; 852 853 if (tree.isNestedDagArg(i)) { 854 os << childNodeNames[i]; 855 } else { 856 DagLeaf leaf = tree.getArgAsLeaf(i); 857 auto symbol = resolveSymbol(tree.getArgName(i)); 858 if (leaf.isNativeCodeCall()) { 859 os << tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(symbol)); 860 } else { 861 os << symbol; 862 } 863 } 864 // TODO(jpienaar): verify types 865 } 866 867 // Add attributes. 868 for (int e = tree.getNumArgs(); i != e; ++i) { 869 // Start each attribute on its own line. 870 (os << ",\n").indent(6); 871 // The argument in the op definition. 872 auto opArgName = resultOp.getArgName(i); 873 if (auto subTree = tree.getArgAsNestedDag(i)) { 874 if (!subTree.isNativeCodeCall()) 875 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 876 "for creating attribute"); 877 os << formatv("/*{0}=*/{1}", opArgName, 878 emitReplaceWithNativeCodeCall(subTree)); 879 } else { 880 auto leaf = tree.getArgAsLeaf(i); 881 // The argument in the result DAG pattern. 882 auto patArgName = tree.getArgName(i); 883 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 884 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 885 auto argument = resultOp.getArg(i); 886 if (!argument.is<NamedAttribute *>()) 887 PrintFatalError(loc, Twine("expected attribute ") + Twine(i)); 888 if (!patArgName.empty()) 889 os << "/*" << patArgName << "=*/"; 890 } else { 891 os << "/*" << opArgName << "=*/"; 892 } 893 os << handleOpArgument(leaf, patArgName); 894 } 895 } 896 os << "\n );\n"; 897 898 return resultValue; 899 } 900 901 void PatternEmitter::emit(StringRef rewriteName, Record *p, 902 RecordOperatorMap *mapper, raw_ostream &os) { 903 PatternEmitter(p, mapper, os).emit(rewriteName); 904 } 905 906 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 907 emitSourceFileHeader("Rewriters", os); 908 909 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 910 auto numPatterns = patterns.size(); 911 912 // We put the map here because it can be shared among multiple patterns. 913 RecordOperatorMap recordOpMap; 914 915 std::vector<std::string> rewriterNames; 916 rewriterNames.reserve(numPatterns); 917 918 std::string baseRewriterName = "GeneratedConvert"; 919 int rewriterIndex = 0; 920 921 for (Record *p : patterns) { 922 std::string name; 923 if (p->isAnonymous()) { 924 // If no name is provided, ensure unique rewriter names simply by 925 // appending unique suffix. 926 name = baseRewriterName + llvm::utostr(rewriterIndex++); 927 } else { 928 name = p->getName(); 929 } 930 PatternEmitter::emit(name, p, &recordOpMap, os); 931 rewriterNames.push_back(std::move(name)); 932 } 933 934 // Emit function to add the generated matchers to the pattern list. 935 os << "void populateWithGenerated(MLIRContext *context, " 936 << "OwningRewritePatternList *patterns) {\n"; 937 for (const auto &name : rewriterNames) { 938 os << " patterns->push_back(llvm::make_unique<" << name 939 << ">(context));\n"; 940 } 941 os << "}\n"; 942 } 943 944 static mlir::GenRegistration 945 genRewriters("gen-rewriters", "Generate pattern rewriters", 946 [](const RecordKeeper &records, raw_ostream &os) { 947 emitRewriters(records, os); 948 return false; 949 }); 950