1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Support/STLExtras.h" 23 #include "mlir/TableGen/Attribute.h" 24 #include "mlir/TableGen/Format.h" 25 #include "mlir/TableGen/GenInfo.h" 26 #include "mlir/TableGen/Operator.h" 27 #include "mlir/TableGen/Pattern.h" 28 #include "mlir/TableGen/Predicate.h" 29 #include "mlir/TableGen/Type.h" 30 #include "llvm/ADT/StringExtras.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/FormatAdapters.h" 35 #include "llvm/Support/PrettyStackTrace.h" 36 #include "llvm/Support/Signals.h" 37 #include "llvm/TableGen/Error.h" 38 #include "llvm/TableGen/Main.h" 39 #include "llvm/TableGen/Record.h" 40 #include "llvm/TableGen/TableGenBackend.h" 41 42 using namespace llvm; 43 using namespace mlir; 44 using namespace mlir::tblgen; 45 46 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 47 48 namespace llvm { 49 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 50 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 51 raw_ostream &os, StringRef style) { 52 os << v.first << ":" << v.second; 53 } 54 }; 55 } // end namespace llvm 56 57 //===----------------------------------------------------------------------===// 58 // PatternEmitter 59 //===----------------------------------------------------------------------===// 60 61 namespace { 62 class PatternEmitter { 63 public: 64 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 65 66 // Emits the mlir::RewritePattern struct named `rewriteName`. 67 void emit(StringRef rewriteName); 68 69 private: 70 // Emits the code for matching ops. 71 void emitMatchLogic(DagNode tree); 72 73 // Emits the code for rewriting ops. 74 void emitRewriteLogic(); 75 76 //===--------------------------------------------------------------------===// 77 // Match utilities 78 //===--------------------------------------------------------------------===// 79 80 // Emits C++ statements for matching the op constrained by the given DAG 81 // `tree`. 82 void emitOpMatch(DagNode tree, int depth); 83 84 // Emits C++ statements for matching the `argIndex`-th argument of the given 85 // DAG `tree` as an operand. 86 void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent); 87 88 // Emits C++ statements for matching the `argIndex`-th argument of the given 89 // DAG `tree` as an attribute. 90 void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent); 91 92 //===--------------------------------------------------------------------===// 93 // Rewrite utilities 94 //===--------------------------------------------------------------------===// 95 96 // The entry point for handling a result pattern rooted at `resultTree`. This 97 // method dispatches to concrete handlers according to `resultTree`'s kind and 98 // returns a symbol representing the whole value pack. Callers are expected to 99 // further resolve the symbol according to the specific use case. 100 // 101 // `depth` is the nesting level of `resultTree`; 0 means top-level result 102 // pattern. For top-level result pattern, `resultIndex` indicates which result 103 // of the matched root op this pattern is intended to replace, which can be 104 // used to deduce the result type of the op generated from this result 105 // pattern. 106 std::string handleResultPattern(DagNode resultTree, int resultIndex, 107 int depth); 108 109 // Emits the C++ statement to replace the matched DAG with a value built via 110 // calling native C++ code. 111 std::string handleReplaceWithNativeCodeCall(DagNode resultTree); 112 113 // Returns the C++ expression referencing the old value serving as the 114 // replacement. 115 std::string handleReplaceWithValue(DagNode tree); 116 117 // Emits the C++ statement to build a new op out of the given DAG `tree` and 118 // returns the variable name that this op is assigned to. If the root op in 119 // DAG `tree` has a specified name, the created op will be assigned to a 120 // variable of the given name. Otherwise, a unique name will be used as the 121 // result value name. 122 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 123 124 // Returns the C++ expression to construct a constant attribute of the given 125 // `value` for the given attribute kind `attr`. 126 std::string handleConstantAttr(Attribute attr, StringRef value); 127 128 // Returns the C++ expression to build an argument from the given DAG `leaf`. 129 // `patArgName` is used to bound the argument to the source pattern. 130 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 131 132 //===--------------------------------------------------------------------===// 133 // General utilities 134 //===--------------------------------------------------------------------===// 135 136 // Collects all of the operations within the given dag tree. 137 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 138 139 // Returns a unique symbol for a local variable of the given `op`. 140 std::string getUniqueSymbol(const Operator *op); 141 142 //===--------------------------------------------------------------------===// 143 // Symbol utilities 144 //===--------------------------------------------------------------------===// 145 146 // Returns how many static values the given DAG `node` correspond to. 147 int getNodeValueCount(DagNode node); 148 149 private: 150 // Pattern instantiation location followed by the location of multiclass 151 // prototypes used. This is intended to be used as a whole to 152 // PrintFatalError() on errors. 153 ArrayRef<llvm::SMLoc> loc; 154 155 // Op's TableGen Record to wrapper object. 156 RecordOperatorMap *opMap; 157 158 // Handy wrapper for pattern being emitted. 159 Pattern pattern; 160 161 // Map for all bound symbols' info. 162 SymbolInfoMap symbolInfoMap; 163 164 // The next unused ID for newly created values. 165 unsigned nextValueId; 166 167 raw_ostream &os; 168 169 // Format contexts containing placeholder substitutions. 170 FmtContext fmtCtx; 171 172 // Number of op processed. 173 int opCounter = 0; 174 }; 175 } // end anonymous namespace 176 177 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 178 raw_ostream &os) 179 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 180 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 181 fmtCtx.withBuilder("rewriter"); 182 } 183 184 std::string PatternEmitter::handleConstantAttr(Attribute attr, 185 StringRef value) { 186 if (!attr.isConstBuildable()) 187 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 188 " does not have the 'constBuilderCall' field"); 189 190 // TODO(jpienaar): Verify the constants here 191 return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value); 192 } 193 194 // Helper function to match patterns. 195 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 196 Operator &op = tree.getDialectOp(opMap); 197 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 198 << op.getOperationName() << "' at depth " << depth 199 << '\n'); 200 201 int indent = 4 + 2 * depth; 202 os.indent(indent) << formatv( 203 "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n", 204 depth, op.getQualCppClassName()); 205 // Skip the operand matching at depth 0 as the pattern rewriter already does. 206 if (depth != 0) { 207 // Skip if there is no defining operation (e.g., arguments to function). 208 os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n", 209 depth); 210 } 211 if (tree.getNumArgs() != op.getNumArgs()) { 212 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 213 "pattern vs. {2} in definition", 214 op.getOperationName(), tree.getNumArgs(), 215 op.getNumArgs())); 216 } 217 218 // If the operand's name is set, set to that variable. 219 auto name = tree.getSymbol(); 220 if (!name.empty()) 221 os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth); 222 223 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 224 auto opArg = op.getArg(i); 225 226 // Handle nested DAG construct first 227 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 228 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 229 if (operand->isVariadic()) { 230 auto error = formatv("use nested DAG construct to match op {0}'s " 231 "variadic operand #{1} unsupported now", 232 op.getOperationName(), i); 233 PrintFatalError(loc, error); 234 } 235 } 236 os.indent(indent) << "{\n"; 237 238 os.indent(indent + 2) << formatv( 239 "auto *op{0} = " 240 "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n", 241 depth + 1, depth, i); 242 emitOpMatch(argTree, depth + 1); 243 os.indent(indent + 2) 244 << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); 245 os.indent(indent) << "}\n"; 246 continue; 247 } 248 249 // Next handle DAG leaf: operand or attribute 250 if (opArg.is<NamedTypeConstraint *>()) { 251 emitOperandMatch(tree, i, depth, indent); 252 } else if (opArg.is<NamedAttribute *>()) { 253 emitAttributeMatch(tree, i, depth, indent); 254 } else { 255 PrintFatalError(loc, "unhandled case when matching op"); 256 } 257 } 258 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 259 << op.getOperationName() << "' at depth " << depth 260 << '\n'); 261 } 262 263 void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth, 264 int indent) { 265 Operator &op = tree.getDialectOp(opMap); 266 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>(); 267 auto matcher = tree.getArgAsLeaf(argIndex); 268 269 // If a constraint is specified, we need to generate C++ statements to 270 // check the constraint. 271 if (!matcher.isUnspecified()) { 272 if (!matcher.isOperandMatcher()) { 273 PrintFatalError( 274 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 275 op.getOperationName(), argIndex + 1)); 276 } 277 278 // Only need to verify if the matcher's type is different from the one 279 // of op definition. 280 if (operand->constraint != matcher.getAsConstraint()) { 281 if (operand->isVariadic()) { 282 auto error = formatv( 283 "further constrain op {0}'s variadic operand #{1} unsupported now", 284 op.getOperationName(), argIndex); 285 PrintFatalError(loc, error); 286 } 287 auto self = 288 formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()", 289 depth, argIndex); 290 os.indent(indent) << "if (!(" 291 << tgfmt(matcher.getConditionTemplate(), 292 &fmtCtx.withSelf(self)) 293 << ")) return matchFailure();\n"; 294 } 295 } 296 297 // Capture the value 298 auto name = tree.getArgName(argIndex); 299 if (!name.empty()) { 300 // We need to subtract the number of attributes before this operand to get 301 // the index in the operand list. 302 auto numPrevAttrs = std::count_if( 303 op.arg_begin(), op.arg_begin() + argIndex, 304 [](const Argument &arg) { return arg.is<NamedAttribute *>(); }); 305 306 os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n", 307 name, depth, argIndex - numPrevAttrs); 308 } 309 } 310 311 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth, 312 int indent) { 313 Operator &op = tree.getDialectOp(opMap); 314 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>(); 315 const auto &attr = namedAttr->attr; 316 317 os.indent(indent) << "{\n"; 318 indent += 2; 319 os.indent(indent) << formatv( 320 "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, 321 attr.getStorageType(), namedAttr->name); 322 323 // TODO(antiagainst): This should use getter method to avoid duplication. 324 if (attr.hasDefaultValueInitializer()) { 325 os.indent(indent) << "if (!tblgen_attr) tblgen_attr = " 326 << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 327 attr.getDefaultValueInitializer()) 328 << ";\n"; 329 } else if (attr.isOptional()) { 330 // For a missing attribute that is optional according to definition, we 331 // should just capture a mlir::Attribute() to signal the missing state. 332 // That is precisely what getAttr() returns on missing attributes. 333 } else { 334 os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n"; 335 } 336 337 auto matcher = tree.getArgAsLeaf(argIndex); 338 if (!matcher.isUnspecified()) { 339 if (!matcher.isAttrMatcher()) { 340 PrintFatalError( 341 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 342 op.getOperationName(), argIndex + 1)); 343 } 344 345 // If a constraint is specified, we need to generate C++ statements to 346 // check the constraint. 347 os.indent(indent) << "if (!(" 348 << tgfmt(matcher.getConditionTemplate(), 349 &fmtCtx.withSelf("tblgen_attr")) 350 << ")) return matchFailure();\n"; 351 } 352 353 // Capture the value 354 auto name = tree.getArgName(argIndex); 355 if (!name.empty()) { 356 os.indent(indent) << formatv("{0} = tblgen_attr;\n", name); 357 } 358 359 indent -= 2; 360 os.indent(indent) << "}\n"; 361 } 362 363 void PatternEmitter::emitMatchLogic(DagNode tree) { 364 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 365 emitOpMatch(tree, 0); 366 367 for (auto &appliedConstraint : pattern.getConstraints()) { 368 auto &constraint = appliedConstraint.constraint; 369 auto &entities = appliedConstraint.entities; 370 371 auto condition = constraint.getConditionTemplate(); 372 auto cmd = "if (!({0})) return matchFailure();\n"; 373 374 if (isa<TypeConstraint>(constraint)) { 375 auto self = formatv("({0}->getType())", 376 symbolInfoMap.getValueAndRangeUse(entities.front())); 377 os.indent(4) << formatv(cmd, 378 tgfmt(condition, &fmtCtx.withSelf(self.str()))); 379 } else if (isa<AttrConstraint>(constraint)) { 380 PrintFatalError( 381 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 382 } else { 383 // TODO(b/138794486): replace formatv arguments with the exact specified 384 // args. 385 if (entities.size() > 4) { 386 PrintFatalError(loc, "only support up to 4-entity constraints now"); 387 } 388 SmallVector<std::string, 4> names; 389 int i = 0; 390 for (int e = entities.size(); i < e; ++i) 391 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 392 std::string self = appliedConstraint.self; 393 if (!self.empty()) 394 self = symbolInfoMap.getValueAndRangeUse(self); 395 for (; i < 4; ++i) 396 names.push_back("<unused>"); 397 os.indent(4) << formatv(cmd, 398 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 399 names[1], names[2], names[3])); 400 } 401 } 402 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 403 } 404 405 void PatternEmitter::collectOps(DagNode tree, 406 llvm::SmallPtrSetImpl<const Operator *> &ops) { 407 // Check if this tree is an operation. 408 if (tree.isOperation()) { 409 const Operator &op = tree.getDialectOp(opMap); 410 LLVM_DEBUG(llvm::dbgs() 411 << "found operation " << op.getOperationName() << '\n'); 412 ops.insert(&op); 413 } 414 415 // Recurse the arguments of the tree. 416 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 417 if (auto child = tree.getArgAsNestedDag(i)) 418 collectOps(child, ops); 419 } 420 421 void PatternEmitter::emit(StringRef rewriteName) { 422 // Get the DAG tree for the source pattern. 423 DagNode sourceTree = pattern.getSourcePattern(); 424 425 const Operator &rootOp = pattern.getSourceRootOp(); 426 auto rootName = rootOp.getOperationName(); 427 428 // Collect the set of result operations. 429 llvm::SmallPtrSet<const Operator *, 4> resultOps; 430 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 431 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 432 collectOps(pattern.getResultPattern(i), resultOps); 433 } 434 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 435 436 // Emit RewritePattern for Pattern. 437 auto locs = pattern.getLocation(); 438 os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", 439 make_range(locs.rbegin(), locs.rend())); 440 os << formatv(R"(struct {0} : public RewritePattern { 441 {0}(MLIRContext *context) 442 : RewritePattern("{1}", {{)", 443 rewriteName, rootName); 444 // Sort result operators by name. 445 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 446 resultOps.end()); 447 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 448 return lhs->getOperationName() < rhs->getOperationName(); 449 }); 450 interleaveComma(sortedResultOps, os, [&](const Operator *op) { 451 os << '"' << op->getOperationName() << '"'; 452 }); 453 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 454 455 // Emit matchAndRewrite() function. 456 os << R"( 457 PatternMatchResult matchAndRewrite(Operation *op0, 458 PatternRewriter &rewriter) const override { 459 )"; 460 461 // Register all symbols bound in the source pattern. 462 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 463 464 LLVM_DEBUG( 465 llvm::dbgs() << "start creating local variables for capturing matches\n"); 466 os.indent(4) << "// Variables for capturing values and attributes used for " 467 "creating ops\n"; 468 // Create local variables for storing the arguments and results bound 469 // to symbols. 470 for (const auto &symbolInfoPair : symbolInfoMap) { 471 StringRef symbol = symbolInfoPair.getKey(); 472 auto &info = symbolInfoPair.getValue(); 473 os.indent(4) << info.getVarDecl(symbol); 474 } 475 // TODO(jpienaar): capture ops with consistent numbering so that it can be 476 // reused for fused loc. 477 os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n", 478 pattern.getSourcePattern().getNumOps()); 479 LLVM_DEBUG( 480 llvm::dbgs() << "done creating local variables for capturing matches\n"); 481 482 os.indent(4) << "// Match\n"; 483 os.indent(4) << "tblgen_ops[0] = op0;\n"; 484 emitMatchLogic(sourceTree); 485 os << "\n"; 486 487 os.indent(4) << "// Rewrite\n"; 488 emitRewriteLogic(); 489 490 os.indent(4) << "return matchSuccess();\n"; 491 os << " };\n"; 492 os << "};\n"; 493 } 494 495 void PatternEmitter::emitRewriteLogic() { 496 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 497 const Operator &rootOp = pattern.getSourceRootOp(); 498 int numExpectedResults = rootOp.getNumResults(); 499 int numResultPatterns = pattern.getNumResultPatterns(); 500 501 // First register all symbols bound to ops generated in result patterns. 502 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 503 504 // Only the last N static values generated are used to replace the matched 505 // root N-result op. We need to calculate the starting index (of the results 506 // of the matched op) each result pattern is to replace. 507 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 508 // If we don't need to replace any value at all, set the replacement starting 509 // index as the number of result patterns so we skip all of them when trying 510 // to replace the matched op's results. 511 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 512 for (int i = numResultPatterns - 1; i >= 0; --i) { 513 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 514 offsets[i] = offsets[i + 1] - numValues; 515 if (offsets[i] == 0) { 516 if (replStartIndex == -1) 517 replStartIndex = i; 518 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 519 auto error = formatv( 520 "cannot use the same multi-result op '{0}' to generate both " 521 "auxiliary values and values to be used for replacing the matched op", 522 pattern.getResultPattern(i).getSymbol()); 523 PrintFatalError(loc, error); 524 } 525 } 526 527 if (offsets.front() > 0) { 528 const char error[] = "no enough values generated to replace the matched op"; 529 PrintFatalError(loc, error); 530 } 531 532 os.indent(4) << "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n"; 533 os.indent(4) << "auto loc = rewriter.getFusedLoc({"; 534 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 535 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 536 } 537 os << "}); (void)loc;\n"; 538 539 // Process auxiliary result patterns. 540 for (int i = 0; i < replStartIndex; ++i) { 541 DagNode resultTree = pattern.getResultPattern(i); 542 auto val = handleResultPattern(resultTree, offsets[i], 0); 543 // Normal op creation will be streamed to `os` by the above call; but 544 // NativeCodeCall will only be materialized to `os` if it is used. Here 545 // we are handling auxiliary patterns so we want the side effect even if 546 // NativeCodeCall is not replacing matched root op's results. 547 if (resultTree.isNativeCodeCall()) 548 os.indent(4) << val << ";\n"; 549 } 550 551 if (numExpectedResults == 0) { 552 assert(replStartIndex >= numResultPatterns && 553 "invalid auxiliary vs. replacement pattern division!"); 554 // No result to replace. Just erase the op. 555 os.indent(4) << "rewriter.eraseOp(op0);\n"; 556 } else { 557 // Process replacement result patterns. 558 os.indent(4) << "SmallVector<Value *, 4> tblgen_values;"; 559 for (int i = replStartIndex; i < numResultPatterns; ++i) { 560 DagNode resultTree = pattern.getResultPattern(i); 561 auto val = handleResultPattern(resultTree, offsets[i], 0); 562 os.indent(4) << "\n"; 563 // Resolve each symbol for all range use so that we can loop over them. 564 os << symbolInfoMap.getAllRangeUse( 565 val, " for (auto *v : {0}) {{ tblgen_values.push_back(v); }", 566 "\n"); 567 } 568 os.indent(4) << "\n"; 569 os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n"; 570 } 571 572 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 573 } 574 575 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 576 return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++); 577 } 578 579 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 580 int resultIndex, int depth) { 581 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 582 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 583 LLVM_DEBUG(llvm::dbgs() << '\n'); 584 585 if (resultTree.isNativeCodeCall()) { 586 auto symbol = handleReplaceWithNativeCodeCall(resultTree); 587 symbolInfoMap.bindValue(symbol); 588 return symbol; 589 } 590 591 if (resultTree.isReplaceWithValue()) { 592 return handleReplaceWithValue(resultTree); 593 } 594 595 // Normal op creation. 596 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 597 if (resultTree.getSymbol().empty()) { 598 // This is an op not explicitly bound to a symbol in the rewrite rule. 599 // Register the auto-generated symbol for it. 600 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 601 } 602 return symbol; 603 } 604 605 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { 606 assert(tree.isReplaceWithValue()); 607 608 if (tree.getNumArgs() != 1) { 609 PrintFatalError( 610 loc, "replaceWithValue directive must take exactly one argument"); 611 } 612 613 if (!tree.getSymbol().empty()) { 614 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 615 } 616 617 return tree.getArgName(0); 618 } 619 620 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 621 StringRef patArgName) { 622 if (leaf.isConstantAttr()) { 623 auto constAttr = leaf.getAsConstantAttr(); 624 return handleConstantAttr(constAttr.getAttribute(), 625 constAttr.getConstantValue()); 626 } 627 if (leaf.isEnumAttrCase()) { 628 auto enumCase = leaf.getAsEnumAttrCase(); 629 if (enumCase.isStrCase()) 630 return handleConstantAttr(enumCase, enumCase.getSymbol()); 631 // This is an enum case backed by an IntegerAttr. We need to get its value 632 // to build the constant. 633 std::string val = std::to_string(enumCase.getValue()); 634 return handleConstantAttr(enumCase, val); 635 } 636 637 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 638 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 639 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 640 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 641 << "' (via symbol ref)\n"); 642 return argName; 643 } 644 if (leaf.isNativeCodeCall()) { 645 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 646 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 647 << "' (via NativeCodeCall)\n"); 648 return repl; 649 } 650 PrintFatalError(loc, "unhandled case when rewriting op"); 651 } 652 653 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { 654 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 655 LLVM_DEBUG(tree.print(llvm::dbgs())); 656 LLVM_DEBUG(llvm::dbgs() << '\n'); 657 658 auto fmt = tree.getNativeCodeTemplate(); 659 // TODO(b/138794486): replace formatv arguments with the exact specified args. 660 SmallVector<std::string, 8> attrs(8); 661 if (tree.getNumArgs() > 8) { 662 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 663 Twine(tree.getNumArgs())); 664 } 665 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 666 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 667 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argment #" << i 668 << " replacement: " << attrs[i] << "\n"); 669 } 670 return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4], 671 attrs[5], attrs[6], attrs[7]); 672 } 673 674 int PatternEmitter::getNodeValueCount(DagNode node) { 675 if (node.isOperation()) { 676 // If the op is bound to a symbol in the rewrite rule, query its result 677 // count from the symbol info map. 678 auto symbol = node.getSymbol(); 679 if (!symbol.empty()) { 680 return symbolInfoMap.getStaticValueCount(symbol); 681 } 682 // Otherwise this is an unbound op; we will use all its results. 683 return pattern.getDialectOp(node).getNumResults(); 684 } 685 // TODO(antiagainst): This considers all NativeCodeCall as returning one 686 // value. Enhance if multi-value ones are needed. 687 return 1; 688 } 689 690 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 691 int depth) { 692 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); 693 LLVM_DEBUG(tree.print(llvm::dbgs())); 694 LLVM_DEBUG(llvm::dbgs() << '\n'); 695 696 Operator &resultOp = tree.getDialectOp(opMap); 697 auto numOpArgs = resultOp.getNumArgs(); 698 699 if (numOpArgs != tree.getNumArgs()) { 700 PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " 701 "{1} in pattern vs. {2} in definition", 702 resultOp.getOperationName(), tree.getNumArgs(), 703 numOpArgs)); 704 } 705 706 // A map to collect all nested DAG child nodes' names, with operand index as 707 // the key. This includes both bound and unbound child nodes. 708 llvm::DenseMap<unsigned, std::string> childNodeNames; 709 710 // First go through all the child nodes who are nested DAG constructs to 711 // create ops for them and remember the symbol names for them, so that we can 712 // use the results in the current node. This happens in a recursive manner. 713 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 714 if (auto child = tree.getArgAsNestedDag(i)) { 715 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 716 } 717 } 718 719 // The name of the local variable holding this op. 720 std::string valuePackName; 721 // The symbol for holding the result of this pattern. Note that the result of 722 // this pattern is not necessarily the same as the variable created by this 723 // pattern because we can use `__N` suffix to refer only a specific result if 724 // the generated op is a multi-result op. 725 std::string resultValue; 726 if (tree.getSymbol().empty()) { 727 // No symbol is explicitly bound to this op in the pattern. Generate a 728 // unique name. 729 valuePackName = resultValue = getUniqueSymbol(&resultOp); 730 } else { 731 resultValue = tree.getSymbol(); 732 // Strip the index to get the name for the value pack and use it to name the 733 // local variable for the op. 734 valuePackName = SymbolInfoMap::getValuePackName(resultValue); 735 } 736 737 // Create the local variable for this op. 738 os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(), 739 valuePackName); 740 os.indent(4) << "{\n"; 741 742 // Now prepare operands used for building this op: 743 // * If the operand is non-variadic, we create a `Value*` local variable. 744 // * If the operand is variadic, we create a `SmallVector<Value*>` local 745 // variable. 746 747 int valueIndex = 0; // An index for uniquing local variable names. 748 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 749 const auto *operand = 750 resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>(); 751 if (!operand) { 752 // We do not need special handling for attributes. 753 continue; 754 } 755 std::string varName; 756 if (operand->isVariadic()) { 757 varName = formatv("tblgen_values_{0}", valueIndex++); 758 os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName); 759 std::string range; 760 if (tree.isNestedDagArg(argIndex)) { 761 range = childNodeNames[argIndex]; 762 } else { 763 range = tree.getArgName(argIndex); 764 } 765 // Resolve the symbol for all range use so that we have a uniform way of 766 // capturing the values. 767 range = symbolInfoMap.getValueAndRangeUse(range); 768 os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range, 769 varName); 770 } else { 771 varName = formatv("tblgen_value_{0}", valueIndex++); 772 os.indent(6) << formatv("Value *{0} = ", varName); 773 if (tree.isNestedDagArg(argIndex)) { 774 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 775 } else { 776 DagLeaf leaf = tree.getArgAsLeaf(argIndex); 777 auto symbol = 778 symbolInfoMap.getValueAndRangeUse(tree.getArgName(argIndex)); 779 if (leaf.isNativeCodeCall()) { 780 os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)); 781 } else { 782 os << symbol; 783 } 784 } 785 os << ";\n"; 786 } 787 788 // Update to use the newly created local variable for building the op later. 789 childNodeNames[argIndex] = varName; 790 } 791 792 // Then we create the builder call. 793 794 // Right now we don't have general type inference in MLIR. Except a few 795 // special cases listed below, we need to supply types for all results 796 // when building an op. 797 bool isSameOperandsAndResultType = 798 resultOp.hasTrait("OpTrait::SameOperandsAndResultType"); 799 bool isBroadcastable = 800 resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult"); 801 bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType"); 802 bool usePartialResults = valuePackName != resultValue; 803 804 if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr || 805 usePartialResults || depth > 0 || resultIndex < 0) { 806 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, 807 resultOp.getQualCppClassName()); 808 } else { 809 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 810 // generated from the source pattern root op. Then we can use the source 811 // pattern's value types to determine the value type of the generated op 812 // here. 813 814 // We need to specify the types for all results. 815 int numResults = resultOp.getNumResults(); 816 if (numResults != 0) { 817 os.indent(6) << "tblgen_types.clear();\n"; 818 for (int i = 0; i < numResults; ++i) { 819 os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) " 820 "tblgen_types.push_back(v->getType());\n", 821 resultIndex + i); 822 } 823 } 824 825 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, 826 resultOp.getQualCppClassName()); 827 if (numResults != 0) 828 os.indent(6) << ", tblgen_types"; 829 } 830 831 // Add arguments for the builder call. 832 for (int argIndex = 0; argIndex != numOpArgs; ++argIndex) { 833 // Start each argment on its own line. 834 (os << ",\n").indent(8); 835 836 Argument opArg = resultOp.getArg(argIndex); 837 // Handle the case of operand first. 838 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 839 if (!operand->name.empty()) { 840 os << "/*" << operand->name << "=*/"; 841 } 842 os << childNodeNames[argIndex]; 843 // TODO(jpienaar): verify types 844 continue; 845 } 846 847 // The argument in the op definition. 848 auto opArgName = resultOp.getArgName(argIndex); 849 if (auto subTree = tree.getArgAsNestedDag(argIndex)) { 850 if (!subTree.isNativeCodeCall()) 851 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 852 "for creating attribute"); 853 os << formatv("/*{0}=*/{1}", opArgName, 854 handleReplaceWithNativeCodeCall(subTree)); 855 } else { 856 auto leaf = tree.getArgAsLeaf(argIndex); 857 // The argument in the result DAG pattern. 858 auto patArgName = tree.getArgName(argIndex); 859 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 860 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 861 if (!opArg.is<NamedAttribute *>()) 862 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 863 if (!patArgName.empty()) 864 os << "/*" << patArgName << "=*/"; 865 } else { 866 os << "/*" << opArgName << "=*/"; 867 } 868 os << handleOpArgument(leaf, patArgName); 869 } 870 } 871 os << "\n );\n"; 872 os.indent(4) << "}\n"; 873 874 return resultValue; 875 } 876 877 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 878 emitSourceFileHeader("Rewriters", os); 879 880 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 881 auto numPatterns = patterns.size(); 882 883 // We put the map here because it can be shared among multiple patterns. 884 RecordOperatorMap recordOpMap; 885 886 std::vector<std::string> rewriterNames; 887 rewriterNames.reserve(numPatterns); 888 889 std::string baseRewriterName = "GeneratedConvert"; 890 int rewriterIndex = 0; 891 892 for (Record *p : patterns) { 893 std::string name; 894 if (p->isAnonymous()) { 895 // If no name is provided, ensure unique rewriter names simply by 896 // appending unique suffix. 897 name = baseRewriterName + llvm::utostr(rewriterIndex++); 898 } else { 899 name = p->getName(); 900 } 901 LLVM_DEBUG(llvm::dbgs() 902 << "=== start generating pattern '" << name << "' ===\n"); 903 PatternEmitter(p, &recordOpMap, os).emit(name); 904 LLVM_DEBUG(llvm::dbgs() 905 << "=== done generating pattern '" << name << "' ===\n"); 906 rewriterNames.push_back(std::move(name)); 907 } 908 909 // Emit function to add the generated matchers to the pattern list. 910 os << "void populateWithGenerated(MLIRContext *context, " 911 << "OwningRewritePatternList *patterns) {\n"; 912 for (const auto &name : rewriterNames) { 913 os << " patterns->insert<" << name << ">(context);\n"; 914 } 915 os << "}\n"; 916 } 917 918 static mlir::GenRegistration 919 genRewriters("gen-rewriters", "Generate pattern rewriters", 920 [](const RecordKeeper &records, raw_ostream &os) { 921 emitRewriters(records, os); 922 return false; 923 }); 924