1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Support/STLExtras.h" 23 #include "mlir/TableGen/Attribute.h" 24 #include "mlir/TableGen/Format.h" 25 #include "mlir/TableGen/GenInfo.h" 26 #include "mlir/TableGen/Operator.h" 27 #include "mlir/TableGen/Pattern.h" 28 #include "mlir/TableGen/Predicate.h" 29 #include "mlir/TableGen/Type.h" 30 #include "llvm/ADT/StringExtras.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/FormatAdapters.h" 35 #include "llvm/Support/PrettyStackTrace.h" 36 #include "llvm/Support/Signals.h" 37 #include "llvm/TableGen/Error.h" 38 #include "llvm/TableGen/Main.h" 39 #include "llvm/TableGen/Record.h" 40 #include "llvm/TableGen/TableGenBackend.h" 41 42 using namespace llvm; 43 using namespace mlir; 44 using namespace mlir::tblgen; 45 46 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 47 48 namespace llvm { 49 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 50 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 51 raw_ostream &os, StringRef style) { 52 os << v.first << ":" << v.second; 53 } 54 }; 55 } // end namespace llvm 56 57 //===----------------------------------------------------------------------===// 58 // PatternEmitter 59 //===----------------------------------------------------------------------===// 60 61 namespace { 62 class PatternEmitter { 63 public: 64 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 65 66 // Emits the mlir::RewritePattern struct named `rewriteName`. 67 void emit(StringRef rewriteName); 68 69 private: 70 // Emits the code for matching ops. 71 void emitMatchLogic(DagNode tree); 72 73 // Emits the code for rewriting ops. 74 void emitRewriteLogic(); 75 76 //===--------------------------------------------------------------------===// 77 // Match utilities 78 //===--------------------------------------------------------------------===// 79 80 // Emits C++ statements for matching the op constrained by the given DAG 81 // `tree`. 82 void emitOpMatch(DagNode tree, int depth); 83 84 // Emits C++ statements for matching the `index`-th argument of the given DAG 85 // `tree` as an operand. 86 void emitOperandMatch(DagNode tree, int index, int depth, int indent); 87 88 // Emits C++ statements for matching the `index`-th argument of the given DAG 89 // `tree` as an attribute. 90 void emitAttributeMatch(DagNode tree, int index, int depth, int indent); 91 92 //===--------------------------------------------------------------------===// 93 // Rewrite utilities 94 //===--------------------------------------------------------------------===// 95 96 // The entry point for handling a result pattern rooted at `resultTree`. This 97 // method dispatches to concrete handlers according to `resultTree`'s kind and 98 // returns a symbol representing the whole value pack. Callers are expected to 99 // further resolve the symbol according to the specific use case. 100 // 101 // `depth` is the nesting level of `resultTree`; 0 means top-level result 102 // pattern. For top-level result pattern, `resultIndex` indicates which result 103 // of the matched root op this pattern is intended to replace, which can be 104 // used to deduce the result type of the op generated from this result 105 // pattern. 106 std::string handleResultPattern(DagNode resultTree, int resultIndex, 107 int depth); 108 109 // Emits the C++ statement to replace the matched DAG with a value built via 110 // calling native C++ code. 111 std::string handleReplaceWithNativeCodeCall(DagNode resultTree); 112 113 // Returns the C++ expression referencing the old value serving as the 114 // replacement. 115 std::string handleReplaceWithValue(DagNode tree); 116 117 // Emits the C++ statement to build a new op out of the given DAG `tree` and 118 // returns the variable name that this op is assigned to. If the root op in 119 // DAG `tree` has a specified name, the created op will be assigned to a 120 // variable of the given name. Otherwise, a unique name will be used as the 121 // result value name. 122 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 123 124 // Returns the C++ expression to construct a constant attribute of the given 125 // `value` for the given attribute kind `attr`. 126 std::string handleConstantAttr(Attribute attr, StringRef value); 127 128 // Returns the C++ expression to build an argument from the given DAG `leaf`. 129 // `patArgName` is used to bound the argument to the source pattern. 130 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 131 132 //===--------------------------------------------------------------------===// 133 // General utilities 134 //===--------------------------------------------------------------------===// 135 136 // Collects all of the operations within the given dag tree. 137 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 138 139 // Returns a unique symbol for a local variable of the given `op`. 140 std::string getUniqueSymbol(const Operator *op); 141 142 //===--------------------------------------------------------------------===// 143 // Symbol utilities 144 //===--------------------------------------------------------------------===// 145 146 // Returns how many static values the given DAG `node` correspond to. 147 int getNodeValueCount(DagNode node); 148 149 private: 150 // Pattern instantiation location followed by the location of multiclass 151 // prototypes used. This is intended to be used as a whole to 152 // PrintFatalError() on errors. 153 ArrayRef<llvm::SMLoc> loc; 154 155 // Op's TableGen Record to wrapper object. 156 RecordOperatorMap *opMap; 157 158 // Handy wrapper for pattern being emitted. 159 Pattern pattern; 160 161 // Map for all bound symbols' info. 162 SymbolInfoMap symbolInfoMap; 163 164 // The next unused ID for newly created values. 165 unsigned nextValueId; 166 167 raw_ostream &os; 168 169 // Format contexts containing placeholder substitutions. 170 FmtContext fmtCtx; 171 172 // Number of op processed. 173 int opCounter = 0; 174 }; 175 } // end anonymous namespace 176 177 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 178 raw_ostream &os) 179 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 180 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 181 fmtCtx.withBuilder("rewriter"); 182 } 183 184 std::string PatternEmitter::handleConstantAttr(Attribute attr, 185 StringRef value) { 186 if (!attr.isConstBuildable()) 187 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 188 " does not have the 'constBuilderCall' field"); 189 190 // TODO(jpienaar): Verify the constants here 191 return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value); 192 } 193 194 // Helper function to match patterns. 195 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 196 Operator &op = tree.getDialectOp(opMap); 197 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 198 << op.getOperationName() << "' at depth " << depth 199 << '\n'); 200 201 int indent = 4 + 2 * depth; 202 os.indent(indent) << formatv( 203 "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n", 204 depth, op.getQualCppClassName()); 205 // Skip the operand matching at depth 0 as the pattern rewriter already does. 206 if (depth != 0) { 207 // Skip if there is no defining operation (e.g., arguments to function). 208 os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n", 209 depth); 210 } 211 if (tree.getNumArgs() != op.getNumArgs()) { 212 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 213 "pattern vs. {2} in definition", 214 op.getOperationName(), tree.getNumArgs(), 215 op.getNumArgs())); 216 } 217 218 // If the operand's name is set, set to that variable. 219 auto name = tree.getSymbol(); 220 if (!name.empty()) 221 os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth); 222 223 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 224 auto opArg = op.getArg(i); 225 226 // Handle nested DAG construct first 227 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 228 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 229 if (operand->isVariadic()) { 230 auto error = formatv("use nested DAG construct to match op {0}'s " 231 "variadic operand #{1} unsupported now", 232 op.getOperationName(), i); 233 PrintFatalError(loc, error); 234 } 235 } 236 os.indent(indent) << "{\n"; 237 238 os.indent(indent + 2) << formatv( 239 "auto *op{0} = " 240 "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n", 241 depth + 1, depth, i); 242 emitOpMatch(argTree, depth + 1); 243 os.indent(indent + 2) 244 << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); 245 os.indent(indent) << "}\n"; 246 continue; 247 } 248 249 // Next handle DAG leaf: operand or attribute 250 if (opArg.is<NamedTypeConstraint *>()) { 251 emitOperandMatch(tree, i, depth, indent); 252 } else if (opArg.is<NamedAttribute *>()) { 253 emitAttributeMatch(tree, i, depth, indent); 254 } else { 255 PrintFatalError(loc, "unhandled case when matching op"); 256 } 257 } 258 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 259 << op.getOperationName() << "' at depth " << depth 260 << '\n'); 261 } 262 263 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, 264 int indent) { 265 Operator &op = tree.getDialectOp(opMap); 266 auto *operand = op.getArg(index).get<NamedTypeConstraint *>(); 267 auto matcher = tree.getArgAsLeaf(index); 268 269 // If a constraint is specified, we need to generate C++ statements to 270 // check the constraint. 271 if (!matcher.isUnspecified()) { 272 if (!matcher.isOperandMatcher()) { 273 PrintFatalError( 274 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 275 op.getOperationName(), index + 1)); 276 } 277 278 // Only need to verify if the matcher's type is different from the one 279 // of op definition. 280 if (operand->constraint != matcher.getAsConstraint()) { 281 if (operand->isVariadic()) { 282 auto error = formatv( 283 "further constrain op {0}'s variadic operand #{1} unsupported now", 284 op.getOperationName(), index); 285 PrintFatalError(loc, error); 286 } 287 auto self = 288 formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()", 289 depth, index); 290 os.indent(indent) << "if (!(" 291 << tgfmt(matcher.getConditionTemplate(), 292 &fmtCtx.withSelf(self)) 293 << ")) return matchFailure();\n"; 294 } 295 } 296 297 // Capture the value 298 auto name = tree.getArgName(index); 299 if (!name.empty()) { 300 os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n", 301 name, depth, index); 302 } 303 } 304 305 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, 306 int indent) { 307 Operator &op = tree.getDialectOp(opMap); 308 auto *namedAttr = op.getArg(index).get<NamedAttribute *>(); 309 const auto &attr = namedAttr->attr; 310 311 os.indent(indent) << "{\n"; 312 indent += 2; 313 os.indent(indent) << formatv( 314 "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, 315 attr.getStorageType(), namedAttr->name); 316 317 // TODO(antiagainst): This should use getter method to avoid duplication. 318 if (attr.hasDefaultValueInitializer()) { 319 os.indent(indent) << "if (!tblgen_attr) tblgen_attr = " 320 << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 321 attr.getDefaultValueInitializer()) 322 << ";\n"; 323 } else if (attr.isOptional()) { 324 // For a missing attribute that is optional according to definition, we 325 // should just capture a mlir::Attribute() to signal the missing state. 326 // That is precisely what getAttr() returns on missing attributes. 327 } else { 328 os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n"; 329 } 330 331 auto matcher = tree.getArgAsLeaf(index); 332 if (!matcher.isUnspecified()) { 333 if (!matcher.isAttrMatcher()) { 334 PrintFatalError( 335 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 336 op.getOperationName(), index + 1)); 337 } 338 339 // If a constraint is specified, we need to generate C++ statements to 340 // check the constraint. 341 os.indent(indent) << "if (!(" 342 << tgfmt(matcher.getConditionTemplate(), 343 &fmtCtx.withSelf("tblgen_attr")) 344 << ")) return matchFailure();\n"; 345 } 346 347 // Capture the value 348 auto name = tree.getArgName(index); 349 if (!name.empty()) { 350 os.indent(indent) << formatv("{0} = tblgen_attr;\n", name); 351 } 352 353 indent -= 2; 354 os.indent(indent) << "}\n"; 355 } 356 357 void PatternEmitter::emitMatchLogic(DagNode tree) { 358 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 359 emitOpMatch(tree, 0); 360 361 for (auto &appliedConstraint : pattern.getConstraints()) { 362 auto &constraint = appliedConstraint.constraint; 363 auto &entities = appliedConstraint.entities; 364 365 auto condition = constraint.getConditionTemplate(); 366 auto cmd = "if (!({0})) return matchFailure();\n"; 367 368 if (isa<TypeConstraint>(constraint)) { 369 auto self = formatv("({0}->getType())", 370 symbolInfoMap.getValueAndRangeUse(entities.front())); 371 os.indent(4) << formatv(cmd, 372 tgfmt(condition, &fmtCtx.withSelf(self.str()))); 373 } else if (isa<AttrConstraint>(constraint)) { 374 PrintFatalError( 375 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 376 } else { 377 // TODO(b/138794486): replace formatv arguments with the exact specified 378 // args. 379 if (entities.size() > 4) { 380 PrintFatalError(loc, "only support up to 4-entity constraints now"); 381 } 382 SmallVector<std::string, 4> names; 383 int i = 0; 384 for (int e = entities.size(); i < e; ++i) 385 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 386 std::string self = appliedConstraint.self; 387 if (!self.empty()) 388 self = symbolInfoMap.getValueAndRangeUse(self); 389 for (; i < 4; ++i) 390 names.push_back("<unused>"); 391 os.indent(4) << formatv(cmd, 392 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 393 names[1], names[2], names[3])); 394 } 395 } 396 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 397 } 398 399 void PatternEmitter::collectOps(DagNode tree, 400 llvm::SmallPtrSetImpl<const Operator *> &ops) { 401 // Check if this tree is an operation. 402 if (tree.isOperation()) { 403 const Operator &op = tree.getDialectOp(opMap); 404 LLVM_DEBUG(llvm::dbgs() 405 << "found operation " << op.getOperationName() << '\n'); 406 ops.insert(&op); 407 } 408 409 // Recurse the arguments of the tree. 410 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 411 if (auto child = tree.getArgAsNestedDag(i)) 412 collectOps(child, ops); 413 } 414 415 void PatternEmitter::emit(StringRef rewriteName) { 416 // Get the DAG tree for the source pattern. 417 DagNode sourceTree = pattern.getSourcePattern(); 418 419 const Operator &rootOp = pattern.getSourceRootOp(); 420 auto rootName = rootOp.getOperationName(); 421 422 // Collect the set of result operations. 423 llvm::SmallPtrSet<const Operator *, 4> resultOps; 424 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 425 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 426 collectOps(pattern.getResultPattern(i), resultOps); 427 } 428 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 429 430 // Emit RewritePattern for Pattern. 431 auto locs = pattern.getLocation(); 432 os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", 433 make_range(locs.rbegin(), locs.rend())); 434 os << formatv(R"(struct {0} : public RewritePattern { 435 {0}(MLIRContext *context) 436 : RewritePattern("{1}", {{)", 437 rewriteName, rootName); 438 // Sort result operators by name. 439 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 440 resultOps.end()); 441 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 442 return lhs->getOperationName() < rhs->getOperationName(); 443 }); 444 interleaveComma(sortedResultOps, os, [&](const Operator *op) { 445 os << '"' << op->getOperationName() << '"'; 446 }); 447 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 448 449 // Emit matchAndRewrite() function. 450 os << R"( 451 PatternMatchResult matchAndRewrite(Operation *op0, 452 PatternRewriter &rewriter) const override { 453 )"; 454 455 // Register all symbols bound in the source pattern. 456 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 457 458 LLVM_DEBUG( 459 llvm::dbgs() << "start creating local variables for capturing matches\n"); 460 os.indent(4) << "// Variables for capturing values and attributes used for " 461 "creating ops\n"; 462 // Create local variables for storing the arguments and results bound 463 // to symbols. 464 for (const auto &symbolInfoPair : symbolInfoMap) { 465 StringRef symbol = symbolInfoPair.getKey(); 466 auto &info = symbolInfoPair.getValue(); 467 os.indent(4) << info.getVarDecl(symbol); 468 } 469 // TODO(jpienaar): capture ops with consistent numbering so that it can be 470 // reused for fused loc. 471 os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n", 472 pattern.getSourcePattern().getNumOps()); 473 LLVM_DEBUG( 474 llvm::dbgs() << "done creating local variables for capturing matches\n"); 475 476 os.indent(4) << "// Match\n"; 477 os.indent(4) << "tblgen_ops[0] = op0;\n"; 478 emitMatchLogic(sourceTree); 479 os << "\n"; 480 481 os.indent(4) << "// Rewrite\n"; 482 emitRewriteLogic(); 483 484 os.indent(4) << "return matchSuccess();\n"; 485 os << " };\n"; 486 os << "};\n"; 487 } 488 489 void PatternEmitter::emitRewriteLogic() { 490 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 491 const Operator &rootOp = pattern.getSourceRootOp(); 492 int numExpectedResults = rootOp.getNumResults(); 493 int numResultPatterns = pattern.getNumResultPatterns(); 494 495 // First register all symbols bound to ops generated in result patterns. 496 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 497 498 // Only the last N static values generated are used to replace the matched 499 // root N-result op. We need to calculate the starting index (of the results 500 // of the matched op) each result pattern is to replace. 501 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 502 // If we don't need to replace any value at all, set the replacement starting 503 // index as the number of result patterns so we skip all of them when trying 504 // to replace the matched op's results. 505 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 506 for (int i = numResultPatterns - 1; i >= 0; --i) { 507 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 508 offsets[i] = offsets[i + 1] - numValues; 509 if (offsets[i] == 0) { 510 if (replStartIndex == -1) 511 replStartIndex = i; 512 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 513 auto error = formatv( 514 "cannot use the same multi-result op '{0}' to generate both " 515 "auxiliary values and values to be used for replacing the matched op", 516 pattern.getResultPattern(i).getSymbol()); 517 PrintFatalError(loc, error); 518 } 519 } 520 521 if (offsets.front() > 0) { 522 const char error[] = "no enough values generated to replace the matched op"; 523 PrintFatalError(loc, error); 524 } 525 526 os.indent(4) << "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n"; 527 os.indent(4) << "auto loc = rewriter.getFusedLoc({"; 528 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 529 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 530 } 531 os << "}); (void)loc;\n"; 532 533 // Process auxiliary result patterns. 534 for (int i = 0; i < replStartIndex; ++i) { 535 DagNode resultTree = pattern.getResultPattern(i); 536 auto val = handleResultPattern(resultTree, offsets[i], 0); 537 // Normal op creation will be streamed to `os` by the above call; but 538 // NativeCodeCall will only be materialized to `os` if it is used. Here 539 // we are handling auxiliary patterns so we want the side effect even if 540 // NativeCodeCall is not replacing matched root op's results. 541 if (resultTree.isNativeCodeCall()) 542 os.indent(4) << val << ";\n"; 543 } 544 545 if (numExpectedResults == 0) { 546 assert(replStartIndex >= numResultPatterns && 547 "invalid auxiliary vs. replacement pattern division!"); 548 // No result to replace. Just erase the op. 549 os.indent(4) << "rewriter.eraseOp(op0);\n"; 550 } else { 551 // Process replacement result patterns. 552 os.indent(4) << "SmallVector<Value *, 4> tblgen_values;"; 553 for (int i = replStartIndex; i < numResultPatterns; ++i) { 554 DagNode resultTree = pattern.getResultPattern(i); 555 auto val = handleResultPattern(resultTree, offsets[i], 0); 556 os.indent(4) << "\n"; 557 // Resolve each symbol for all range use so that we can loop over them. 558 os << symbolInfoMap.getAllRangeUse( 559 val, " for (auto *v : {0}) tblgen_values.push_back(v);", "\n"); 560 } 561 os.indent(4) << "\n"; 562 os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n"; 563 } 564 565 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 566 } 567 568 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 569 return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++); 570 } 571 572 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 573 int resultIndex, int depth) { 574 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 575 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 576 LLVM_DEBUG(llvm::dbgs() << '\n'); 577 578 if (resultTree.isNativeCodeCall()) { 579 auto symbol = handleReplaceWithNativeCodeCall(resultTree); 580 symbolInfoMap.bindValue(symbol); 581 return symbol; 582 } 583 584 if (resultTree.isReplaceWithValue()) { 585 return handleReplaceWithValue(resultTree); 586 } 587 588 // Normal op creation. 589 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 590 if (resultTree.getSymbol().empty()) { 591 // This is an op not explicitly bound to a symbol in the rewrite rule. 592 // Register the auto-generated symbol for it. 593 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 594 } 595 return symbol; 596 } 597 598 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { 599 assert(tree.isReplaceWithValue()); 600 601 if (tree.getNumArgs() != 1) { 602 PrintFatalError( 603 loc, "replaceWithValue directive must take exactly one argument"); 604 } 605 606 if (!tree.getSymbol().empty()) { 607 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 608 } 609 610 return tree.getArgName(0); 611 } 612 613 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 614 StringRef patArgName) { 615 if (leaf.isConstantAttr()) { 616 auto constAttr = leaf.getAsConstantAttr(); 617 return handleConstantAttr(constAttr.getAttribute(), 618 constAttr.getConstantValue()); 619 } 620 if (leaf.isEnumAttrCase()) { 621 auto enumCase = leaf.getAsEnumAttrCase(); 622 if (enumCase.isStrCase()) 623 return handleConstantAttr(enumCase, enumCase.getSymbol()); 624 // This is an enum case backed by an IntegerAttr. We need to get its value 625 // to build the constant. 626 std::string val = std::to_string(enumCase.getValue()); 627 return handleConstantAttr(enumCase, val); 628 } 629 630 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 631 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 632 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 633 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 634 << "' (via symbol ref)\n"); 635 return argName; 636 } 637 if (leaf.isNativeCodeCall()) { 638 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 639 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 640 << "' (via NativeCodeCall)\n"); 641 return repl; 642 } 643 PrintFatalError(loc, "unhandled case when rewriting op"); 644 } 645 646 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { 647 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 648 LLVM_DEBUG(tree.print(llvm::dbgs())); 649 LLVM_DEBUG(llvm::dbgs() << '\n'); 650 651 auto fmt = tree.getNativeCodeTemplate(); 652 // TODO(b/138794486): replace formatv arguments with the exact specified args. 653 SmallVector<std::string, 8> attrs(8); 654 if (tree.getNumArgs() > 8) { 655 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 656 Twine(tree.getNumArgs())); 657 } 658 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 659 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 660 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argment #" << i 661 << " replacement: " << attrs[i] << "\n"); 662 } 663 return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4], 664 attrs[5], attrs[6], attrs[7]); 665 } 666 667 int PatternEmitter::getNodeValueCount(DagNode node) { 668 if (node.isOperation()) { 669 // If the op is bound to a symbol in the rewrite rule, query its result 670 // count from the symbol info map. 671 auto symbol = node.getSymbol(); 672 if (!symbol.empty()) { 673 return symbolInfoMap.getStaticValueCount(symbol); 674 } 675 // Otherwise this is an unbound op; we will use all its results. 676 return pattern.getDialectOp(node).getNumResults(); 677 } 678 // TODO(antiagainst): This considers all NativeCodeCall as returning one 679 // value. Enhance if multi-value ones are needed. 680 return 1; 681 } 682 683 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 684 int depth) { 685 Operator &resultOp = tree.getDialectOp(opMap); 686 auto numOpArgs = resultOp.getNumArgs(); 687 688 if (numOpArgs != tree.getNumArgs()) { 689 PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " 690 "{1} in pattern vs. {2} in definition", 691 resultOp.getOperationName(), tree.getNumArgs(), 692 numOpArgs)); 693 } 694 695 // A map to collect all nested DAG child nodes' names, with operand index as 696 // the key. This includes both bound and unbound child nodes. 697 llvm::DenseMap<unsigned, std::string> childNodeNames; 698 699 // First go through all the child nodes who are nested DAG constructs to 700 // create ops for them and remember the symbol names for them, so that we can 701 // use the results in the current node. This happens in a recursive manner. 702 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 703 if (auto child = tree.getArgAsNestedDag(i)) { 704 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 705 } 706 } 707 708 // The name of the local variable holding this op. 709 std::string valuePackName; 710 // The symbol for holding the result of this pattern. Note that the result of 711 // this pattern is not necessarily the same as the variable created by this 712 // pattern because we can use `__N` suffix to refer only a specific result if 713 // the generated op is a multi-result op. 714 std::string resultValue; 715 if (tree.getSymbol().empty()) { 716 // No symbol is explicitly bound to this op in the pattern. Generate a 717 // unique name. 718 valuePackName = resultValue = getUniqueSymbol(&resultOp); 719 } else { 720 resultValue = tree.getSymbol(); 721 // Strip the index to get the name for the value pack and use it to name the 722 // local variable for the op. 723 valuePackName = SymbolInfoMap::getValuePackName(resultValue); 724 } 725 726 // Create the local variable for this op. 727 os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(), 728 valuePackName); 729 os.indent(4) << "{\n"; 730 731 // Now prepare operands used for building this op: 732 // * If the operand is non-variadic, we create a `Value*` local variable. 733 // * If the operand is variadic, we create a `SmallVector<Value*>` local 734 // variable. 735 736 int argIndex = 0; // The current index to this op's ODS argument 737 int valueIndex = 0; // An index for uniquing local variable names. 738 for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) { 739 const auto &operand = resultOp.getOperand(argIndex); 740 std::string varName; 741 if (operand.isVariadic()) { 742 varName = formatv("tblgen_values_{0}", valueIndex++); 743 os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName); 744 std::string range; 745 if (tree.isNestedDagArg(argIndex)) { 746 range = childNodeNames[argIndex]; 747 } else { 748 range = tree.getArgName(argIndex); 749 } 750 // Resolve the symbol for all range use so that we have a uniform way of 751 // capturing the values. 752 range = symbolInfoMap.getValueAndRangeUse(range); 753 os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range, 754 varName); 755 } else { 756 varName = formatv("tblgen_value_{0}", valueIndex++); 757 os.indent(6) << formatv("Value *{0} = ", varName); 758 if (tree.isNestedDagArg(argIndex)) { 759 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 760 } else { 761 DagLeaf leaf = tree.getArgAsLeaf(argIndex); 762 auto symbol = 763 symbolInfoMap.getValueAndRangeUse(tree.getArgName(argIndex)); 764 if (leaf.isNativeCodeCall()) { 765 os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)); 766 } else { 767 os << symbol; 768 } 769 } 770 os << ";\n"; 771 } 772 773 // Update to use the newly created local variable for building the op later. 774 childNodeNames[argIndex] = varName; 775 } 776 777 // Then we create the builder call. 778 779 // Right now we don't have general type inference in MLIR. Except a few 780 // special cases listed below, we need to supply types for all results 781 // when building an op. 782 bool isSameOperandsAndResultType = 783 resultOp.hasTrait("OpTrait::SameOperandsAndResultType"); 784 bool isBroadcastable = 785 resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult"); 786 bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType"); 787 bool usePartialResults = valuePackName != resultValue; 788 789 if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr || 790 usePartialResults || depth > 0 || resultIndex < 0) { 791 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, 792 resultOp.getQualCppClassName()); 793 } else { 794 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 795 // generated from the source pattern root op. Then we can use the source 796 // pattern's value types to determine the value type of the generated op 797 // here. 798 799 // We need to specify the types for all results. 800 int numResults = resultOp.getNumResults(); 801 if (numResults != 0) { 802 os.indent(6) << "tblgen_types.clear();\n"; 803 for (int i = 0; i < numResults; ++i) { 804 os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) " 805 "tblgen_types.push_back(v->getType());\n", 806 resultIndex + i); 807 } 808 } 809 810 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, 811 resultOp.getQualCppClassName()); 812 if (numResults != 0) 813 os.indent(6) << ", tblgen_types"; 814 } 815 816 // Add operands for the builder all. 817 for (int i = 0; i < argIndex; ++i) { 818 const auto &operand = resultOp.getOperand(i); 819 // Start each operand on its own line. 820 (os << ",\n").indent(8); 821 if (!operand.name.empty()) { 822 os << "/*" << operand.name << "=*/"; 823 } 824 os << childNodeNames[i]; 825 // TODO(jpienaar): verify types 826 } 827 828 // Add attributes for the builder call. 829 for (; argIndex != numOpArgs; ++argIndex) { 830 // Start each attribute on its own line. 831 (os << ",\n").indent(8); 832 // The argument in the op definition. 833 auto opArgName = resultOp.getArgName(argIndex); 834 if (auto subTree = tree.getArgAsNestedDag(argIndex)) { 835 if (!subTree.isNativeCodeCall()) 836 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 837 "for creating attribute"); 838 os << formatv("/*{0}=*/{1}", opArgName, 839 handleReplaceWithNativeCodeCall(subTree)); 840 } else { 841 auto leaf = tree.getArgAsLeaf(argIndex); 842 // The argument in the result DAG pattern. 843 auto patArgName = tree.getArgName(argIndex); 844 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 845 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 846 auto argument = resultOp.getArg(argIndex); 847 if (!argument.is<NamedAttribute *>()) 848 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 849 if (!patArgName.empty()) 850 os << "/*" << patArgName << "=*/"; 851 } else { 852 os << "/*" << opArgName << "=*/"; 853 } 854 os << handleOpArgument(leaf, patArgName); 855 } 856 } 857 os << "\n );\n"; 858 os.indent(4) << "}\n"; 859 860 return resultValue; 861 } 862 863 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 864 emitSourceFileHeader("Rewriters", os); 865 866 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 867 auto numPatterns = patterns.size(); 868 869 // We put the map here because it can be shared among multiple patterns. 870 RecordOperatorMap recordOpMap; 871 872 std::vector<std::string> rewriterNames; 873 rewriterNames.reserve(numPatterns); 874 875 std::string baseRewriterName = "GeneratedConvert"; 876 int rewriterIndex = 0; 877 878 for (Record *p : patterns) { 879 std::string name; 880 if (p->isAnonymous()) { 881 // If no name is provided, ensure unique rewriter names simply by 882 // appending unique suffix. 883 name = baseRewriterName + llvm::utostr(rewriterIndex++); 884 } else { 885 name = p->getName(); 886 } 887 LLVM_DEBUG(llvm::dbgs() 888 << "=== start generating pattern '" << name << "' ===\n"); 889 PatternEmitter(p, &recordOpMap, os).emit(name); 890 LLVM_DEBUG(llvm::dbgs() 891 << "=== done generating pattern '" << name << "' ===\n"); 892 rewriterNames.push_back(std::move(name)); 893 } 894 895 // Emit function to add the generated matchers to the pattern list. 896 os << "void populateWithGenerated(MLIRContext *context, " 897 << "OwningRewritePatternList *patterns) {\n"; 898 for (const auto &name : rewriterNames) { 899 os << " patterns->insert<" << name << ">(context);\n"; 900 } 901 os << "}\n"; 902 } 903 904 static mlir::GenRegistration 905 genRewriters("gen-rewriters", "Generate pattern rewriters", 906 [](const RecordKeeper &records, raw_ostream &os) { 907 emitRewriters(records, os); 908 return false; 909 }); 910