1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Support/STLExtras.h" 23 #include "mlir/TableGen/Attribute.h" 24 #include "mlir/TableGen/Format.h" 25 #include "mlir/TableGen/GenInfo.h" 26 #include "mlir/TableGen/Operator.h" 27 #include "mlir/TableGen/Pattern.h" 28 #include "mlir/TableGen/Predicate.h" 29 #include "mlir/TableGen/Type.h" 30 #include "llvm/ADT/StringExtras.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/FormatAdapters.h" 35 #include "llvm/Support/PrettyStackTrace.h" 36 #include "llvm/Support/Signals.h" 37 #include "llvm/TableGen/Error.h" 38 #include "llvm/TableGen/Main.h" 39 #include "llvm/TableGen/Record.h" 40 #include "llvm/TableGen/TableGenBackend.h" 41 42 using namespace llvm; 43 using namespace mlir; 44 using namespace mlir::tblgen; 45 46 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 47 48 namespace llvm { 49 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 50 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 51 raw_ostream &os, StringRef style) { 52 os << v.first << ":" << v.second; 53 } 54 }; 55 } // end namespace llvm 56 57 //===----------------------------------------------------------------------===// 58 // PatternEmitter 59 //===----------------------------------------------------------------------===// 60 61 namespace { 62 class PatternEmitter { 63 public: 64 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 65 66 // Emits the mlir::RewritePattern struct named `rewriteName`. 67 void emit(StringRef rewriteName); 68 69 private: 70 // Emits the code for matching ops. 71 void emitMatchLogic(DagNode tree); 72 73 // Emits the code for rewriting ops. 74 void emitRewriteLogic(); 75 76 //===--------------------------------------------------------------------===// 77 // Match utilities 78 //===--------------------------------------------------------------------===// 79 80 // Emits C++ statements for matching the op constrained by the given DAG 81 // `tree`. 82 void emitOpMatch(DagNode tree, int depth); 83 84 // Emits C++ statements for matching the `index`-th argument of the given DAG 85 // `tree` as an operand. 86 void emitOperandMatch(DagNode tree, int index, int depth, int indent); 87 88 // Emits C++ statements for matching the `index`-th argument of the given DAG 89 // `tree` as an attribute. 90 void emitAttributeMatch(DagNode tree, int index, int depth, int indent); 91 92 //===--------------------------------------------------------------------===// 93 // Rewrite utilities 94 //===--------------------------------------------------------------------===// 95 96 // The entry point for handling a result pattern rooted at `resultTree`. This 97 // method dispatches to concrete handlers according to `resultTree`'s kind and 98 // returns a symbol representing the whole value pack. Callers are expected to 99 // further resolve the symbol according to the specific use case. 100 // 101 // `depth` is the nesting level of `resultTree`; 0 means top-level result 102 // pattern. For top-level result pattern, `resultIndex` indicates which result 103 // of the matched root op this pattern is intended to replace, which can be 104 // used to deduce the result type of the op generated from this result 105 // pattern. 106 std::string handleResultPattern(DagNode resultTree, int resultIndex, 107 int depth); 108 109 // Emits the C++ statement to replace the matched DAG with a value built via 110 // calling native C++ code. 111 std::string handleReplaceWithNativeCodeCall(DagNode resultTree); 112 113 // Returns the C++ expression referencing the old value serving as the 114 // replacement. 115 std::string handleReplaceWithValue(DagNode tree); 116 117 // Emits the C++ statement to build a new op out of the given DAG `tree` and 118 // returns the variable name that this op is assigned to. If the root op in 119 // DAG `tree` has a specified name, the created op will be assigned to a 120 // variable of the given name. Otherwise, a unique name will be used as the 121 // result value name. 122 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 123 124 // Returns the C++ expression to construct a constant attribute of the given 125 // `value` for the given attribute kind `attr`. 126 std::string handleConstantAttr(Attribute attr, StringRef value); 127 128 // Returns the C++ expression to build an argument from the given DAG `leaf`. 129 // `patArgName` is used to bound the argument to the source pattern. 130 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 131 132 //===--------------------------------------------------------------------===// 133 // General utilities 134 //===--------------------------------------------------------------------===// 135 136 // Collects all of the operations within the given dag tree. 137 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 138 139 // Returns a unique symbol for a local variable of the given `op`. 140 std::string getUniqueSymbol(const Operator *op); 141 142 //===--------------------------------------------------------------------===// 143 // Symbol utilities 144 //===--------------------------------------------------------------------===// 145 146 // Returns how many static values the given DAG `node` correspond to. 147 int getNodeValueCount(DagNode node); 148 149 private: 150 // Pattern instantiation location followed by the location of multiclass 151 // prototypes used. This is intended to be used as a whole to 152 // PrintFatalError() on errors. 153 ArrayRef<llvm::SMLoc> loc; 154 155 // Op's TableGen Record to wrapper object. 156 RecordOperatorMap *opMap; 157 158 // Handy wrapper for pattern being emitted. 159 Pattern pattern; 160 161 // Map for all bound symbols' info. 162 SymbolInfoMap symbolInfoMap; 163 164 // The next unused ID for newly created values. 165 unsigned nextValueId; 166 167 raw_ostream &os; 168 169 // Format contexts containing placeholder substitutions. 170 FmtContext fmtCtx; 171 172 // Number of op processed. 173 int opCounter = 0; 174 }; 175 } // end anonymous namespace 176 177 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 178 raw_ostream &os) 179 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 180 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 181 fmtCtx.withBuilder("rewriter"); 182 } 183 184 std::string PatternEmitter::handleConstantAttr(Attribute attr, 185 StringRef value) { 186 if (!attr.isConstBuildable()) 187 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 188 " does not have the 'constBuilderCall' field"); 189 190 // TODO(jpienaar): Verify the constants here 191 return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value); 192 } 193 194 // Helper function to match patterns. 195 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 196 Operator &op = tree.getDialectOp(opMap); 197 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 198 << op.getOperationName() << "' at depth " << depth 199 << '\n'); 200 201 int indent = 4 + 2 * depth; 202 os.indent(indent) << formatv( 203 "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n", 204 depth, op.getQualCppClassName()); 205 // Skip the operand matching at depth 0 as the pattern rewriter already does. 206 if (depth != 0) { 207 // Skip if there is no defining operation (e.g., arguments to function). 208 os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n", 209 depth); 210 } 211 if (tree.getNumArgs() != op.getNumArgs()) { 212 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 213 "pattern vs. {2} in definition", 214 op.getOperationName(), tree.getNumArgs(), 215 op.getNumArgs())); 216 } 217 218 // If the operand's name is set, set to that variable. 219 auto name = tree.getSymbol(); 220 if (!name.empty()) 221 os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth); 222 223 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 224 auto opArg = op.getArg(i); 225 226 // Handle nested DAG construct first 227 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 228 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 229 if (operand->isVariadic()) { 230 auto error = formatv("use nested DAG construct to match op {0}'s " 231 "variadic operand #{1} unsupported now", 232 op.getOperationName(), i); 233 PrintFatalError(loc, error); 234 } 235 } 236 os.indent(indent) << "{\n"; 237 238 os.indent(indent + 2) << formatv( 239 "auto *op{0} = " 240 "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n", 241 depth + 1, depth, i); 242 emitOpMatch(argTree, depth + 1); 243 os.indent(indent + 2) 244 << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); 245 os.indent(indent) << "}\n"; 246 continue; 247 } 248 249 // Next handle DAG leaf: operand or attribute 250 if (opArg.is<NamedTypeConstraint *>()) { 251 emitOperandMatch(tree, i, depth, indent); 252 } else if (opArg.is<NamedAttribute *>()) { 253 emitAttributeMatch(tree, i, depth, indent); 254 } else { 255 PrintFatalError(loc, "unhandled case when matching op"); 256 } 257 } 258 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 259 << op.getOperationName() << "' at depth " << depth 260 << '\n'); 261 } 262 263 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, 264 int indent) { 265 Operator &op = tree.getDialectOp(opMap); 266 auto *operand = op.getArg(index).get<NamedTypeConstraint *>(); 267 auto matcher = tree.getArgAsLeaf(index); 268 269 // If a constraint is specified, we need to generate C++ statements to 270 // check the constraint. 271 if (!matcher.isUnspecified()) { 272 if (!matcher.isOperandMatcher()) { 273 PrintFatalError( 274 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 275 op.getOperationName(), index + 1)); 276 } 277 278 // Only need to verify if the matcher's type is different from the one 279 // of op definition. 280 if (operand->constraint != matcher.getAsConstraint()) { 281 if (operand->isVariadic()) { 282 auto error = formatv( 283 "further constrain op {0}'s variadic operand #{1} unsupported now", 284 op.getOperationName(), index); 285 PrintFatalError(loc, error); 286 } 287 auto self = 288 formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()", 289 depth, index); 290 os.indent(indent) << "if (!(" 291 << tgfmt(matcher.getConditionTemplate(), 292 &fmtCtx.withSelf(self)) 293 << ")) return matchFailure();\n"; 294 } 295 } 296 297 // Capture the value 298 auto name = tree.getArgName(index); 299 if (!name.empty()) { 300 os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n", 301 name, depth, index); 302 } 303 } 304 305 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, 306 int indent) { 307 Operator &op = tree.getDialectOp(opMap); 308 auto *namedAttr = op.getArg(index).get<NamedAttribute *>(); 309 const auto &attr = namedAttr->attr; 310 311 os.indent(indent) << "{\n"; 312 indent += 2; 313 os.indent(indent) << formatv( 314 "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, 315 attr.getStorageType(), namedAttr->name); 316 317 // TODO(antiagainst): This should use getter method to avoid duplication. 318 if (attr.hasDefaultValueInitializer()) { 319 os.indent(indent) << "if (!tblgen_attr) tblgen_attr = " 320 << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 321 attr.getDefaultValueInitializer()) 322 << ";\n"; 323 } else if (attr.isOptional()) { 324 // For a missing attribute that is optional according to definition, we 325 // should just capture a mlir::Attribute() to signal the missing state. 326 // That is precisely what getAttr() returns on missing attributes. 327 } else { 328 os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n"; 329 } 330 331 auto matcher = tree.getArgAsLeaf(index); 332 if (!matcher.isUnspecified()) { 333 if (!matcher.isAttrMatcher()) { 334 PrintFatalError( 335 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 336 op.getOperationName(), index + 1)); 337 } 338 339 // If a constraint is specified, we need to generate C++ statements to 340 // check the constraint. 341 os.indent(indent) << "if (!(" 342 << tgfmt(matcher.getConditionTemplate(), 343 &fmtCtx.withSelf("tblgen_attr")) 344 << ")) return matchFailure();\n"; 345 } 346 347 // Capture the value 348 auto name = tree.getArgName(index); 349 if (!name.empty()) { 350 os.indent(indent) << formatv("{0} = tblgen_attr;\n", name); 351 } 352 353 indent -= 2; 354 os.indent(indent) << "}\n"; 355 } 356 357 void PatternEmitter::emitMatchLogic(DagNode tree) { 358 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 359 emitOpMatch(tree, 0); 360 361 for (auto &appliedConstraint : pattern.getConstraints()) { 362 auto &constraint = appliedConstraint.constraint; 363 auto &entities = appliedConstraint.entities; 364 365 auto condition = constraint.getConditionTemplate(); 366 auto cmd = "if (!({0})) return matchFailure();\n"; 367 368 if (isa<TypeConstraint>(constraint)) { 369 auto self = formatv("({0}->getType())", 370 symbolInfoMap.getValueAndRangeUse(entities.front())); 371 os.indent(4) << formatv(cmd, 372 tgfmt(condition, &fmtCtx.withSelf(self.str()))); 373 } else if (isa<AttrConstraint>(constraint)) { 374 PrintFatalError( 375 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 376 } else { 377 // TODO(b/138794486): replace formatv arguments with the exact specified 378 // args. 379 if (entities.size() > 4) { 380 PrintFatalError(loc, "only support up to 4-entity constraints now"); 381 } 382 SmallVector<std::string, 4> names; 383 int i = 0; 384 for (int e = entities.size(); i < e; ++i) 385 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 386 std::string self = appliedConstraint.self; 387 if (!self.empty()) 388 self = symbolInfoMap.getValueAndRangeUse(self); 389 for (; i < 4; ++i) 390 names.push_back("<unused>"); 391 os.indent(4) << formatv(cmd, 392 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 393 names[1], names[2], names[3])); 394 } 395 } 396 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 397 } 398 399 void PatternEmitter::collectOps(DagNode tree, 400 llvm::SmallPtrSetImpl<const Operator *> &ops) { 401 // Check if this tree is an operation. 402 if (tree.isOperation()) { 403 const Operator &op = tree.getDialectOp(opMap); 404 LLVM_DEBUG(llvm::dbgs() 405 << "found operation " << op.getOperationName() << '\n'); 406 ops.insert(&op); 407 } 408 409 // Recurse the arguments of the tree. 410 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 411 if (auto child = tree.getArgAsNestedDag(i)) 412 collectOps(child, ops); 413 } 414 415 void PatternEmitter::emit(StringRef rewriteName) { 416 // Get the DAG tree for the source pattern. 417 DagNode sourceTree = pattern.getSourcePattern(); 418 419 const Operator &rootOp = pattern.getSourceRootOp(); 420 auto rootName = rootOp.getOperationName(); 421 422 // Collect the set of result operations. 423 llvm::SmallPtrSet<const Operator *, 4> resultOps; 424 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 425 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 426 collectOps(pattern.getResultPattern(i), resultOps); 427 } 428 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 429 430 // Emit RewritePattern for Pattern. 431 auto locs = pattern.getLocation(); 432 os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", 433 make_range(locs.rbegin(), locs.rend())); 434 os << formatv(R"(struct {0} : public RewritePattern { 435 {0}(MLIRContext *context) 436 : RewritePattern("{1}", {{)", 437 rewriteName, rootName); 438 // Sort result operators by name. 439 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 440 resultOps.end()); 441 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 442 return lhs->getOperationName() < rhs->getOperationName(); 443 }); 444 interleaveComma(sortedResultOps, os, [&](const Operator *op) { 445 os << '"' << op->getOperationName() << '"'; 446 }); 447 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 448 449 // Emit matchAndRewrite() function. 450 os << R"( 451 PatternMatchResult matchAndRewrite(Operation *op0, 452 PatternRewriter &rewriter) const override { 453 )"; 454 455 // Register all symbols bound in the source pattern. 456 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 457 458 LLVM_DEBUG( 459 llvm::dbgs() << "start creating local variables for capturing matches\n"); 460 os.indent(4) << "// Variables for capturing values and attributes used for " 461 "creating ops\n"; 462 // Create local variables for storing the arguments and results bound 463 // to symbols. 464 for (const auto &symbolInfoPair : symbolInfoMap) { 465 StringRef symbol = symbolInfoPair.getKey(); 466 auto &info = symbolInfoPair.getValue(); 467 os.indent(4) << info.getVarDecl(symbol); 468 } 469 // TODO(jpienaar): capture ops with consistent numbering so that it can be 470 // reused for fused loc. 471 os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n", 472 pattern.getSourcePattern().getNumOps()); 473 LLVM_DEBUG( 474 llvm::dbgs() << "done creating local variables for capturing matches\n"); 475 476 os.indent(4) << "// Match\n"; 477 os.indent(4) << "tblgen_ops[0] = op0;\n"; 478 emitMatchLogic(sourceTree); 479 os << "\n"; 480 481 os.indent(4) << "// Rewrite\n"; 482 emitRewriteLogic(); 483 484 os.indent(4) << "return matchSuccess();\n"; 485 os << " };\n"; 486 os << "};\n"; 487 } 488 489 void PatternEmitter::emitRewriteLogic() { 490 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 491 const Operator &rootOp = pattern.getSourceRootOp(); 492 int numExpectedResults = rootOp.getNumResults(); 493 int numResultPatterns = pattern.getNumResultPatterns(); 494 495 // First register all symbols bound to ops generated in result patterns. 496 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 497 498 // Only the last N static values generated are used to replace the matched 499 // root N-result op. We need to calculate the starting index (of the results 500 // of the matched op) each result pattern is to replace. 501 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 502 // If we don't need to replace any value at all, set the replacement starting 503 // index as the number of result patterns so we skip all of them when trying 504 // to replace the matched op's results. 505 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 506 for (int i = numResultPatterns - 1; i >= 0; --i) { 507 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 508 offsets[i] = offsets[i + 1] - numValues; 509 if (offsets[i] == 0) { 510 if (replStartIndex == -1) 511 replStartIndex = i; 512 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 513 auto error = formatv( 514 "cannot use the same multi-result op '{0}' to generate both " 515 "auxiliary values and values to be used for replacing the matched op", 516 pattern.getResultPattern(i).getSymbol()); 517 PrintFatalError(loc, error); 518 } 519 } 520 521 if (offsets.front() > 0) { 522 const char error[] = "no enough values generated to replace the matched op"; 523 PrintFatalError(loc, error); 524 } 525 526 os.indent(4) << "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n"; 527 os.indent(4) << "auto loc = rewriter.getFusedLoc({"; 528 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 529 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 530 } 531 os << "}); (void)loc;\n"; 532 533 // Process auxiliary result patterns. 534 for (int i = 0; i < replStartIndex; ++i) { 535 DagNode resultTree = pattern.getResultPattern(i); 536 auto val = handleResultPattern(resultTree, offsets[i], 0); 537 // Normal op creation will be streamed to `os` by the above call; but 538 // NativeCodeCall will only be materialized to `os` if it is used. Here 539 // we are handling auxiliary patterns so we want the side effect even if 540 // NativeCodeCall is not replacing matched root op's results. 541 if (resultTree.isNativeCodeCall()) 542 os.indent(4) << val << ";\n"; 543 } 544 545 // Process replacement result patterns. 546 os.indent(4) << "SmallVector<Value *, 4> tblgen_values;"; 547 for (int i = replStartIndex; i < numResultPatterns; ++i) { 548 DagNode resultTree = pattern.getResultPattern(i); 549 auto val = handleResultPattern(resultTree, offsets[i], 0); 550 os.indent(4) << "\n"; 551 // Resolve each symbol for all range use so that we can loop over them. 552 os << symbolInfoMap.getAllRangeUse( 553 val, " for (auto *v : {0}) tblgen_values.push_back(v);", "\n"); 554 } 555 os.indent(4) << "\n"; 556 os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n"; 557 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 558 } 559 560 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 561 return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++); 562 } 563 564 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 565 int resultIndex, int depth) { 566 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 567 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 568 LLVM_DEBUG(llvm::dbgs() << '\n'); 569 570 if (resultTree.isNativeCodeCall()) { 571 auto symbol = handleReplaceWithNativeCodeCall(resultTree); 572 symbolInfoMap.bindValue(symbol); 573 return symbol; 574 } 575 576 if (resultTree.isReplaceWithValue()) { 577 return handleReplaceWithValue(resultTree); 578 } 579 580 // Normal op creation. 581 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 582 if (resultTree.getSymbol().empty()) { 583 // This is an op not explicitly bound to a symbol in the rewrite rule. 584 // Register the auto-generated symbol for it. 585 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 586 } 587 return symbol; 588 } 589 590 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { 591 assert(tree.isReplaceWithValue()); 592 593 if (tree.getNumArgs() != 1) { 594 PrintFatalError( 595 loc, "replaceWithValue directive must take exactly one argument"); 596 } 597 598 if (!tree.getSymbol().empty()) { 599 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 600 } 601 602 return tree.getArgName(0); 603 } 604 605 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 606 StringRef patArgName) { 607 if (leaf.isConstantAttr()) { 608 auto constAttr = leaf.getAsConstantAttr(); 609 return handleConstantAttr(constAttr.getAttribute(), 610 constAttr.getConstantValue()); 611 } 612 if (leaf.isEnumAttrCase()) { 613 auto enumCase = leaf.getAsEnumAttrCase(); 614 if (enumCase.isStrCase()) 615 return handleConstantAttr(enumCase, enumCase.getSymbol()); 616 // This is an enum case backed by an IntegerAttr. We need to get its value 617 // to build the constant. 618 std::string val = std::to_string(enumCase.getValue()); 619 return handleConstantAttr(enumCase, val); 620 } 621 622 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 623 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 624 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 625 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 626 << "' (via symbol ref)\n"); 627 return argName; 628 } 629 if (leaf.isNativeCodeCall()) { 630 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 631 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 632 << "' (via NativeCodeCall)\n"); 633 return repl; 634 } 635 PrintFatalError(loc, "unhandled case when rewriting op"); 636 } 637 638 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { 639 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 640 LLVM_DEBUG(tree.print(llvm::dbgs())); 641 LLVM_DEBUG(llvm::dbgs() << '\n'); 642 643 auto fmt = tree.getNativeCodeTemplate(); 644 // TODO(b/138794486): replace formatv arguments with the exact specified args. 645 SmallVector<std::string, 8> attrs(8); 646 if (tree.getNumArgs() > 8) { 647 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 648 Twine(tree.getNumArgs())); 649 } 650 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 651 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 652 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argment #" << i 653 << " replacement: " << attrs[i] << "\n"); 654 } 655 return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4], 656 attrs[5], attrs[6], attrs[7]); 657 } 658 659 int PatternEmitter::getNodeValueCount(DagNode node) { 660 if (node.isOperation()) { 661 // If the op is bound to a symbol in the rewrite rule, query its result 662 // count from the symbol info map. 663 auto symbol = node.getSymbol(); 664 if (!symbol.empty()) { 665 return symbolInfoMap.getStaticValueCount(symbol); 666 } 667 // Otherwise this is an unbound op; we will use all its results. 668 return pattern.getDialectOp(node).getNumResults(); 669 } 670 // TODO(antiagainst): This considers all NativeCodeCall as returning one 671 // value. Enhance if multi-value ones are needed. 672 return 1; 673 } 674 675 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 676 int depth) { 677 Operator &resultOp = tree.getDialectOp(opMap); 678 auto numOpArgs = resultOp.getNumArgs(); 679 680 if (numOpArgs != tree.getNumArgs()) { 681 PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " 682 "{1} in pattern vs. {2} in definition", 683 resultOp.getOperationName(), tree.getNumArgs(), 684 numOpArgs)); 685 } 686 687 // A map to collect all nested DAG child nodes' names, with operand index as 688 // the key. This includes both bound and unbound child nodes. 689 llvm::DenseMap<unsigned, std::string> childNodeNames; 690 691 // First go through all the child nodes who are nested DAG constructs to 692 // create ops for them and remember the symbol names for them, so that we can 693 // use the results in the current node. This happens in a recursive manner. 694 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 695 if (auto child = tree.getArgAsNestedDag(i)) { 696 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 697 } 698 } 699 700 // The name of the local variable holding this op. 701 std::string valuePackName; 702 // The symbol for holding the result of this pattern. Note that the result of 703 // this pattern is not necessarily the same as the variable created by this 704 // pattern because we can use `__N` suffix to refer only a specific result if 705 // the generated op is a multi-result op. 706 std::string resultValue; 707 if (tree.getSymbol().empty()) { 708 // No symbol is explicitly bound to this op in the pattern. Generate a 709 // unique name. 710 valuePackName = resultValue = getUniqueSymbol(&resultOp); 711 } else { 712 resultValue = tree.getSymbol(); 713 // Strip the index to get the name for the value pack and use it to name the 714 // local variable for the op. 715 valuePackName = SymbolInfoMap::getValuePackName(resultValue); 716 } 717 718 // Create the local variable for this op. 719 os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(), 720 valuePackName); 721 os.indent(4) << "{\n"; 722 723 // Now prepare operands used for building this op: 724 // * If the operand is non-variadic, we create a `Value*` local variable. 725 // * If the operand is variadic, we create a `SmallVector<Value*>` local 726 // variable. 727 728 int argIndex = 0; // The current index to this op's ODS argument 729 int valueIndex = 0; // An index for uniquing local variable names. 730 for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) { 731 const auto &operand = resultOp.getOperand(argIndex); 732 std::string varName; 733 if (operand.isVariadic()) { 734 varName = formatv("tblgen_values_{0}", valueIndex++); 735 os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName); 736 std::string range; 737 if (tree.isNestedDagArg(argIndex)) { 738 range = childNodeNames[argIndex]; 739 } else { 740 range = tree.getArgName(argIndex); 741 } 742 // Resolve the symbol for all range use so that we have a uniform way of 743 // capturing the values. 744 range = symbolInfoMap.getValueAndRangeUse(range); 745 os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range, 746 varName); 747 } else { 748 varName = formatv("tblgen_value_{0}", valueIndex++); 749 os.indent(6) << formatv("Value *{0} = ", varName); 750 if (tree.isNestedDagArg(argIndex)) { 751 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 752 } else { 753 DagLeaf leaf = tree.getArgAsLeaf(argIndex); 754 auto symbol = 755 symbolInfoMap.getValueAndRangeUse(tree.getArgName(argIndex)); 756 if (leaf.isNativeCodeCall()) { 757 os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)); 758 } else { 759 os << symbol; 760 } 761 } 762 os << ";\n"; 763 } 764 765 // Update to use the newly created local variable for building the op later. 766 childNodeNames[argIndex] = varName; 767 } 768 769 // Then we create the builder call. 770 771 // Right now we don't have general type inference in MLIR. Except a few 772 // special cases listed below, we need to supply types for all results 773 // when building an op. 774 bool isSameOperandsAndResultType = 775 resultOp.hasTrait("OpTrait::SameOperandsAndResultType"); 776 bool isBroadcastable = 777 resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult"); 778 bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType"); 779 bool usePartialResults = valuePackName != resultValue; 780 781 if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr || 782 usePartialResults || depth > 0 || resultIndex < 0) { 783 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, 784 resultOp.getQualCppClassName()); 785 } else { 786 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 787 // generated from the source pattern root op. Then we can use the source 788 // pattern's value types to determine the value type of the generated op 789 // here. 790 791 // We need to specify the types for all results. 792 int numResults = resultOp.getNumResults(); 793 if (numResults != 0) { 794 os.indent(6) << "tblgen_types.clear();\n"; 795 for (int i = 0; i < numResults; ++i) { 796 os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) " 797 "tblgen_types.push_back(v->getType());\n", 798 resultIndex + i); 799 } 800 } 801 802 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, 803 resultOp.getQualCppClassName()); 804 if (numResults != 0) 805 os.indent(6) << ", tblgen_types"; 806 } 807 808 // Add operands for the builder all. 809 for (int i = 0; i < argIndex; ++i) { 810 const auto &operand = resultOp.getOperand(i); 811 // Start each operand on its own line. 812 (os << ",\n").indent(8); 813 if (!operand.name.empty()) { 814 os << "/*" << operand.name << "=*/"; 815 } 816 os << childNodeNames[i]; 817 // TODO(jpienaar): verify types 818 } 819 820 // Add attributes for the builder call. 821 for (; argIndex != numOpArgs; ++argIndex) { 822 // Start each attribute on its own line. 823 (os << ",\n").indent(8); 824 // The argument in the op definition. 825 auto opArgName = resultOp.getArgName(argIndex); 826 if (auto subTree = tree.getArgAsNestedDag(argIndex)) { 827 if (!subTree.isNativeCodeCall()) 828 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 829 "for creating attribute"); 830 os << formatv("/*{0}=*/{1}", opArgName, 831 handleReplaceWithNativeCodeCall(subTree)); 832 } else { 833 auto leaf = tree.getArgAsLeaf(argIndex); 834 // The argument in the result DAG pattern. 835 auto patArgName = tree.getArgName(argIndex); 836 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 837 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 838 auto argument = resultOp.getArg(argIndex); 839 if (!argument.is<NamedAttribute *>()) 840 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 841 if (!patArgName.empty()) 842 os << "/*" << patArgName << "=*/"; 843 } else { 844 os << "/*" << opArgName << "=*/"; 845 } 846 os << handleOpArgument(leaf, patArgName); 847 } 848 } 849 os << "\n );\n"; 850 os.indent(4) << "}\n"; 851 852 return resultValue; 853 } 854 855 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 856 emitSourceFileHeader("Rewriters", os); 857 858 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 859 auto numPatterns = patterns.size(); 860 861 // We put the map here because it can be shared among multiple patterns. 862 RecordOperatorMap recordOpMap; 863 864 std::vector<std::string> rewriterNames; 865 rewriterNames.reserve(numPatterns); 866 867 std::string baseRewriterName = "GeneratedConvert"; 868 int rewriterIndex = 0; 869 870 for (Record *p : patterns) { 871 std::string name; 872 if (p->isAnonymous()) { 873 // If no name is provided, ensure unique rewriter names simply by 874 // appending unique suffix. 875 name = baseRewriterName + llvm::utostr(rewriterIndex++); 876 } else { 877 name = p->getName(); 878 } 879 LLVM_DEBUG(llvm::dbgs() 880 << "=== start generating pattern '" << name << "' ===\n"); 881 PatternEmitter(p, &recordOpMap, os).emit(name); 882 LLVM_DEBUG(llvm::dbgs() 883 << "=== done generating pattern '" << name << "' ===\n"); 884 rewriterNames.push_back(std::move(name)); 885 } 886 887 // Emit function to add the generated matchers to the pattern list. 888 os << "void populateWithGenerated(MLIRContext *context, " 889 << "OwningRewritePatternList *patterns) {\n"; 890 for (const auto &name : rewriterNames) { 891 os << " patterns->insert<" << name << ">(context);\n"; 892 } 893 os << "}\n"; 894 } 895 896 static mlir::GenRegistration 897 genRewriters("gen-rewriters", "Generate pattern rewriters", 898 [](const RecordKeeper &records, raw_ostream &os) { 899 emitRewriters(records, os); 900 return false; 901 }); 902