1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Support/STLExtras.h" 23 #include "mlir/TableGen/Attribute.h" 24 #include "mlir/TableGen/Format.h" 25 #include "mlir/TableGen/GenInfo.h" 26 #include "mlir/TableGen/Operator.h" 27 #include "mlir/TableGen/Pattern.h" 28 #include "mlir/TableGen/Predicate.h" 29 #include "mlir/TableGen/Type.h" 30 #include "llvm/ADT/StringExtras.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/FormatAdapters.h" 34 #include "llvm/Support/PrettyStackTrace.h" 35 #include "llvm/Support/Signals.h" 36 #include "llvm/TableGen/Error.h" 37 #include "llvm/TableGen/Main.h" 38 #include "llvm/TableGen/Record.h" 39 #include "llvm/TableGen/TableGenBackend.h" 40 41 using namespace llvm; 42 using namespace mlir; 43 using namespace mlir::tblgen; 44 45 namespace llvm { 46 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 47 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 48 raw_ostream &os, StringRef style) { 49 os << v.first << ":" << v.second; 50 } 51 }; 52 } // end namespace llvm 53 54 // Gets the dynamic value pack's name by removing the index suffix from 55 // `symbol`. Returns `symbol` itself if it does not contain an index. 56 // 57 // We can use `name__<index>` to access the `<index>`-th value in the dynamic 58 // value pack bound to `name`. `name` is typically the results of an 59 // multi-result op. 60 static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) { 61 StringRef name, indexStr; 62 unsigned idx = 0; 63 std::tie(name, indexStr) = symbol.rsplit("__"); 64 if (indexStr.consumeInteger(10, idx)) { 65 // The second part is not an index. 66 return symbol; 67 } 68 if (index) 69 *index = idx; 70 return name; 71 } 72 73 // Formats all values from a dynamic value pack `symbol` according to the given 74 // `fmt` string. The `fmt` string should use `{0}` as a placeholder for `symbol` 75 // and `{1}` as a placeholder for the value index, which will be offsetted by 76 // `offset`. The `symbol` value pack has a total of `count` values. 77 // 78 // This extracts one value from the pack if `symbol` contains an index, 79 // otherwise it extracts all values sequentially and returns them as a 80 // comma-separated list. 81 static std::string formatValuePack(const char *fmt, StringRef symbol, 82 unsigned count, unsigned offset) { 83 auto getNthValue = [fmt, offset](StringRef results, 84 unsigned index) -> std::string { 85 return formatv(fmt, results, index + offset); 86 }; 87 88 unsigned index = 0; 89 StringRef name = getValuePackName(symbol, &index); 90 if (name != symbol) { 91 // The symbol contains an index. 92 return getNthValue(name, index); 93 } 94 95 // The symbol does not contain an index. Treat the symbol as a whole. 96 SmallVector<std::string, 4> values; 97 values.reserve(count); 98 for (unsigned i = 0; i < count; ++i) 99 values.emplace_back(getNthValue(symbol, i)); 100 return llvm::join(values, ", "); 101 } 102 103 //===----------------------------------------------------------------------===// 104 // PatternSymbolResolver 105 //===----------------------------------------------------------------------===// 106 107 namespace { 108 // A class for resolving symbols bound in patterns. 109 // 110 // Symbols can be bound to op arguments and ops in the source pattern and ops 111 // in result patterns. For example, in 112 // 113 // ``` 114 // def : Pattern<(SrcOp:$op1 $arg0, %arg1), 115 // [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>; 116 // ``` 117 // 118 // `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`. 119 // `$op2` is bound to `ResOp1`. 120 // 121 // If a symbol binds to a multi-result op and it does not have the `__N` 122 // suffix, the symbol is expanded to the whole value pack generated by the 123 // multi-result op. If the symbol has a `__N` suffix, then it will expand to 124 // only the N-th result. 125 // 126 // This class keeps track of such symbols and translates them into their bound 127 // values. 128 // 129 // Note that we also generate local variables for unnamed DAG nodes, like 130 // `(ResOp3)` in the above. Since we don't bind a symbol to the op, the 131 // generated local variable will be implicitly named. Those implicit names are 132 // not tracked in this class. 133 class PatternSymbolResolver { 134 public: 135 PatternSymbolResolver(const StringMap<Argument> &srcArgs, 136 const StringMap<const Operator *> &srcOperations); 137 138 // Marks the given `symbol` as bound to a value pack with `numValues` and 139 // returns true on success. Returns false if the `symbol` is already bound. 140 bool add(StringRef symbol, int numValues); 141 142 // Queries the substitution for the given `symbol`. Returns empty string if 143 // symbol not found. If the symbol represents a value pack, returns all the 144 // values separated via comma. 145 std::string query(StringRef symbol) const; 146 147 // Returns how many static values the given `symbol` correspond to. Returns a 148 // negative value if the given symbol is not bound. 149 // 150 // Normally a symbol would correspond to just one value; for symbols bound to 151 // multi-result ops, it can be more than one. 152 int getValueCount(StringRef symbol) const; 153 154 private: 155 // Symbols bound to arguments in source pattern. 156 const StringMap<Argument> &sourceArguments; 157 // Symbols bound to ops (for their results) in source pattern. 158 const StringMap<const Operator *> &sourceOps; 159 // Symbols bound to ops (for their results) in result patterns. 160 // Key: symbol; value: number of values inside the pack 161 StringMap<int> resultOps; 162 }; 163 } // end anonymous namespace 164 165 PatternSymbolResolver::PatternSymbolResolver( 166 const StringMap<Argument> &srcArgs, 167 const StringMap<const Operator *> &srcOperations) 168 : sourceArguments(srcArgs), sourceOps(srcOperations) {} 169 170 bool PatternSymbolResolver::add(StringRef symbol, int numValues) { 171 StringRef name = getValuePackName(symbol); 172 return resultOps.try_emplace(name, numValues).second; 173 } 174 175 std::string PatternSymbolResolver::query(StringRef symbol) const { 176 StringRef name = getValuePackName(symbol); 177 // Handle symbols bound to generated ops 178 auto resOpIt = resultOps.find(name); 179 if (resOpIt != resultOps.end()) 180 return formatValuePack("{0}.getOperation()->getResult({1})", symbol, 181 resOpIt->second, /*offset=*/0); 182 183 // Handle symbols bound to matched op arguments 184 auto srcArgIt = sourceArguments.find(symbol); 185 if (srcArgIt != sourceArguments.end()) 186 return symbol; 187 188 // Handle symbols bound to matched op results 189 auto srcOpIt = sourceOps.find(name); 190 if (srcOpIt != sourceOps.end()) 191 return formatValuePack("{0}->getResult({1})", symbol, 192 srcOpIt->second->getNumResults(), /*offset=*/0); 193 return {}; 194 } 195 196 int PatternSymbolResolver::getValueCount(StringRef symbol) const { 197 StringRef name = getValuePackName(symbol); 198 // Handle symbols bound to generated ops 199 auto resOpIt = resultOps.find(name); 200 if (resOpIt != resultOps.end()) 201 return name == symbol ? resOpIt->second : 1; 202 203 // Handle symbols bound to matched op arguments 204 if (sourceArguments.count(symbol)) 205 return 1; 206 207 // Handle symbols bound to matched op results 208 auto srcOpIt = sourceOps.find(name); 209 if (srcOpIt != sourceOps.end()) 210 return name == symbol ? srcOpIt->second->getNumResults() : 1; 211 return -1; 212 } 213 214 //===----------------------------------------------------------------------===// 215 // PatternEmitter 216 //===----------------------------------------------------------------------===// 217 218 namespace { 219 class PatternEmitter { 220 public: 221 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 222 223 // Emits the mlir::RewritePattern struct named `rewriteName`. 224 void emit(StringRef rewriteName); 225 226 private: 227 // Emits the code for matching ops. 228 void emitMatchLogic(DagNode tree); 229 230 // Emits the code for rewriting ops. 231 void emitRewriteLogic(); 232 233 //===--------------------------------------------------------------------===// 234 // Match utilities 235 //===--------------------------------------------------------------------===// 236 237 // Emits C++ statements for matching the op constrained by the given DAG 238 // `tree`. 239 void emitOpMatch(DagNode tree, int depth); 240 241 // Emits C++ statements for matching the `index`-th argument of the given DAG 242 // `tree` as an operand. 243 void emitOperandMatch(DagNode tree, int index, int depth, int indent); 244 245 // Emits C++ statements for matching the `index`-th argument of the given DAG 246 // `tree` as an attribute. 247 void emitAttributeMatch(DagNode tree, int index, int depth, int indent); 248 249 //===--------------------------------------------------------------------===// 250 // Rewrite utilities 251 //===--------------------------------------------------------------------===// 252 253 // Entry point for handling a result pattern rooted at `resultTree` and 254 // dispatches to concrete handlers. The given tree is the `resultIndex`-th 255 // argument of the enclosing DAG. 256 std::string handleResultPattern(DagNode resultTree, int resultIndex, 257 int depth); 258 259 // Emits the C++ statement to replace the matched DAG with a value built via 260 // calling native C++ code. 261 std::string handleReplaceWithNativeCodeCall(DagNode resultTree); 262 263 // Returns the C++ expression referencing the old value serving as the 264 // replacement. 265 std::string handleReplaceWithValue(DagNode tree); 266 267 // Emits the C++ statement to build a new op out of the given DAG `tree` and 268 // returns the variable name that this op is assigned to. If the root op in 269 // DAG `tree` has a specified name, the created op will be assigned to a 270 // variable of the given name. Otherwise, a unique name will be used as the 271 // result value name. 272 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 273 274 // Returns the C++ expression to construct a constant attribute of the given 275 // `value` for the given attribute kind `attr`. 276 std::string handleConstantAttr(Attribute attr, StringRef value); 277 278 // Returns the C++ expression to build an argument from the given DAG `leaf`. 279 // `patArgName` is used to bound the argument to the source pattern. 280 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 281 282 //===--------------------------------------------------------------------===// 283 // General utilities 284 //===--------------------------------------------------------------------===// 285 286 // Collects all of the operations within the given dag tree. 287 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 288 289 // Returns a unique name for a value of the given `op`. 290 std::string getUniqueValueName(const Operator *op); 291 292 //===--------------------------------------------------------------------===// 293 // Symbol utilities 294 //===--------------------------------------------------------------------===// 295 296 // Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol 297 // is already bound. 298 void addSymbol(StringRef symbol, int numValues); 299 300 // Gets the substitution for `symbol`. Aborts if `symbol` is not bound. 301 std::string resolveSymbol(StringRef symbol); 302 303 // Returns how many static values the given DAG `node` correspond to. 304 int getNodeValueCount(DagNode node); 305 306 private: 307 // Pattern instantiation location followed by the location of multiclass 308 // prototypes used. This is intended to be used as a whole to 309 // PrintFatalError() on errors. 310 ArrayRef<llvm::SMLoc> loc; 311 // Op's TableGen Record to wrapper object 312 RecordOperatorMap *opMap; 313 // Handy wrapper for pattern being emitted 314 Pattern pattern; 315 PatternSymbolResolver symbolResolver; 316 // The next unused ID for newly created values 317 unsigned nextValueId; 318 raw_ostream &os; 319 320 // Format contexts containing placeholder substitutations. 321 FmtContext fmtCtx; 322 323 // Number of op processed. 324 int opCounter = 0; 325 }; 326 } // end anonymous namespace 327 328 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 329 raw_ostream &os) 330 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 331 symbolResolver(pattern.getSourcePatternBoundArgs(), 332 pattern.getSourcePatternBoundOps()), 333 nextValueId(0), os(os) { 334 fmtCtx.withBuilder("rewriter"); 335 } 336 337 std::string PatternEmitter::handleConstantAttr(Attribute attr, 338 StringRef value) { 339 if (!attr.isConstBuildable()) 340 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 341 " does not have the 'constBuilderCall' field"); 342 343 // TODO(jpienaar): Verify the constants here 344 return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value); 345 } 346 347 // Helper function to match patterns. 348 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 349 Operator &op = tree.getDialectOp(opMap); 350 if (op.isVariadic()) { 351 PrintFatalError(loc, formatv("matching op '{0}' with variadic " 352 "operands/results is unsupported right now", 353 op.getOperationName())); 354 } 355 356 int indent = 4 + 2 * depth; 357 // Skip the operand matching at depth 0 as the pattern rewriter already does. 358 if (depth != 0) { 359 // Skip if there is no defining operation (e.g., arguments to function). 360 os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); 361 os.indent(indent) << formatv( 362 "if (!isa<{1}>(op{0})) return matchFailure();\n", depth, 363 op.getQualCppClassName()); 364 } 365 if (tree.getNumArgs() != op.getNumArgs()) { 366 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 367 "pattern vs. {2} in definition", 368 op.getOperationName(), tree.getNumArgs(), 369 op.getNumArgs())); 370 } 371 372 // If the operand's name is set, set to that variable. 373 auto name = tree.getSymbol(); 374 if (!name.empty()) 375 os.indent(indent) << formatv("{0} = op{1};\n", name, depth); 376 377 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 378 auto opArg = op.getArg(i); 379 380 // Handle nested DAG construct first 381 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 382 os.indent(indent) << "{\n"; 383 os.indent(indent + 2) 384 << formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n", 385 depth + 1, depth, i); 386 emitOpMatch(argTree, depth + 1); 387 os.indent(indent + 2) 388 << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); 389 os.indent(indent) << "}\n"; 390 continue; 391 } 392 393 // Next handle DAG leaf: operand or attribute 394 if (opArg.is<NamedTypeConstraint *>()) { 395 emitOperandMatch(tree, i, depth, indent); 396 } else if (opArg.is<NamedAttribute *>()) { 397 emitAttributeMatch(tree, i, depth, indent); 398 } else { 399 PrintFatalError(loc, "unhandled case when matching op"); 400 } 401 } 402 } 403 404 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, 405 int indent) { 406 Operator &op = tree.getDialectOp(opMap); 407 auto *operand = op.getArg(index).get<NamedTypeConstraint *>(); 408 auto matcher = tree.getArgAsLeaf(index); 409 410 // If a constraint is specified, we need to generate C++ statements to 411 // check the constraint. 412 if (!matcher.isUnspecified()) { 413 if (!matcher.isOperandMatcher()) { 414 PrintFatalError( 415 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 416 op.getOperationName(), index + 1)); 417 } 418 419 // Only need to verify if the matcher's type is different from the one 420 // of op definition. 421 if (operand->constraint != matcher.getAsConstraint()) { 422 auto self = formatv("op{0}->getOperand({1})->getType()", depth, index); 423 os.indent(indent) << "if (!(" 424 << tgfmt(matcher.getConditionTemplate(), 425 &fmtCtx.withSelf(self)) 426 << ")) return matchFailure();\n"; 427 } 428 } 429 430 // Capture the value 431 auto name = tree.getArgName(index); 432 if (!name.empty()) { 433 os.indent(indent) << formatv("{0} = op{1}->getOperand({2});\n", name, depth, 434 index); 435 } 436 } 437 438 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, 439 int indent) { 440 Operator &op = tree.getDialectOp(opMap); 441 auto *namedAttr = op.getArg(index).get<NamedAttribute *>(); 442 const auto &attr = namedAttr->attr; 443 444 os.indent(indent) << "{\n"; 445 indent += 2; 446 os.indent(indent) << formatv( 447 "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, 448 attr.getStorageType(), namedAttr->name); 449 450 // TODO(antiagainst): This should use getter method to avoid duplication. 451 if (attr.hasDefaultValueInitializer()) { 452 os.indent(indent) << "if (!tblgen_attr) tblgen_attr = " 453 << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 454 attr.getDefaultValueInitializer()) 455 << ";\n"; 456 } else if (attr.isOptional()) { 457 // For a missing attribute that is optional according to definition, we 458 // should just capature a mlir::Attribute() to signal the missing state. 459 // That is precisely what getAttr() returns on missing attributes. 460 } else { 461 os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n"; 462 } 463 464 auto matcher = tree.getArgAsLeaf(index); 465 if (!matcher.isUnspecified()) { 466 if (!matcher.isAttrMatcher()) { 467 PrintFatalError( 468 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 469 op.getOperationName(), index + 1)); 470 } 471 472 // If a constraint is specified, we need to generate C++ statements to 473 // check the constraint. 474 os.indent(indent) << "if (!(" 475 << tgfmt(matcher.getConditionTemplate(), 476 &fmtCtx.withSelf("tblgen_attr")) 477 << ")) return matchFailure();\n"; 478 } 479 480 // Capture the value 481 auto name = tree.getArgName(index); 482 if (!name.empty()) { 483 os.indent(indent) << formatv("{0} = tblgen_attr;\n", name); 484 } 485 486 indent -= 2; 487 os.indent(indent) << "}\n"; 488 } 489 490 void PatternEmitter::emitMatchLogic(DagNode tree) { 491 emitOpMatch(tree, 0); 492 493 for (auto &appliedConstraint : pattern.getConstraints()) { 494 auto &constraint = appliedConstraint.constraint; 495 auto &entities = appliedConstraint.entities; 496 497 auto condition = constraint.getConditionTemplate(); 498 auto cmd = "if (!({0})) return matchFailure();\n"; 499 500 if (isa<TypeConstraint>(constraint)) { 501 auto self = formatv("({0}->getType())", resolveSymbol(entities.front())); 502 os.indent(4) << formatv(cmd, 503 tgfmt(condition, &fmtCtx.withSelf(self.str()))); 504 } else if (isa<AttrConstraint>(constraint)) { 505 PrintFatalError( 506 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 507 } else { 508 // TODO(b/138794486): replace formatv arguments with the exact specified 509 // args. 510 if (entities.size() > 4) { 511 PrintFatalError(loc, "only support up to 4-entity constraints now"); 512 } 513 SmallVector<std::string, 4> names; 514 int i = 0; 515 for (int e = entities.size(); i < e; ++i) 516 names.push_back(resolveSymbol(entities[i])); 517 std::string self = appliedConstraint.self; 518 if (!self.empty()) 519 self = resolveSymbol(self); 520 for (; i < 4; ++i) 521 names.push_back("<unused>"); 522 os.indent(4) << formatv(cmd, 523 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 524 names[1], names[2], names[3])); 525 } 526 } 527 } 528 529 void PatternEmitter::collectOps(DagNode tree, 530 llvm::SmallPtrSetImpl<const Operator *> &ops) { 531 // Check if this tree is an operation. 532 if (tree.isOperation()) 533 ops.insert(&tree.getDialectOp(opMap)); 534 535 // Recurse the arguments of the tree. 536 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 537 if (auto child = tree.getArgAsNestedDag(i)) 538 collectOps(child, ops); 539 } 540 541 void PatternEmitter::emit(StringRef rewriteName) { 542 // Get the DAG tree for the source pattern. 543 DagNode sourceTree = pattern.getSourcePattern(); 544 545 const Operator &rootOp = pattern.getSourceRootOp(); 546 auto rootName = rootOp.getOperationName(); 547 548 // Collect the set of result operations. 549 llvm::SmallPtrSet<const Operator *, 4> resultOps; 550 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) 551 collectOps(pattern.getResultPattern(i), resultOps); 552 553 // Emit RewritePattern for Pattern. 554 auto locs = pattern.getLocation(); 555 os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", 556 make_range(locs.rbegin(), locs.rend())); 557 os << formatv(R"(struct {0} : public RewritePattern { 558 {0}(MLIRContext *context) 559 : RewritePattern("{1}", {{)", 560 rewriteName, rootName); 561 interleaveComma(resultOps, os, [&](const Operator *op) { 562 os << '"' << op->getOperationName() << '"'; 563 }); 564 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 565 566 // Emit matchAndRewrite() function. 567 os << R"( 568 PatternMatchResult matchAndRewrite(Operation *op0, 569 PatternRewriter &rewriter) const override { 570 )"; 571 572 os.indent(4) << "// Variables for capturing values and attributes used for " 573 "creating ops\n"; 574 // Create local variables for storing the arguments bound to symbols. 575 for (const auto &arg : pattern.getSourcePatternBoundArgs()) { 576 auto fieldName = arg.first(); 577 if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) { 578 os.indent(4) << formatv("{0} {1};\n", namedAttr->attr.getStorageType(), 579 fieldName); 580 } else { 581 os.indent(4) << "Value *" << fieldName << ";\n"; 582 } 583 } 584 // Create local variables for storing the ops bound to symbols. 585 for (const auto &result : pattern.getSourcePatternBoundOps()) { 586 os.indent(4) << formatv("Operation *{0};\n", result.getKey()); 587 } 588 // TODO(jpienaar): capture ops with consistent numbering so that it can be 589 // reused for fused loc. 590 os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n", 591 pattern.getSourcePattern().getNumOps()); 592 593 os.indent(4) << "// Match\n"; 594 os.indent(4) << "tblgen_ops[0] = op0;\n"; 595 emitMatchLogic(sourceTree); 596 os << "\n"; 597 598 os.indent(4) << "// Rewrite\n"; 599 emitRewriteLogic(); 600 601 os.indent(4) << "return matchSuccess();\n"; 602 os << " };\n"; 603 os << "};\n"; 604 } 605 606 void PatternEmitter::emitRewriteLogic() { 607 const Operator &rootOp = pattern.getSourceRootOp(); 608 int numExpectedResults = rootOp.getNumResults(); 609 int numResultPatterns = pattern.getNumResultPatterns(); 610 611 // First register all symbols bound to ops generated in result patterns. 612 for (const auto &boundOp : pattern.getResultPatternBoundOps()) { 613 addSymbol(boundOp.getKey(), boundOp.getValue()->getNumResults()); 614 } 615 616 // Only the last N static values generated are used to replace the matched 617 // root N-result op. We need to calculate the starting index (of the results 618 // of the matched op) each result pattern is to replace. 619 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 620 int replStartIndex = -1; 621 for (int i = numResultPatterns - 1; i >= 0; --i) { 622 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 623 offsets[i] = offsets[i + 1] - numValues; 624 if (offsets[i] == 0) { 625 replStartIndex = i; 626 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 627 auto error = formatv( 628 "cannot use the same multi-result op '{0}' to generate both " 629 "auxiliary values and values to be used for replacing the matched op", 630 pattern.getResultPattern(i).getSymbol()); 631 PrintFatalError(loc, error); 632 } 633 } 634 635 if (offsets.front() > 0) { 636 const char error[] = "no enough values generated to replace the matched op"; 637 PrintFatalError(loc, error); 638 } 639 640 os.indent(4) << "auto loc = rewriter.getFusedLoc({"; 641 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 642 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 643 } 644 os << "}); (void)loc;\n"; 645 646 // Collect the replacement value for each result 647 llvm::SmallVector<std::string, 2> resultValues; 648 for (int i = 0; i < numResultPatterns; ++i) { 649 DagNode resultTree = pattern.getResultPattern(i); 650 resultValues.push_back(handleResultPattern(resultTree, offsets[i], 0)); 651 } 652 653 // Emit the final replaceOp() statement 654 os.indent(4) << "rewriter.replaceOp(op0, {"; 655 interleave( 656 ArrayRef<std::string>(resultValues).drop_front(replStartIndex), 657 [&](const std::string &name) { os << name; }, [&]() { os << ", "; }); 658 os << "});\n"; 659 } 660 661 std::string PatternEmitter::getUniqueValueName(const Operator *op) { 662 return formatv("v{0}{1}", op->getCppClassName(), nextValueId++); 663 } 664 665 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 666 int resultIndex, int depth) { 667 if (resultTree.isNativeCodeCall()) 668 return handleReplaceWithNativeCodeCall(resultTree); 669 670 if (resultTree.isReplaceWithValue()) 671 return handleReplaceWithValue(resultTree); 672 673 // Create the op and get the local variable for it. 674 auto results = handleOpCreation(resultTree, resultIndex, depth); 675 // We need to get all the values out of this local variable if we've created a 676 // multi-result op. 677 const auto &numResults = pattern.getDialectOp(resultTree).getNumResults(); 678 return formatValuePack("{0}.getOperation()->getResult({1})", results, 679 numResults, /*offset=*/0); 680 } 681 682 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { 683 assert(tree.isReplaceWithValue()); 684 685 if (tree.getNumArgs() != 1) { 686 PrintFatalError( 687 loc, "replaceWithValue directive must take exactly one argument"); 688 } 689 690 if (!tree.getSymbol().empty()) { 691 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 692 } 693 694 return resolveSymbol(tree.getArgName(0)); 695 } 696 697 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) { 698 if (leaf.isConstantAttr()) { 699 auto constAttr = leaf.getAsConstantAttr(); 700 return handleConstantAttr(constAttr.getAttribute(), 701 constAttr.getConstantValue()); 702 } 703 if (leaf.isEnumAttrCase()) { 704 auto enumCase = leaf.getAsEnumAttrCase(); 705 if (enumCase.isStrCase()) 706 return handleConstantAttr(enumCase, enumCase.getSymbol()); 707 // This is an enum case backed by an IntegerAttr. We need to get its value 708 // to build the constant. 709 std::string val = std::to_string(enumCase.getValue()); 710 return handleConstantAttr(enumCase, val); 711 } 712 pattern.ensureBoundInSourcePattern(argName); 713 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 714 return argName; 715 } 716 if (leaf.isNativeCodeCall()) { 717 return tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 718 } 719 PrintFatalError(loc, "unhandled case when rewriting op"); 720 } 721 722 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { 723 auto fmt = tree.getNativeCodeTemplate(); 724 // TODO(b/138794486): replace formatv arguments with the exact specified args. 725 SmallVector<std::string, 8> attrs(8); 726 if (tree.getNumArgs() > 8) { 727 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 728 Twine(tree.getNumArgs())); 729 } 730 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 731 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 732 } 733 return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4], 734 attrs[5], attrs[6], attrs[7]); 735 } 736 737 void PatternEmitter::addSymbol(StringRef symbol, int numValues) { 738 if (!symbolResolver.add(symbol, numValues)) 739 PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol)); 740 } 741 742 std::string PatternEmitter::resolveSymbol(StringRef symbol) { 743 auto subst = symbolResolver.query(symbol); 744 if (subst.empty()) 745 PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol)); 746 return subst; 747 } 748 749 int PatternEmitter::getNodeValueCount(DagNode node) { 750 if (node.isOperation()) { 751 // First to see whether this op is bound and we just want a specific result 752 // of it with `__N` suffix in symbol. 753 int count = symbolResolver.getValueCount(node.getSymbol()); 754 if (count >= 0) 755 return count; 756 757 // No symbol. Then we are using all the results. 758 return pattern.getDialectOp(node).getNumResults(); 759 } 760 // TODO(antiagainst): This considers all NativeCodeCall as returning one 761 // value. Enhance if multi-value ones are needed. 762 return 1; 763 } 764 765 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 766 int depth) { 767 Operator &resultOp = tree.getDialectOp(opMap); 768 auto numOpArgs = resultOp.getNumArgs(); 769 770 if (resultOp.isVariadic()) { 771 PrintFatalError(loc, formatv("generating op '{0}' with variadic " 772 "operands/results is unsupported now", 773 resultOp.getOperationName())); 774 } 775 776 if (numOpArgs != tree.getNumArgs()) { 777 PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " 778 "{1} in pattern vs. {2} in definition", 779 resultOp.getOperationName(), tree.getNumArgs(), 780 numOpArgs)); 781 } 782 783 // A map to collect all nested DAG child nodes' names, with operand index as 784 // the key. This includes both bound and unbound child nodes. Bound child 785 // nodes will additionally be tracked in `symbolResolver` so they can be 786 // referenced by other patterns. Unbound child nodes will only be used once 787 // to build this op. 788 llvm::DenseMap<unsigned, std::string> childNodeNames; 789 790 // First go through all the child nodes who are nested DAG constructs to 791 // create ops for them, so that we can use the results in the current node. 792 // This happens in a recursive manner. 793 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 794 if (auto child = tree.getArgAsNestedDag(i)) { 795 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 796 } 797 } 798 799 // Use the specified name for this op if available. Generate one otherwise. 800 std::string resultValue = tree.getSymbol(); 801 if (resultValue.empty()) 802 resultValue = getUniqueValueName(&resultOp); 803 // Strip the index to get the name for the value pack. This will be used to 804 // name the local variable for the op. 805 StringRef valuePackName = getValuePackName(resultValue); 806 807 // Then we build the new op corresponding to this DAG node. 808 809 // Right now we don't have general type inference in MLIR. Except a few 810 // special cases listed below, we need to supply types for all results 811 // when building an op. 812 bool isSameOperandsAndResultType = 813 resultOp.hasTrait("SameOperandsAndResultType"); 814 bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult"); 815 bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType"); 816 bool usePartialResults = valuePackName != resultValue; 817 818 if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr || 819 usePartialResults || depth > 0 || resultIndex < 0) { 820 os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", 821 valuePackName, resultOp.getQualCppClassName()); 822 } else { 823 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 824 // generated from the source pattern root op. Then we can use the source 825 // pattern's value types to determine the value type of the generated op 826 // here. 827 828 // We need to specify the types for all results. 829 auto resultTypes = 830 formatValuePack("op0->getResult({1})->getType()", valuePackName, 831 resultOp.getNumResults(), resultIndex); 832 833 os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", 834 valuePackName, resultOp.getQualCppClassName()) 835 << (resultTypes.empty() ? "" : ", ") << resultTypes; 836 } 837 838 // Create the builder call for the result. 839 // Add operands. 840 int i = 0; 841 for (int e = resultOp.getNumOperands(); i < e; ++i) { 842 const auto &operand = resultOp.getOperand(i); 843 844 // Start each operand on its own line. 845 (os << ",\n").indent(6); 846 847 if (!operand.name.empty()) 848 os << "/*" << operand.name << "=*/"; 849 850 if (tree.isNestedDagArg(i)) { 851 os << childNodeNames[i]; 852 } else { 853 DagLeaf leaf = tree.getArgAsLeaf(i); 854 auto symbol = resolveSymbol(tree.getArgName(i)); 855 if (leaf.isNativeCodeCall()) { 856 os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)); 857 } else { 858 os << symbol; 859 } 860 } 861 // TODO(jpienaar): verify types 862 } 863 864 // Add attributes. 865 for (int e = tree.getNumArgs(); i != e; ++i) { 866 // Start each attribute on its own line. 867 (os << ",\n").indent(6); 868 // The argument in the op definition. 869 auto opArgName = resultOp.getArgName(i); 870 if (auto subTree = tree.getArgAsNestedDag(i)) { 871 if (!subTree.isNativeCodeCall()) 872 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 873 "for creating attribute"); 874 os << formatv("/*{0}=*/{1}", opArgName, 875 handleReplaceWithNativeCodeCall(subTree)); 876 } else { 877 auto leaf = tree.getArgAsLeaf(i); 878 // The argument in the result DAG pattern. 879 auto patArgName = tree.getArgName(i); 880 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 881 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 882 auto argument = resultOp.getArg(i); 883 if (!argument.is<NamedAttribute *>()) 884 PrintFatalError(loc, Twine("expected attribute ") + Twine(i)); 885 if (!patArgName.empty()) 886 os << "/*" << patArgName << "=*/"; 887 } else { 888 os << "/*" << opArgName << "=*/"; 889 } 890 os << handleOpArgument(leaf, patArgName); 891 } 892 } 893 os << "\n );\n"; 894 895 return resultValue; 896 } 897 898 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 899 emitSourceFileHeader("Rewriters", os); 900 901 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 902 auto numPatterns = patterns.size(); 903 904 // We put the map here because it can be shared among multiple patterns. 905 RecordOperatorMap recordOpMap; 906 907 std::vector<std::string> rewriterNames; 908 rewriterNames.reserve(numPatterns); 909 910 std::string baseRewriterName = "GeneratedConvert"; 911 int rewriterIndex = 0; 912 913 for (Record *p : patterns) { 914 std::string name; 915 if (p->isAnonymous()) { 916 // If no name is provided, ensure unique rewriter names simply by 917 // appending unique suffix. 918 name = baseRewriterName + llvm::utostr(rewriterIndex++); 919 } else { 920 name = p->getName(); 921 } 922 PatternEmitter(p, &recordOpMap, os).emit(name); 923 rewriterNames.push_back(std::move(name)); 924 } 925 926 // Emit function to add the generated matchers to the pattern list. 927 os << "void populateWithGenerated(MLIRContext *context, " 928 << "OwningRewritePatternList *patterns) {\n"; 929 for (const auto &name : rewriterNames) { 930 os << " patterns->insert<" << name << ">(context);\n"; 931 } 932 os << "}\n"; 933 } 934 935 static mlir::GenRegistration 936 genRewriters("gen-rewriters", "Generate pattern rewriters", 937 [](const RecordKeeper &records, raw_ostream &os) { 938 emitRewriters(records, os); 939 return false; 940 }); 941