1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Support/STLExtras.h" 23 #include "mlir/TableGen/Attribute.h" 24 #include "mlir/TableGen/Format.h" 25 #include "mlir/TableGen/GenInfo.h" 26 #include "mlir/TableGen/Operator.h" 27 #include "mlir/TableGen/Pattern.h" 28 #include "mlir/TableGen/Predicate.h" 29 #include "mlir/TableGen/Type.h" 30 #include "llvm/ADT/StringExtras.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/FormatAdapters.h" 35 #include "llvm/Support/PrettyStackTrace.h" 36 #include "llvm/Support/Signals.h" 37 #include "llvm/TableGen/Error.h" 38 #include "llvm/TableGen/Main.h" 39 #include "llvm/TableGen/Record.h" 40 #include "llvm/TableGen/TableGenBackend.h" 41 42 using namespace llvm; 43 using namespace mlir; 44 using namespace mlir::tblgen; 45 46 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 47 48 namespace llvm { 49 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 50 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 51 raw_ostream &os, StringRef style) { 52 os << v.first << ":" << v.second; 53 } 54 }; 55 } // end namespace llvm 56 57 //===----------------------------------------------------------------------===// 58 // PatternEmitter 59 //===----------------------------------------------------------------------===// 60 61 namespace { 62 class PatternEmitter { 63 public: 64 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 65 66 // Emits the mlir::RewritePattern struct named `rewriteName`. 67 void emit(StringRef rewriteName); 68 69 private: 70 // Emits the code for matching ops. 71 void emitMatchLogic(DagNode tree); 72 73 // Emits the code for rewriting ops. 74 void emitRewriteLogic(); 75 76 //===--------------------------------------------------------------------===// 77 // Match utilities 78 //===--------------------------------------------------------------------===// 79 80 // Emits C++ statements for matching the op constrained by the given DAG 81 // `tree`. 82 void emitOpMatch(DagNode tree, int depth); 83 84 // Emits C++ statements for matching the `index`-th argument of the given DAG 85 // `tree` as an operand. 86 void emitOperandMatch(DagNode tree, int index, int depth, int indent); 87 88 // Emits C++ statements for matching the `index`-th argument of the given DAG 89 // `tree` as an attribute. 90 void emitAttributeMatch(DagNode tree, int index, int depth, int indent); 91 92 //===--------------------------------------------------------------------===// 93 // Rewrite utilities 94 //===--------------------------------------------------------------------===// 95 96 // The entry point for handling a result pattern rooted at `resultTree`. This 97 // method dispatches to concrete handlers according to `resultTree`'s kind and 98 // returns a symbol representing the whole value pack. Callers are expected to 99 // further resolve the symbol according to the specific use case. 100 // 101 // `depth` is the nesting level of `resultTree`; 0 means top-level result 102 // pattern. For top-level result pattern, `resultIndex` indicates which result 103 // of the matched root op this pattern is intended to replace, which can be 104 // used to deduce the result type of the op generated from this result 105 // pattern. 106 std::string handleResultPattern(DagNode resultTree, int resultIndex, 107 int depth); 108 109 // Emits the C++ statement to replace the matched DAG with a value built via 110 // calling native C++ code. 111 std::string handleReplaceWithNativeCodeCall(DagNode resultTree); 112 113 // Returns the C++ expression referencing the old value serving as the 114 // replacement. 115 std::string handleReplaceWithValue(DagNode tree); 116 117 // Emits the C++ statement to build a new op out of the given DAG `tree` and 118 // returns the variable name that this op is assigned to. If the root op in 119 // DAG `tree` has a specified name, the created op will be assigned to a 120 // variable of the given name. Otherwise, a unique name will be used as the 121 // result value name. 122 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 123 124 // Returns the C++ expression to construct a constant attribute of the given 125 // `value` for the given attribute kind `attr`. 126 std::string handleConstantAttr(Attribute attr, StringRef value); 127 128 // Returns the C++ expression to build an argument from the given DAG `leaf`. 129 // `patArgName` is used to bound the argument to the source pattern. 130 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 131 132 //===--------------------------------------------------------------------===// 133 // General utilities 134 //===--------------------------------------------------------------------===// 135 136 // Collects all of the operations within the given dag tree. 137 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 138 139 // Returns a unique symbol for a local variable of the given `op`. 140 std::string getUniqueSymbol(const Operator *op); 141 142 //===--------------------------------------------------------------------===// 143 // Symbol utilities 144 //===--------------------------------------------------------------------===// 145 146 // Returns how many static values the given DAG `node` correspond to. 147 int getNodeValueCount(DagNode node); 148 149 private: 150 // Pattern instantiation location followed by the location of multiclass 151 // prototypes used. This is intended to be used as a whole to 152 // PrintFatalError() on errors. 153 ArrayRef<llvm::SMLoc> loc; 154 155 // Op's TableGen Record to wrapper object. 156 RecordOperatorMap *opMap; 157 158 // Handy wrapper for pattern being emitted. 159 Pattern pattern; 160 161 // Map for all bound symbols' info. 162 SymbolInfoMap symbolInfoMap; 163 164 // The next unused ID for newly created values. 165 unsigned nextValueId; 166 167 raw_ostream &os; 168 169 // Format contexts containing placeholder substitutions. 170 FmtContext fmtCtx; 171 172 // Number of op processed. 173 int opCounter = 0; 174 }; 175 } // end anonymous namespace 176 177 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 178 raw_ostream &os) 179 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 180 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 181 fmtCtx.withBuilder("rewriter"); 182 } 183 184 std::string PatternEmitter::handleConstantAttr(Attribute attr, 185 StringRef value) { 186 if (!attr.isConstBuildable()) 187 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 188 " does not have the 'constBuilderCall' field"); 189 190 // TODO(jpienaar): Verify the constants here 191 return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value); 192 } 193 194 // Helper function to match patterns. 195 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 196 Operator &op = tree.getDialectOp(opMap); 197 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 198 << op.getOperationName() << "' at depth " << depth 199 << '\n'); 200 201 int indent = 4 + 2 * depth; 202 os.indent(indent) << formatv( 203 "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n", 204 depth, op.getQualCppClassName()); 205 // Skip the operand matching at depth 0 as the pattern rewriter already does. 206 if (depth != 0) { 207 // Skip if there is no defining operation (e.g., arguments to function). 208 os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n", 209 depth); 210 } 211 if (tree.getNumArgs() != op.getNumArgs()) { 212 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 213 "pattern vs. {2} in definition", 214 op.getOperationName(), tree.getNumArgs(), 215 op.getNumArgs())); 216 } 217 218 // If the operand's name is set, set to that variable. 219 auto name = tree.getSymbol(); 220 if (!name.empty()) 221 os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth); 222 223 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 224 auto opArg = op.getArg(i); 225 226 // Handle nested DAG construct first 227 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 228 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 229 if (operand->isVariadic()) { 230 auto error = formatv("use nested DAG construct to match op {0}'s " 231 "variadic operand #{1} unsupported now", 232 op.getOperationName(), i); 233 PrintFatalError(loc, error); 234 } 235 } 236 os.indent(indent) << "{\n"; 237 238 os.indent(indent + 2) << formatv( 239 "auto *op{0} = " 240 "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n", 241 depth + 1, depth, i); 242 emitOpMatch(argTree, depth + 1); 243 os.indent(indent + 2) 244 << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); 245 os.indent(indent) << "}\n"; 246 continue; 247 } 248 249 // Next handle DAG leaf: operand or attribute 250 if (opArg.is<NamedTypeConstraint *>()) { 251 emitOperandMatch(tree, i, depth, indent); 252 } else if (opArg.is<NamedAttribute *>()) { 253 emitAttributeMatch(tree, i, depth, indent); 254 } else { 255 PrintFatalError(loc, "unhandled case when matching op"); 256 } 257 } 258 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 259 << op.getOperationName() << "' at depth " << depth 260 << '\n'); 261 } 262 263 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, 264 int indent) { 265 Operator &op = tree.getDialectOp(opMap); 266 auto *operand = op.getArg(index).get<NamedTypeConstraint *>(); 267 auto matcher = tree.getArgAsLeaf(index); 268 269 // If a constraint is specified, we need to generate C++ statements to 270 // check the constraint. 271 if (!matcher.isUnspecified()) { 272 if (!matcher.isOperandMatcher()) { 273 PrintFatalError( 274 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 275 op.getOperationName(), index + 1)); 276 } 277 278 // Only need to verify if the matcher's type is different from the one 279 // of op definition. 280 if (operand->constraint != matcher.getAsConstraint()) { 281 if (operand->isVariadic()) { 282 auto error = formatv( 283 "further constrain op {0}'s variadic operand #{1} unsupported now", 284 op.getOperationName(), index); 285 PrintFatalError(loc, error); 286 } 287 auto self = 288 formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()", 289 depth, index); 290 os.indent(indent) << "if (!(" 291 << tgfmt(matcher.getConditionTemplate(), 292 &fmtCtx.withSelf(self)) 293 << ")) return matchFailure();\n"; 294 } 295 } 296 297 // Capture the value 298 auto name = tree.getArgName(index); 299 if (!name.empty()) { 300 os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n", 301 name, depth, index); 302 } 303 } 304 305 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, 306 int indent) { 307 Operator &op = tree.getDialectOp(opMap); 308 auto *namedAttr = op.getArg(index).get<NamedAttribute *>(); 309 const auto &attr = namedAttr->attr; 310 311 os.indent(indent) << "{\n"; 312 indent += 2; 313 os.indent(indent) << formatv( 314 "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, 315 attr.getStorageType(), namedAttr->name); 316 317 // TODO(antiagainst): This should use getter method to avoid duplication. 318 if (attr.hasDefaultValueInitializer()) { 319 os.indent(indent) << "if (!tblgen_attr) tblgen_attr = " 320 << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 321 attr.getDefaultValueInitializer()) 322 << ";\n"; 323 } else if (attr.isOptional()) { 324 // For a missing attribute that is optional according to definition, we 325 // should just capture a mlir::Attribute() to signal the missing state. 326 // That is precisely what getAttr() returns on missing attributes. 327 } else { 328 os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n"; 329 } 330 331 auto matcher = tree.getArgAsLeaf(index); 332 if (!matcher.isUnspecified()) { 333 if (!matcher.isAttrMatcher()) { 334 PrintFatalError( 335 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 336 op.getOperationName(), index + 1)); 337 } 338 339 // If a constraint is specified, we need to generate C++ statements to 340 // check the constraint. 341 os.indent(indent) << "if (!(" 342 << tgfmt(matcher.getConditionTemplate(), 343 &fmtCtx.withSelf("tblgen_attr")) 344 << ")) return matchFailure();\n"; 345 } 346 347 // Capture the value 348 auto name = tree.getArgName(index); 349 if (!name.empty()) { 350 os.indent(indent) << formatv("{0} = tblgen_attr;\n", name); 351 } 352 353 indent -= 2; 354 os.indent(indent) << "}\n"; 355 } 356 357 void PatternEmitter::emitMatchLogic(DagNode tree) { 358 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 359 emitOpMatch(tree, 0); 360 361 for (auto &appliedConstraint : pattern.getConstraints()) { 362 auto &constraint = appliedConstraint.constraint; 363 auto &entities = appliedConstraint.entities; 364 365 auto condition = constraint.getConditionTemplate(); 366 auto cmd = "if (!({0})) return matchFailure();\n"; 367 368 if (isa<TypeConstraint>(constraint)) { 369 auto self = formatv("({0}->getType())", 370 symbolInfoMap.getValueAndRangeUse(entities.front())); 371 os.indent(4) << formatv(cmd, 372 tgfmt(condition, &fmtCtx.withSelf(self.str()))); 373 } else if (isa<AttrConstraint>(constraint)) { 374 PrintFatalError( 375 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 376 } else { 377 // TODO(b/138794486): replace formatv arguments with the exact specified 378 // args. 379 if (entities.size() > 4) { 380 PrintFatalError(loc, "only support up to 4-entity constraints now"); 381 } 382 SmallVector<std::string, 4> names; 383 int i = 0; 384 for (int e = entities.size(); i < e; ++i) 385 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 386 std::string self = appliedConstraint.self; 387 if (!self.empty()) 388 self = symbolInfoMap.getValueAndRangeUse(self); 389 for (; i < 4; ++i) 390 names.push_back("<unused>"); 391 os.indent(4) << formatv(cmd, 392 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 393 names[1], names[2], names[3])); 394 } 395 } 396 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 397 } 398 399 void PatternEmitter::collectOps(DagNode tree, 400 llvm::SmallPtrSetImpl<const Operator *> &ops) { 401 // Check if this tree is an operation. 402 if (tree.isOperation()) { 403 const Operator &op = tree.getDialectOp(opMap); 404 LLVM_DEBUG(llvm::dbgs() 405 << "found operation " << op.getOperationName() << '\n'); 406 ops.insert(&op); 407 } 408 409 // Recurse the arguments of the tree. 410 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 411 if (auto child = tree.getArgAsNestedDag(i)) 412 collectOps(child, ops); 413 } 414 415 void PatternEmitter::emit(StringRef rewriteName) { 416 // Get the DAG tree for the source pattern. 417 DagNode sourceTree = pattern.getSourcePattern(); 418 419 const Operator &rootOp = pattern.getSourceRootOp(); 420 auto rootName = rootOp.getOperationName(); 421 422 // Collect the set of result operations. 423 llvm::SmallPtrSet<const Operator *, 4> resultOps; 424 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 425 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 426 collectOps(pattern.getResultPattern(i), resultOps); 427 } 428 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 429 430 // Emit RewritePattern for Pattern. 431 auto locs = pattern.getLocation(); 432 os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", 433 make_range(locs.rbegin(), locs.rend())); 434 os << formatv(R"(struct {0} : public RewritePattern { 435 {0}(MLIRContext *context) 436 : RewritePattern("{1}", {{)", 437 rewriteName, rootName); 438 // Sort result operators by name. 439 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 440 resultOps.end()); 441 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 442 return lhs->getOperationName() < rhs->getOperationName(); 443 }); 444 interleaveComma(sortedResultOps, os, [&](const Operator *op) { 445 os << '"' << op->getOperationName() << '"'; 446 }); 447 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 448 449 // Emit matchAndRewrite() function. 450 os << R"( 451 PatternMatchResult matchAndRewrite(Operation *op0, 452 PatternRewriter &rewriter) const override { 453 )"; 454 455 // Register all symbols bound in the source pattern. 456 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 457 458 LLVM_DEBUG( 459 llvm::dbgs() << "start creating local variables for capturing matches\n"); 460 os.indent(4) << "// Variables for capturing values and attributes used for " 461 "creating ops\n"; 462 // Create local variables for storing the arguments and results bound 463 // to symbols. 464 for (const auto &symbolInfoPair : symbolInfoMap) { 465 StringRef symbol = symbolInfoPair.getKey(); 466 auto &info = symbolInfoPair.getValue(); 467 os.indent(4) << info.getVarDecl(symbol); 468 } 469 // TODO(jpienaar): capture ops with consistent numbering so that it can be 470 // reused for fused loc. 471 os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n", 472 pattern.getSourcePattern().getNumOps()); 473 LLVM_DEBUG( 474 llvm::dbgs() << "done creating local variables for capturing matches\n"); 475 476 os.indent(4) << "// Match\n"; 477 os.indent(4) << "tblgen_ops[0] = op0;\n"; 478 emitMatchLogic(sourceTree); 479 os << "\n"; 480 481 os.indent(4) << "// Rewrite\n"; 482 emitRewriteLogic(); 483 484 os.indent(4) << "return matchSuccess();\n"; 485 os << " };\n"; 486 os << "};\n"; 487 } 488 489 void PatternEmitter::emitRewriteLogic() { 490 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 491 const Operator &rootOp = pattern.getSourceRootOp(); 492 int numExpectedResults = rootOp.getNumResults(); 493 int numResultPatterns = pattern.getNumResultPatterns(); 494 495 // First register all symbols bound to ops generated in result patterns. 496 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 497 498 // Only the last N static values generated are used to replace the matched 499 // root N-result op. We need to calculate the starting index (of the results 500 // of the matched op) each result pattern is to replace. 501 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 502 // If we don't need to replace any value at all, set the replacement starting 503 // index as the number of result patterns so we skip all of them when trying 504 // to replace the matched op's results. 505 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 506 for (int i = numResultPatterns - 1; i >= 0; --i) { 507 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 508 offsets[i] = offsets[i + 1] - numValues; 509 if (offsets[i] == 0) { 510 if (replStartIndex == -1) 511 replStartIndex = i; 512 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 513 auto error = formatv( 514 "cannot use the same multi-result op '{0}' to generate both " 515 "auxiliary values and values to be used for replacing the matched op", 516 pattern.getResultPattern(i).getSymbol()); 517 PrintFatalError(loc, error); 518 } 519 } 520 521 if (offsets.front() > 0) { 522 const char error[] = "no enough values generated to replace the matched op"; 523 PrintFatalError(loc, error); 524 } 525 526 os.indent(4) << "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n"; 527 os.indent(4) << "auto loc = rewriter.getFusedLoc({"; 528 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 529 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 530 } 531 os << "}); (void)loc;\n"; 532 533 // Process auxiliary result patterns. 534 for (int i = 0; i < replStartIndex; ++i) { 535 DagNode resultTree = pattern.getResultPattern(i); 536 auto val = handleResultPattern(resultTree, offsets[i], 0); 537 // Normal op creation will be streamed to `os` by the above call; but 538 // NativeCodeCall will only be materialized to `os` if it is used. Here 539 // we are handling auxiliary patterns so we want the side effect even if 540 // NativeCodeCall is not replacing matched root op's results. 541 if (resultTree.isNativeCodeCall()) 542 os.indent(4) << val << ";\n"; 543 } 544 545 if (numExpectedResults == 0) { 546 assert(replStartIndex >= numResultPatterns && 547 "invalid auxiliary vs. replacement pattern division!"); 548 // No result to replace. Just erase the op. 549 os.indent(4) << "rewriter.eraseOp(op0);\n"; 550 } else { 551 // Process replacement result patterns. 552 os.indent(4) << "SmallVector<Value *, 4> tblgen_values;"; 553 for (int i = replStartIndex; i < numResultPatterns; ++i) { 554 DagNode resultTree = pattern.getResultPattern(i); 555 auto val = handleResultPattern(resultTree, offsets[i], 0); 556 os.indent(4) << "\n"; 557 // Resolve each symbol for all range use so that we can loop over them. 558 os << symbolInfoMap.getAllRangeUse( 559 val, " for (auto *v : {0}) {{ tblgen_values.push_back(v); }", 560 "\n"); 561 } 562 os.indent(4) << "\n"; 563 os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n"; 564 } 565 566 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 567 } 568 569 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 570 return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++); 571 } 572 573 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 574 int resultIndex, int depth) { 575 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 576 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 577 LLVM_DEBUG(llvm::dbgs() << '\n'); 578 579 if (resultTree.isNativeCodeCall()) { 580 auto symbol = handleReplaceWithNativeCodeCall(resultTree); 581 symbolInfoMap.bindValue(symbol); 582 return symbol; 583 } 584 585 if (resultTree.isReplaceWithValue()) { 586 return handleReplaceWithValue(resultTree); 587 } 588 589 // Normal op creation. 590 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 591 if (resultTree.getSymbol().empty()) { 592 // This is an op not explicitly bound to a symbol in the rewrite rule. 593 // Register the auto-generated symbol for it. 594 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 595 } 596 return symbol; 597 } 598 599 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { 600 assert(tree.isReplaceWithValue()); 601 602 if (tree.getNumArgs() != 1) { 603 PrintFatalError( 604 loc, "replaceWithValue directive must take exactly one argument"); 605 } 606 607 if (!tree.getSymbol().empty()) { 608 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 609 } 610 611 return tree.getArgName(0); 612 } 613 614 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 615 StringRef patArgName) { 616 if (leaf.isConstantAttr()) { 617 auto constAttr = leaf.getAsConstantAttr(); 618 return handleConstantAttr(constAttr.getAttribute(), 619 constAttr.getConstantValue()); 620 } 621 if (leaf.isEnumAttrCase()) { 622 auto enumCase = leaf.getAsEnumAttrCase(); 623 if (enumCase.isStrCase()) 624 return handleConstantAttr(enumCase, enumCase.getSymbol()); 625 // This is an enum case backed by an IntegerAttr. We need to get its value 626 // to build the constant. 627 std::string val = std::to_string(enumCase.getValue()); 628 return handleConstantAttr(enumCase, val); 629 } 630 631 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 632 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 633 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 634 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 635 << "' (via symbol ref)\n"); 636 return argName; 637 } 638 if (leaf.isNativeCodeCall()) { 639 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 640 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 641 << "' (via NativeCodeCall)\n"); 642 return repl; 643 } 644 PrintFatalError(loc, "unhandled case when rewriting op"); 645 } 646 647 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { 648 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 649 LLVM_DEBUG(tree.print(llvm::dbgs())); 650 LLVM_DEBUG(llvm::dbgs() << '\n'); 651 652 auto fmt = tree.getNativeCodeTemplate(); 653 // TODO(b/138794486): replace formatv arguments with the exact specified args. 654 SmallVector<std::string, 8> attrs(8); 655 if (tree.getNumArgs() > 8) { 656 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 657 Twine(tree.getNumArgs())); 658 } 659 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 660 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 661 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argment #" << i 662 << " replacement: " << attrs[i] << "\n"); 663 } 664 return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4], 665 attrs[5], attrs[6], attrs[7]); 666 } 667 668 int PatternEmitter::getNodeValueCount(DagNode node) { 669 if (node.isOperation()) { 670 // If the op is bound to a symbol in the rewrite rule, query its result 671 // count from the symbol info map. 672 auto symbol = node.getSymbol(); 673 if (!symbol.empty()) { 674 return symbolInfoMap.getStaticValueCount(symbol); 675 } 676 // Otherwise this is an unbound op; we will use all its results. 677 return pattern.getDialectOp(node).getNumResults(); 678 } 679 // TODO(antiagainst): This considers all NativeCodeCall as returning one 680 // value. Enhance if multi-value ones are needed. 681 return 1; 682 } 683 684 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 685 int depth) { 686 Operator &resultOp = tree.getDialectOp(opMap); 687 auto numOpArgs = resultOp.getNumArgs(); 688 689 if (numOpArgs != tree.getNumArgs()) { 690 PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " 691 "{1} in pattern vs. {2} in definition", 692 resultOp.getOperationName(), tree.getNumArgs(), 693 numOpArgs)); 694 } 695 696 // A map to collect all nested DAG child nodes' names, with operand index as 697 // the key. This includes both bound and unbound child nodes. 698 llvm::DenseMap<unsigned, std::string> childNodeNames; 699 700 // First go through all the child nodes who are nested DAG constructs to 701 // create ops for them and remember the symbol names for them, so that we can 702 // use the results in the current node. This happens in a recursive manner. 703 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 704 if (auto child = tree.getArgAsNestedDag(i)) { 705 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 706 } 707 } 708 709 // The name of the local variable holding this op. 710 std::string valuePackName; 711 // The symbol for holding the result of this pattern. Note that the result of 712 // this pattern is not necessarily the same as the variable created by this 713 // pattern because we can use `__N` suffix to refer only a specific result if 714 // the generated op is a multi-result op. 715 std::string resultValue; 716 if (tree.getSymbol().empty()) { 717 // No symbol is explicitly bound to this op in the pattern. Generate a 718 // unique name. 719 valuePackName = resultValue = getUniqueSymbol(&resultOp); 720 } else { 721 resultValue = tree.getSymbol(); 722 // Strip the index to get the name for the value pack and use it to name the 723 // local variable for the op. 724 valuePackName = SymbolInfoMap::getValuePackName(resultValue); 725 } 726 727 // Create the local variable for this op. 728 os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(), 729 valuePackName); 730 os.indent(4) << "{\n"; 731 732 // Now prepare operands used for building this op: 733 // * If the operand is non-variadic, we create a `Value*` local variable. 734 // * If the operand is variadic, we create a `SmallVector<Value*>` local 735 // variable. 736 737 int argIndex = 0; // The current index to this op's ODS argument 738 int valueIndex = 0; // An index for uniquing local variable names. 739 for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) { 740 const auto &operand = resultOp.getOperand(argIndex); 741 std::string varName; 742 if (operand.isVariadic()) { 743 varName = formatv("tblgen_values_{0}", valueIndex++); 744 os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName); 745 std::string range; 746 if (tree.isNestedDagArg(argIndex)) { 747 range = childNodeNames[argIndex]; 748 } else { 749 range = tree.getArgName(argIndex); 750 } 751 // Resolve the symbol for all range use so that we have a uniform way of 752 // capturing the values. 753 range = symbolInfoMap.getValueAndRangeUse(range); 754 os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range, 755 varName); 756 } else { 757 varName = formatv("tblgen_value_{0}", valueIndex++); 758 os.indent(6) << formatv("Value *{0} = ", varName); 759 if (tree.isNestedDagArg(argIndex)) { 760 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 761 } else { 762 DagLeaf leaf = tree.getArgAsLeaf(argIndex); 763 auto symbol = 764 symbolInfoMap.getValueAndRangeUse(tree.getArgName(argIndex)); 765 if (leaf.isNativeCodeCall()) { 766 os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)); 767 } else { 768 os << symbol; 769 } 770 } 771 os << ";\n"; 772 } 773 774 // Update to use the newly created local variable for building the op later. 775 childNodeNames[argIndex] = varName; 776 } 777 778 // Then we create the builder call. 779 780 // Right now we don't have general type inference in MLIR. Except a few 781 // special cases listed below, we need to supply types for all results 782 // when building an op. 783 bool isSameOperandsAndResultType = 784 resultOp.hasTrait("OpTrait::SameOperandsAndResultType"); 785 bool isBroadcastable = 786 resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult"); 787 bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType"); 788 bool usePartialResults = valuePackName != resultValue; 789 790 if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr || 791 usePartialResults || depth > 0 || resultIndex < 0) { 792 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, 793 resultOp.getQualCppClassName()); 794 } else { 795 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 796 // generated from the source pattern root op. Then we can use the source 797 // pattern's value types to determine the value type of the generated op 798 // here. 799 800 // We need to specify the types for all results. 801 int numResults = resultOp.getNumResults(); 802 if (numResults != 0) { 803 os.indent(6) << "tblgen_types.clear();\n"; 804 for (int i = 0; i < numResults; ++i) { 805 os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) " 806 "tblgen_types.push_back(v->getType());\n", 807 resultIndex + i); 808 } 809 } 810 811 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, 812 resultOp.getQualCppClassName()); 813 if (numResults != 0) 814 os.indent(6) << ", tblgen_types"; 815 } 816 817 // Add operands for the builder all. 818 for (int i = 0; i < argIndex; ++i) { 819 const auto &operand = resultOp.getOperand(i); 820 // Start each operand on its own line. 821 (os << ",\n").indent(8); 822 if (!operand.name.empty()) { 823 os << "/*" << operand.name << "=*/"; 824 } 825 os << childNodeNames[i]; 826 // TODO(jpienaar): verify types 827 } 828 829 // Add attributes for the builder call. 830 for (; argIndex != numOpArgs; ++argIndex) { 831 // Start each attribute on its own line. 832 (os << ",\n").indent(8); 833 // The argument in the op definition. 834 auto opArgName = resultOp.getArgName(argIndex); 835 if (auto subTree = tree.getArgAsNestedDag(argIndex)) { 836 if (!subTree.isNativeCodeCall()) 837 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 838 "for creating attribute"); 839 os << formatv("/*{0}=*/{1}", opArgName, 840 handleReplaceWithNativeCodeCall(subTree)); 841 } else { 842 auto leaf = tree.getArgAsLeaf(argIndex); 843 // The argument in the result DAG pattern. 844 auto patArgName = tree.getArgName(argIndex); 845 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 846 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 847 auto argument = resultOp.getArg(argIndex); 848 if (!argument.is<NamedAttribute *>()) 849 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 850 if (!patArgName.empty()) 851 os << "/*" << patArgName << "=*/"; 852 } else { 853 os << "/*" << opArgName << "=*/"; 854 } 855 os << handleOpArgument(leaf, patArgName); 856 } 857 } 858 os << "\n );\n"; 859 os.indent(4) << "}\n"; 860 861 return resultValue; 862 } 863 864 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 865 emitSourceFileHeader("Rewriters", os); 866 867 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 868 auto numPatterns = patterns.size(); 869 870 // We put the map here because it can be shared among multiple patterns. 871 RecordOperatorMap recordOpMap; 872 873 std::vector<std::string> rewriterNames; 874 rewriterNames.reserve(numPatterns); 875 876 std::string baseRewriterName = "GeneratedConvert"; 877 int rewriterIndex = 0; 878 879 for (Record *p : patterns) { 880 std::string name; 881 if (p->isAnonymous()) { 882 // If no name is provided, ensure unique rewriter names simply by 883 // appending unique suffix. 884 name = baseRewriterName + llvm::utostr(rewriterIndex++); 885 } else { 886 name = p->getName(); 887 } 888 LLVM_DEBUG(llvm::dbgs() 889 << "=== start generating pattern '" << name << "' ===\n"); 890 PatternEmitter(p, &recordOpMap, os).emit(name); 891 LLVM_DEBUG(llvm::dbgs() 892 << "=== done generating pattern '" << name << "' ===\n"); 893 rewriterNames.push_back(std::move(name)); 894 } 895 896 // Emit function to add the generated matchers to the pattern list. 897 os << "void populateWithGenerated(MLIRContext *context, " 898 << "OwningRewritePatternList *patterns) {\n"; 899 for (const auto &name : rewriterNames) { 900 os << " patterns->insert<" << name << ">(context);\n"; 901 } 902 os << "}\n"; 903 } 904 905 static mlir::GenRegistration 906 genRewriters("gen-rewriters", "Generate pattern rewriters", 907 [](const RecordKeeper &records, raw_ostream &os) { 908 emitRewriters(records, os); 909 return false; 910 }); 911