1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Support/STLExtras.h" 14 #include "mlir/TableGen/Attribute.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/GenInfo.h" 17 #include "mlir/TableGen/Operator.h" 18 #include "mlir/TableGen/Pattern.h" 19 #include "mlir/TableGen/Predicate.h" 20 #include "mlir/TableGen/Type.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/StringSet.h" 23 #include "llvm/Support/CommandLine.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/FormatAdapters.h" 26 #include "llvm/Support/PrettyStackTrace.h" 27 #include "llvm/Support/Signals.h" 28 #include "llvm/TableGen/Error.h" 29 #include "llvm/TableGen/Main.h" 30 #include "llvm/TableGen/Record.h" 31 #include "llvm/TableGen/TableGenBackend.h" 32 33 using namespace mlir; 34 using namespace mlir::tblgen; 35 36 using llvm::formatv; 37 using llvm::Record; 38 using llvm::RecordKeeper; 39 40 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 41 42 namespace llvm { 43 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 44 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 45 raw_ostream &os, StringRef style) { 46 os << v.first << ":" << v.second; 47 } 48 }; 49 } // end namespace llvm 50 51 //===----------------------------------------------------------------------===// 52 // PatternEmitter 53 //===----------------------------------------------------------------------===// 54 55 namespace { 56 class PatternEmitter { 57 public: 58 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 59 60 // Emits the mlir::RewritePattern struct named `rewriteName`. 61 void emit(StringRef rewriteName); 62 63 private: 64 // Emits the code for matching ops. 65 void emitMatchLogic(DagNode tree); 66 67 // Emits the code for rewriting ops. 68 void emitRewriteLogic(); 69 70 //===--------------------------------------------------------------------===// 71 // Match utilities 72 //===--------------------------------------------------------------------===// 73 74 // Emits C++ statements for matching the op constrained by the given DAG 75 // `tree`. 76 void emitOpMatch(DagNode tree, int depth); 77 78 // Emits C++ statements for matching the `argIndex`-th argument of the given 79 // DAG `tree` as an operand. 80 void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent); 81 82 // Emits C++ statements for matching the `argIndex`-th argument of the given 83 // DAG `tree` as an attribute. 84 void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent); 85 86 //===--------------------------------------------------------------------===// 87 // Rewrite utilities 88 //===--------------------------------------------------------------------===// 89 90 // The entry point for handling a result pattern rooted at `resultTree`. This 91 // method dispatches to concrete handlers according to `resultTree`'s kind and 92 // returns a symbol representing the whole value pack. Callers are expected to 93 // further resolve the symbol according to the specific use case. 94 // 95 // `depth` is the nesting level of `resultTree`; 0 means top-level result 96 // pattern. For top-level result pattern, `resultIndex` indicates which result 97 // of the matched root op this pattern is intended to replace, which can be 98 // used to deduce the result type of the op generated from this result 99 // pattern. 100 std::string handleResultPattern(DagNode resultTree, int resultIndex, 101 int depth); 102 103 // Emits the C++ statement to replace the matched DAG with a value built via 104 // calling native C++ code. 105 std::string handleReplaceWithNativeCodeCall(DagNode resultTree); 106 107 // Returns the C++ expression referencing the old value serving as the 108 // replacement. 109 std::string handleReplaceWithValue(DagNode tree); 110 111 // Emits the C++ statement to build a new op out of the given DAG `tree` and 112 // returns the variable name that this op is assigned to. If the root op in 113 // DAG `tree` has a specified name, the created op will be assigned to a 114 // variable of the given name. Otherwise, a unique name will be used as the 115 // result value name. 116 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 117 118 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; 119 120 // Emits a local variable for each value and attribute to be used for creating 121 // an op. 122 void createSeparateLocalVarsForOpArgs(DagNode node, 123 ChildNodeIndexNameMap &childNodeNames); 124 125 // Emits the concrete arguments used to call a op's builder. 126 void supplyValuesForOpArgs(DagNode node, 127 const ChildNodeIndexNameMap &childNodeNames); 128 129 // Emits the local variables for holding all values as a whole and all named 130 // attributes as a whole to be used for creating an op. 131 void createAggregateLocalVarsForOpArgs( 132 DagNode node, const ChildNodeIndexNameMap &childNodeNames); 133 134 // Returns the C++ expression to construct a constant attribute of the given 135 // `value` for the given attribute kind `attr`. 136 std::string handleConstantAttr(Attribute attr, StringRef value); 137 138 // Returns the C++ expression to build an argument from the given DAG `leaf`. 139 // `patArgName` is used to bound the argument to the source pattern. 140 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 141 142 //===--------------------------------------------------------------------===// 143 // General utilities 144 //===--------------------------------------------------------------------===// 145 146 // Collects all of the operations within the given dag tree. 147 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 148 149 // Returns a unique symbol for a local variable of the given `op`. 150 std::string getUniqueSymbol(const Operator *op); 151 152 //===--------------------------------------------------------------------===// 153 // Symbol utilities 154 //===--------------------------------------------------------------------===// 155 156 // Returns how many static values the given DAG `node` correspond to. 157 int getNodeValueCount(DagNode node); 158 159 private: 160 // Pattern instantiation location followed by the location of multiclass 161 // prototypes used. This is intended to be used as a whole to 162 // PrintFatalError() on errors. 163 ArrayRef<llvm::SMLoc> loc; 164 165 // Op's TableGen Record to wrapper object. 166 RecordOperatorMap *opMap; 167 168 // Handy wrapper for pattern being emitted. 169 Pattern pattern; 170 171 // Map for all bound symbols' info. 172 SymbolInfoMap symbolInfoMap; 173 174 // The next unused ID for newly created values. 175 unsigned nextValueId; 176 177 raw_ostream &os; 178 179 // Format contexts containing placeholder substitutions. 180 FmtContext fmtCtx; 181 182 // Number of op processed. 183 int opCounter = 0; 184 }; 185 } // end anonymous namespace 186 187 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 188 raw_ostream &os) 189 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 190 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 191 fmtCtx.withBuilder("rewriter"); 192 } 193 194 std::string PatternEmitter::handleConstantAttr(Attribute attr, 195 StringRef value) { 196 if (!attr.isConstBuildable()) 197 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 198 " does not have the 'constBuilderCall' field"); 199 200 // TODO(jpienaar): Verify the constants here 201 return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value); 202 } 203 204 // Helper function to match patterns. 205 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 206 Operator &op = tree.getDialectOp(opMap); 207 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 208 << op.getOperationName() << "' at depth " << depth 209 << '\n'); 210 211 int indent = 4 + 2 * depth; 212 os.indent(indent) << formatv( 213 "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n", 214 depth, op.getQualCppClassName()); 215 // Skip the operand matching at depth 0 as the pattern rewriter already does. 216 if (depth != 0) { 217 // Skip if there is no defining operation (e.g., arguments to function). 218 os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n", 219 depth); 220 } 221 if (tree.getNumArgs() != op.getNumArgs()) { 222 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 223 "pattern vs. {2} in definition", 224 op.getOperationName(), tree.getNumArgs(), 225 op.getNumArgs())); 226 } 227 228 // If the operand's name is set, set to that variable. 229 auto name = tree.getSymbol(); 230 if (!name.empty()) 231 os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth); 232 233 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 234 auto opArg = op.getArg(i); 235 236 // Handle nested DAG construct first 237 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 238 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 239 if (operand->isVariadic()) { 240 auto error = formatv("use nested DAG construct to match op {0}'s " 241 "variadic operand #{1} unsupported now", 242 op.getOperationName(), i); 243 PrintFatalError(loc, error); 244 } 245 } 246 os.indent(indent) << "{\n"; 247 248 os.indent(indent + 2) << formatv( 249 "auto *op{0} = " 250 "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n", 251 depth + 1, depth, i); 252 emitOpMatch(argTree, depth + 1); 253 os.indent(indent + 2) 254 << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); 255 os.indent(indent) << "}\n"; 256 continue; 257 } 258 259 // Next handle DAG leaf: operand or attribute 260 if (opArg.is<NamedTypeConstraint *>()) { 261 emitOperandMatch(tree, i, depth, indent); 262 } else if (opArg.is<NamedAttribute *>()) { 263 emitAttributeMatch(tree, i, depth, indent); 264 } else { 265 PrintFatalError(loc, "unhandled case when matching op"); 266 } 267 } 268 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 269 << op.getOperationName() << "' at depth " << depth 270 << '\n'); 271 } 272 273 void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth, 274 int indent) { 275 Operator &op = tree.getDialectOp(opMap); 276 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>(); 277 auto matcher = tree.getArgAsLeaf(argIndex); 278 279 // If a constraint is specified, we need to generate C++ statements to 280 // check the constraint. 281 if (!matcher.isUnspecified()) { 282 if (!matcher.isOperandMatcher()) { 283 PrintFatalError( 284 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 285 op.getOperationName(), argIndex + 1)); 286 } 287 288 // Only need to verify if the matcher's type is different from the one 289 // of op definition. 290 if (operand->constraint != matcher.getAsConstraint()) { 291 if (operand->isVariadic()) { 292 auto error = formatv( 293 "further constrain op {0}'s variadic operand #{1} unsupported now", 294 op.getOperationName(), argIndex); 295 PrintFatalError(loc, error); 296 } 297 auto self = 298 formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()", 299 depth, argIndex); 300 os.indent(indent) << "if (!(" 301 << tgfmt(matcher.getConditionTemplate(), 302 &fmtCtx.withSelf(self)) 303 << ")) return matchFailure();\n"; 304 } 305 } 306 307 // Capture the value 308 auto name = tree.getArgName(argIndex); 309 // `$_` is a special symbol to ignore op argument matching. 310 if (!name.empty() && name != "_") { 311 // We need to subtract the number of attributes before this operand to get 312 // the index in the operand list. 313 auto numPrevAttrs = std::count_if( 314 op.arg_begin(), op.arg_begin() + argIndex, 315 [](const Argument &arg) { return arg.is<NamedAttribute *>(); }); 316 317 os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n", 318 name, depth, argIndex - numPrevAttrs); 319 } 320 } 321 322 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth, 323 int indent) { 324 325 Operator &op = tree.getDialectOp(opMap); 326 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>(); 327 const auto &attr = namedAttr->attr; 328 329 os.indent(indent) << "{\n"; 330 indent += 2; 331 os.indent(indent) << formatv( 332 "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, 333 attr.getStorageType(), namedAttr->name); 334 335 // TODO(antiagainst): This should use getter method to avoid duplication. 336 if (attr.hasDefaultValue()) { 337 os.indent(indent) << "if (!tblgen_attr) tblgen_attr = " 338 << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 339 attr.getDefaultValue()) 340 << ";\n"; 341 } else if (attr.isOptional()) { 342 // For a missing attribute that is optional according to definition, we 343 // should just capture a mlir::Attribute() to signal the missing state. 344 // That is precisely what getAttr() returns on missing attributes. 345 } else { 346 os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n"; 347 } 348 349 auto matcher = tree.getArgAsLeaf(argIndex); 350 if (!matcher.isUnspecified()) { 351 if (!matcher.isAttrMatcher()) { 352 PrintFatalError( 353 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 354 op.getOperationName(), argIndex + 1)); 355 } 356 357 // If a constraint is specified, we need to generate C++ statements to 358 // check the constraint. 359 os.indent(indent) << "if (!(" 360 << tgfmt(matcher.getConditionTemplate(), 361 &fmtCtx.withSelf("tblgen_attr")) 362 << ")) return matchFailure();\n"; 363 } 364 365 // Capture the value 366 auto name = tree.getArgName(argIndex); 367 // `$_` is a special symbol to ignore op argument matching. 368 if (!name.empty() && name != "_") { 369 os.indent(indent) << formatv("{0} = tblgen_attr;\n", name); 370 } 371 372 indent -= 2; 373 os.indent(indent) << "}\n"; 374 } 375 376 void PatternEmitter::emitMatchLogic(DagNode tree) { 377 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 378 emitOpMatch(tree, 0); 379 380 for (auto &appliedConstraint : pattern.getConstraints()) { 381 auto &constraint = appliedConstraint.constraint; 382 auto &entities = appliedConstraint.entities; 383 384 auto condition = constraint.getConditionTemplate(); 385 auto cmd = "if (!({0})) return matchFailure();\n"; 386 387 if (isa<TypeConstraint>(constraint)) { 388 auto self = formatv("({0}->getType())", 389 symbolInfoMap.getValueAndRangeUse(entities.front())); 390 os.indent(4) << formatv(cmd, 391 tgfmt(condition, &fmtCtx.withSelf(self.str()))); 392 } else if (isa<AttrConstraint>(constraint)) { 393 PrintFatalError( 394 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 395 } else { 396 // TODO(b/138794486): replace formatv arguments with the exact specified 397 // args. 398 if (entities.size() > 4) { 399 PrintFatalError(loc, "only support up to 4-entity constraints now"); 400 } 401 SmallVector<std::string, 4> names; 402 int i = 0; 403 for (int e = entities.size(); i < e; ++i) 404 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 405 std::string self = appliedConstraint.self; 406 if (!self.empty()) 407 self = symbolInfoMap.getValueAndRangeUse(self); 408 for (; i < 4; ++i) 409 names.push_back("<unused>"); 410 os.indent(4) << formatv(cmd, 411 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 412 names[1], names[2], names[3])); 413 } 414 } 415 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 416 } 417 418 void PatternEmitter::collectOps(DagNode tree, 419 llvm::SmallPtrSetImpl<const Operator *> &ops) { 420 // Check if this tree is an operation. 421 if (tree.isOperation()) { 422 const Operator &op = tree.getDialectOp(opMap); 423 LLVM_DEBUG(llvm::dbgs() 424 << "found operation " << op.getOperationName() << '\n'); 425 ops.insert(&op); 426 } 427 428 // Recurse the arguments of the tree. 429 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 430 if (auto child = tree.getArgAsNestedDag(i)) 431 collectOps(child, ops); 432 } 433 434 void PatternEmitter::emit(StringRef rewriteName) { 435 // Get the DAG tree for the source pattern. 436 DagNode sourceTree = pattern.getSourcePattern(); 437 438 const Operator &rootOp = pattern.getSourceRootOp(); 439 auto rootName = rootOp.getOperationName(); 440 441 // Collect the set of result operations. 442 llvm::SmallPtrSet<const Operator *, 4> resultOps; 443 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 444 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 445 collectOps(pattern.getResultPattern(i), resultOps); 446 } 447 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 448 449 // Emit RewritePattern for Pattern. 450 auto locs = pattern.getLocation(); 451 os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", 452 make_range(locs.rbegin(), locs.rend())); 453 os << formatv(R"(struct {0} : public RewritePattern { 454 {0}(MLIRContext *context) 455 : RewritePattern("{1}", {{)", 456 rewriteName, rootName); 457 // Sort result operators by name. 458 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 459 resultOps.end()); 460 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 461 return lhs->getOperationName() < rhs->getOperationName(); 462 }); 463 interleaveComma(sortedResultOps, os, [&](const Operator *op) { 464 os << '"' << op->getOperationName() << '"'; 465 }); 466 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 467 468 // Emit matchAndRewrite() function. 469 os << R"( 470 PatternMatchResult matchAndRewrite(Operation *op0, 471 PatternRewriter &rewriter) const override { 472 )"; 473 474 // Register all symbols bound in the source pattern. 475 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 476 477 LLVM_DEBUG( 478 llvm::dbgs() << "start creating local variables for capturing matches\n"); 479 os.indent(4) << "// Variables for capturing values and attributes used for " 480 "creating ops\n"; 481 // Create local variables for storing the arguments and results bound 482 // to symbols. 483 for (const auto &symbolInfoPair : symbolInfoMap) { 484 StringRef symbol = symbolInfoPair.getKey(); 485 auto &info = symbolInfoPair.getValue(); 486 os.indent(4) << info.getVarDecl(symbol); 487 } 488 // TODO(jpienaar): capture ops with consistent numbering so that it can be 489 // reused for fused loc. 490 os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n", 491 pattern.getSourcePattern().getNumOps()); 492 LLVM_DEBUG( 493 llvm::dbgs() << "done creating local variables for capturing matches\n"); 494 495 os.indent(4) << "// Match\n"; 496 os.indent(4) << "tblgen_ops[0] = op0;\n"; 497 emitMatchLogic(sourceTree); 498 os << "\n"; 499 500 os.indent(4) << "// Rewrite\n"; 501 emitRewriteLogic(); 502 503 os.indent(4) << "return matchSuccess();\n"; 504 os << " };\n"; 505 os << "};\n"; 506 } 507 508 void PatternEmitter::emitRewriteLogic() { 509 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 510 const Operator &rootOp = pattern.getSourceRootOp(); 511 int numExpectedResults = rootOp.getNumResults(); 512 int numResultPatterns = pattern.getNumResultPatterns(); 513 514 // First register all symbols bound to ops generated in result patterns. 515 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 516 517 // Only the last N static values generated are used to replace the matched 518 // root N-result op. We need to calculate the starting index (of the results 519 // of the matched op) each result pattern is to replace. 520 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 521 // If we don't need to replace any value at all, set the replacement starting 522 // index as the number of result patterns so we skip all of them when trying 523 // to replace the matched op's results. 524 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 525 for (int i = numResultPatterns - 1; i >= 0; --i) { 526 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 527 offsets[i] = offsets[i + 1] - numValues; 528 if (offsets[i] == 0) { 529 if (replStartIndex == -1) 530 replStartIndex = i; 531 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 532 auto error = formatv( 533 "cannot use the same multi-result op '{0}' to generate both " 534 "auxiliary values and values to be used for replacing the matched op", 535 pattern.getResultPattern(i).getSymbol()); 536 PrintFatalError(loc, error); 537 } 538 } 539 540 if (offsets.front() > 0) { 541 const char error[] = "no enough values generated to replace the matched op"; 542 PrintFatalError(loc, error); 543 } 544 545 os.indent(4) << "auto loc = rewriter.getFusedLoc({"; 546 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 547 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 548 } 549 os << "}); (void)loc;\n"; 550 551 // Process auxiliary result patterns. 552 for (int i = 0; i < replStartIndex; ++i) { 553 DagNode resultTree = pattern.getResultPattern(i); 554 auto val = handleResultPattern(resultTree, offsets[i], 0); 555 // Normal op creation will be streamed to `os` by the above call; but 556 // NativeCodeCall will only be materialized to `os` if it is used. Here 557 // we are handling auxiliary patterns so we want the side effect even if 558 // NativeCodeCall is not replacing matched root op's results. 559 if (resultTree.isNativeCodeCall()) 560 os.indent(4) << val << ";\n"; 561 } 562 563 if (numExpectedResults == 0) { 564 assert(replStartIndex >= numResultPatterns && 565 "invalid auxiliary vs. replacement pattern division!"); 566 // No result to replace. Just erase the op. 567 os.indent(4) << "rewriter.eraseOp(op0);\n"; 568 } else { 569 // Process replacement result patterns. 570 os.indent(4) << "SmallVector<Value, 4> tblgen_repl_values;\n"; 571 for (int i = replStartIndex; i < numResultPatterns; ++i) { 572 DagNode resultTree = pattern.getResultPattern(i); 573 auto val = handleResultPattern(resultTree, offsets[i], 0); 574 os.indent(4) << "\n"; 575 // Resolve each symbol for all range use so that we can loop over them. 576 os << symbolInfoMap.getAllRangeUse( 577 val, " for (auto v : {0}) {{ tblgen_repl_values.push_back(v); }", 578 "\n"); 579 } 580 os.indent(4) << "\n"; 581 os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n"; 582 } 583 584 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 585 } 586 587 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 588 return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++); 589 } 590 591 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 592 int resultIndex, int depth) { 593 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 594 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 595 LLVM_DEBUG(llvm::dbgs() << '\n'); 596 597 if (resultTree.isNativeCodeCall()) { 598 auto symbol = handleReplaceWithNativeCodeCall(resultTree); 599 symbolInfoMap.bindValue(symbol); 600 return symbol; 601 } 602 603 if (resultTree.isReplaceWithValue()) { 604 return handleReplaceWithValue(resultTree); 605 } 606 607 // Normal op creation. 608 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 609 if (resultTree.getSymbol().empty()) { 610 // This is an op not explicitly bound to a symbol in the rewrite rule. 611 // Register the auto-generated symbol for it. 612 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 613 } 614 return symbol; 615 } 616 617 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { 618 assert(tree.isReplaceWithValue()); 619 620 if (tree.getNumArgs() != 1) { 621 PrintFatalError( 622 loc, "replaceWithValue directive must take exactly one argument"); 623 } 624 625 if (!tree.getSymbol().empty()) { 626 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 627 } 628 629 return tree.getArgName(0); 630 } 631 632 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 633 StringRef patArgName) { 634 if (leaf.isConstantAttr()) { 635 auto constAttr = leaf.getAsConstantAttr(); 636 return handleConstantAttr(constAttr.getAttribute(), 637 constAttr.getConstantValue()); 638 } 639 if (leaf.isEnumAttrCase()) { 640 auto enumCase = leaf.getAsEnumAttrCase(); 641 if (enumCase.isStrCase()) 642 return handleConstantAttr(enumCase, enumCase.getSymbol()); 643 // This is an enum case backed by an IntegerAttr. We need to get its value 644 // to build the constant. 645 std::string val = std::to_string(enumCase.getValue()); 646 return handleConstantAttr(enumCase, val); 647 } 648 649 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 650 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 651 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 652 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 653 << "' (via symbol ref)\n"); 654 return argName; 655 } 656 if (leaf.isNativeCodeCall()) { 657 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 658 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 659 << "' (via NativeCodeCall)\n"); 660 return repl; 661 } 662 PrintFatalError(loc, "unhandled case when rewriting op"); 663 } 664 665 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { 666 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 667 LLVM_DEBUG(tree.print(llvm::dbgs())); 668 LLVM_DEBUG(llvm::dbgs() << '\n'); 669 670 auto fmt = tree.getNativeCodeTemplate(); 671 // TODO(b/138794486): replace formatv arguments with the exact specified args. 672 SmallVector<std::string, 8> attrs(8); 673 if (tree.getNumArgs() > 8) { 674 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 675 Twine(tree.getNumArgs())); 676 } 677 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 678 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 679 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i 680 << " replacement: " << attrs[i] << "\n"); 681 } 682 return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4], 683 attrs[5], attrs[6], attrs[7]); 684 } 685 686 int PatternEmitter::getNodeValueCount(DagNode node) { 687 if (node.isOperation()) { 688 // If the op is bound to a symbol in the rewrite rule, query its result 689 // count from the symbol info map. 690 auto symbol = node.getSymbol(); 691 if (!symbol.empty()) { 692 return symbolInfoMap.getStaticValueCount(symbol); 693 } 694 // Otherwise this is an unbound op; we will use all its results. 695 return pattern.getDialectOp(node).getNumResults(); 696 } 697 // TODO(antiagainst): This considers all NativeCodeCall as returning one 698 // value. Enhance if multi-value ones are needed. 699 return 1; 700 } 701 702 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 703 int depth) { 704 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); 705 LLVM_DEBUG(tree.print(llvm::dbgs())); 706 LLVM_DEBUG(llvm::dbgs() << '\n'); 707 708 Operator &resultOp = tree.getDialectOp(opMap); 709 auto numOpArgs = resultOp.getNumArgs(); 710 711 if (numOpArgs != tree.getNumArgs()) { 712 PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " 713 "{1} in pattern vs. {2} in definition", 714 resultOp.getOperationName(), tree.getNumArgs(), 715 numOpArgs)); 716 } 717 718 // A map to collect all nested DAG child nodes' names, with operand index as 719 // the key. This includes both bound and unbound child nodes. 720 ChildNodeIndexNameMap childNodeNames; 721 722 // First go through all the child nodes who are nested DAG constructs to 723 // create ops for them and remember the symbol names for them, so that we can 724 // use the results in the current node. This happens in a recursive manner. 725 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 726 if (auto child = tree.getArgAsNestedDag(i)) { 727 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 728 } 729 } 730 731 // The name of the local variable holding this op. 732 std::string valuePackName; 733 // The symbol for holding the result of this pattern. Note that the result of 734 // this pattern is not necessarily the same as the variable created by this 735 // pattern because we can use `__N` suffix to refer only a specific result if 736 // the generated op is a multi-result op. 737 std::string resultValue; 738 if (tree.getSymbol().empty()) { 739 // No symbol is explicitly bound to this op in the pattern. Generate a 740 // unique name. 741 valuePackName = resultValue = getUniqueSymbol(&resultOp); 742 } else { 743 resultValue = tree.getSymbol(); 744 // Strip the index to get the name for the value pack and use it to name the 745 // local variable for the op. 746 valuePackName = SymbolInfoMap::getValuePackName(resultValue); 747 } 748 749 // Create the local variable for this op. 750 os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(), 751 valuePackName); 752 os.indent(4) << "{\n"; 753 754 // Right now ODS don't have general type inference support. Except a few 755 // special cases listed below, DRR needs to supply types for all results 756 // when building an op. 757 bool isSameOperandsAndResultType = 758 resultOp.getTrait("OpTrait::SameOperandsAndResultType"); 759 bool useFirstAttr = resultOp.getTrait("OpTrait::FirstAttrDerivedResultType"); 760 761 if (isSameOperandsAndResultType || useFirstAttr) { 762 // We know how to deduce the result type for ops with these traits and we've 763 // generated builders taking aggregate parameters. Use those builders to 764 // create the ops. 765 766 // First prepare local variables for op arguments used in builder call. 767 createAggregateLocalVarsForOpArgs(tree, childNodeNames); 768 // Then create the op. 769 os.indent(6) << formatv( 770 "{0} = rewriter.create<{1}>(loc, tblgen_values, tblgen_attrs);\n", 771 valuePackName, resultOp.getQualCppClassName()); 772 os.indent(4) << "}\n"; 773 return resultValue; 774 } 775 776 bool isBroadcastable = 777 resultOp.getTrait("OpTrait::BroadcastableTwoOperandsOneResult"); 778 bool usePartialResults = valuePackName != resultValue; 779 780 if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) { 781 // For these cases (broadcastable ops, op results used both as auxiliary 782 // values and replacement values, ops in nested patterns, auxiliary ops), we 783 // still need to supply the result types when building the op. But because 784 // we don't generate a builder automatically with ODS for them, it's the 785 // developer's responsiblity to make sure such a builder (with result type 786 // deduction ability) exists. We go through the separate-parameter builder 787 // here given that it's easier for developers to write compared to 788 // aggregate-parameter builders. 789 createSeparateLocalVarsForOpArgs(tree, childNodeNames); 790 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, 791 resultOp.getQualCppClassName()); 792 supplyValuesForOpArgs(tree, childNodeNames); 793 os << "\n );\n"; 794 os.indent(4) << "}\n"; 795 return resultValue; 796 } 797 798 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 799 // generated from the source pattern root op. Then we can use the source 800 // pattern's value types to determine the value type of the generated op 801 // here. 802 803 // First prepare local variables for op arguments used in builder call. 804 createAggregateLocalVarsForOpArgs(tree, childNodeNames); 805 806 // Then prepare the result types. We need to specify the types for all 807 // results. 808 os.indent(6) << formatv( 809 "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n"); 810 int numResults = resultOp.getNumResults(); 811 if (numResults != 0) { 812 for (int i = 0; i < numResults; ++i) 813 os.indent(6) << formatv("for (auto v : castedOp0.getODSResults({0})) {{" 814 "tblgen_types.push_back(v->getType()); }\n", 815 resultIndex + i); 816 } 817 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc, tblgen_types, " 818 "tblgen_values, tblgen_attrs);\n", 819 valuePackName, resultOp.getQualCppClassName()); 820 os.indent(4) << "}\n"; 821 return resultValue; 822 } 823 824 void PatternEmitter::createSeparateLocalVarsForOpArgs( 825 DagNode node, ChildNodeIndexNameMap &childNodeNames) { 826 Operator &resultOp = node.getDialectOp(opMap); 827 828 // Now prepare operands used for building this op: 829 // * If the operand is non-variadic, we create a `Value` local variable. 830 // * If the operand is variadic, we create a `SmallVector<Value>` local 831 // variable. 832 833 int valueIndex = 0; // An index for uniquing local variable names. 834 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 835 const auto *operand = 836 resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>(); 837 if (!operand) { 838 // We do not need special handling for attributes. 839 continue; 840 } 841 842 std::string varName; 843 if (operand->isVariadic()) { 844 varName = formatv("tblgen_values_{0}", valueIndex++); 845 os.indent(6) << formatv("SmallVector<Value, 4> {0};\n", varName); 846 std::string range; 847 if (node.isNestedDagArg(argIndex)) { 848 range = childNodeNames[argIndex]; 849 } else { 850 range = node.getArgName(argIndex); 851 } 852 // Resolve the symbol for all range use so that we have a uniform way of 853 // capturing the values. 854 range = symbolInfoMap.getValueAndRangeUse(range); 855 os.indent(6) << formatv("for (auto v : {0}) {1}.push_back(v);\n", range, 856 varName); 857 } else { 858 varName = formatv("tblgen_value_{0}", valueIndex++); 859 os.indent(6) << formatv("Value {0} = ", varName); 860 if (node.isNestedDagArg(argIndex)) { 861 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 862 } else { 863 DagLeaf leaf = node.getArgAsLeaf(argIndex); 864 auto symbol = 865 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 866 if (leaf.isNativeCodeCall()) { 867 os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)); 868 } else { 869 os << symbol; 870 } 871 } 872 os << ";\n"; 873 } 874 875 // Update to use the newly created local variable for building the op later. 876 childNodeNames[argIndex] = varName; 877 } 878 } 879 880 void PatternEmitter::supplyValuesForOpArgs( 881 DagNode node, const ChildNodeIndexNameMap &childNodeNames) { 882 Operator &resultOp = node.getDialectOp(opMap); 883 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); 884 argIndex != numOpArgs; ++argIndex) { 885 // Start each argument on its own line. 886 (os << ",\n").indent(8); 887 888 Argument opArg = resultOp.getArg(argIndex); 889 // Handle the case of operand first. 890 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 891 if (!operand->name.empty()) 892 os << "/*" << operand->name << "=*/"; 893 os << childNodeNames.lookup(argIndex); 894 continue; 895 } 896 897 // The argument in the op definition. 898 auto opArgName = resultOp.getArgName(argIndex); 899 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 900 if (!subTree.isNativeCodeCall()) 901 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 902 "for creating attribute"); 903 os << formatv("/*{0}=*/{1}", opArgName, 904 handleReplaceWithNativeCodeCall(subTree)); 905 } else { 906 auto leaf = node.getArgAsLeaf(argIndex); 907 // The argument in the result DAG pattern. 908 auto patArgName = node.getArgName(argIndex); 909 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 910 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 911 if (!opArg.is<NamedAttribute *>()) 912 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 913 if (!patArgName.empty()) 914 os << "/*" << patArgName << "=*/"; 915 } else { 916 os << "/*" << opArgName << "=*/"; 917 } 918 os << handleOpArgument(leaf, patArgName); 919 } 920 } 921 } 922 923 void PatternEmitter::createAggregateLocalVarsForOpArgs( 924 DagNode node, const ChildNodeIndexNameMap &childNodeNames) { 925 Operator &resultOp = node.getDialectOp(opMap); 926 927 os.indent(6) << formatv( 928 "SmallVector<Value, 4> tblgen_values; (void)tblgen_values;\n"); 929 os.indent(6) << formatv( 930 "SmallVector<NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs;\n"); 931 932 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 933 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) { 934 const char *addAttrCmd = "if ({1}) {{" 935 " tblgen_attrs.emplace_back(rewriter." 936 "getIdentifier(\"{0}\"), {1}); }\n"; 937 // The argument in the op definition. 938 auto opArgName = resultOp.getArgName(argIndex); 939 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 940 if (!subTree.isNativeCodeCall()) 941 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 942 "for creating attribute"); 943 os.indent(6) << formatv(addAttrCmd, opArgName, 944 handleReplaceWithNativeCodeCall(subTree)); 945 } else { 946 auto leaf = node.getArgAsLeaf(argIndex); 947 // The argument in the result DAG pattern. 948 auto patArgName = node.getArgName(argIndex); 949 os.indent(6) << formatv(addAttrCmd, opArgName, 950 handleOpArgument(leaf, patArgName)); 951 } 952 continue; 953 } 954 955 const auto *operand = 956 resultOp.getArg(argIndex).get<NamedTypeConstraint *>(); 957 std::string varName; 958 if (operand->isVariadic()) { 959 std::string range; 960 if (node.isNestedDagArg(argIndex)) { 961 range = childNodeNames.lookup(argIndex); 962 } else { 963 range = node.getArgName(argIndex); 964 } 965 // Resolve the symbol for all range use so that we have a uniform way of 966 // capturing the values. 967 range = symbolInfoMap.getValueAndRangeUse(range); 968 os.indent(6) << formatv( 969 "for (auto v : {0}) tblgen_values.push_back(v);\n", range); 970 } else { 971 os.indent(6) << formatv("tblgen_values.push_back(", varName); 972 if (node.isNestedDagArg(argIndex)) { 973 os << symbolInfoMap.getValueAndRangeUse( 974 childNodeNames.lookup(argIndex)); 975 } else { 976 DagLeaf leaf = node.getArgAsLeaf(argIndex); 977 auto symbol = 978 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 979 if (leaf.isNativeCodeCall()) { 980 os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)); 981 } else { 982 os << symbol; 983 } 984 } 985 os << ");\n"; 986 } 987 } 988 } 989 990 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 991 emitSourceFileHeader("Rewriters", os); 992 993 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 994 auto numPatterns = patterns.size(); 995 996 // We put the map here because it can be shared among multiple patterns. 997 RecordOperatorMap recordOpMap; 998 999 std::vector<std::string> rewriterNames; 1000 rewriterNames.reserve(numPatterns); 1001 1002 std::string baseRewriterName = "GeneratedConvert"; 1003 int rewriterIndex = 0; 1004 1005 for (Record *p : patterns) { 1006 std::string name; 1007 if (p->isAnonymous()) { 1008 // If no name is provided, ensure unique rewriter names simply by 1009 // appending unique suffix. 1010 name = baseRewriterName + llvm::utostr(rewriterIndex++); 1011 } else { 1012 name = p->getName(); 1013 } 1014 LLVM_DEBUG(llvm::dbgs() 1015 << "=== start generating pattern '" << name << "' ===\n"); 1016 PatternEmitter(p, &recordOpMap, os).emit(name); 1017 LLVM_DEBUG(llvm::dbgs() 1018 << "=== done generating pattern '" << name << "' ===\n"); 1019 rewriterNames.push_back(std::move(name)); 1020 } 1021 1022 // Emit function to add the generated matchers to the pattern list. 1023 os << "void populateWithGenerated(MLIRContext *context, " 1024 << "OwningRewritePatternList *patterns) {\n"; 1025 for (const auto &name : rewriterNames) { 1026 os << " patterns->insert<" << name << ">(context);\n"; 1027 } 1028 os << "}\n"; 1029 } 1030 1031 static mlir::GenRegistration 1032 genRewriters("gen-rewriters", "Generate pattern rewriters", 1033 [](const RecordKeeper &records, raw_ostream &os) { 1034 emitRewriters(records, os); 1035 return false; 1036 }); 1037