1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Support/STLExtras.h" 23 #include "mlir/TableGen/Attribute.h" 24 #include "mlir/TableGen/Format.h" 25 #include "mlir/TableGen/GenInfo.h" 26 #include "mlir/TableGen/Operator.h" 27 #include "mlir/TableGen/Pattern.h" 28 #include "mlir/TableGen/Predicate.h" 29 #include "mlir/TableGen/Type.h" 30 #include "llvm/ADT/StringExtras.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/FormatAdapters.h" 35 #include "llvm/Support/PrettyStackTrace.h" 36 #include "llvm/Support/Signals.h" 37 #include "llvm/TableGen/Error.h" 38 #include "llvm/TableGen/Main.h" 39 #include "llvm/TableGen/Record.h" 40 #include "llvm/TableGen/TableGenBackend.h" 41 42 using namespace mlir; 43 using namespace mlir::tblgen; 44 45 using llvm::formatv; 46 using llvm::Record; 47 using llvm::RecordKeeper; 48 49 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 50 51 namespace llvm { 52 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 53 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 54 raw_ostream &os, StringRef style) { 55 os << v.first << ":" << v.second; 56 } 57 }; 58 } // end namespace llvm 59 60 //===----------------------------------------------------------------------===// 61 // PatternEmitter 62 //===----------------------------------------------------------------------===// 63 64 namespace { 65 class PatternEmitter { 66 public: 67 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 68 69 // Emits the mlir::RewritePattern struct named `rewriteName`. 70 void emit(StringRef rewriteName); 71 72 private: 73 // Emits the code for matching ops. 74 void emitMatchLogic(DagNode tree); 75 76 // Emits the code for rewriting ops. 77 void emitRewriteLogic(); 78 79 //===--------------------------------------------------------------------===// 80 // Match utilities 81 //===--------------------------------------------------------------------===// 82 83 // Emits C++ statements for matching the op constrained by the given DAG 84 // `tree`. 85 void emitOpMatch(DagNode tree, int depth); 86 87 // Emits C++ statements for matching the `argIndex`-th argument of the given 88 // DAG `tree` as an operand. 89 void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent); 90 91 // Emits C++ statements for matching the `argIndex`-th argument of the given 92 // DAG `tree` as an attribute. 93 void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent); 94 95 //===--------------------------------------------------------------------===// 96 // Rewrite utilities 97 //===--------------------------------------------------------------------===// 98 99 // The entry point for handling a result pattern rooted at `resultTree`. This 100 // method dispatches to concrete handlers according to `resultTree`'s kind and 101 // returns a symbol representing the whole value pack. Callers are expected to 102 // further resolve the symbol according to the specific use case. 103 // 104 // `depth` is the nesting level of `resultTree`; 0 means top-level result 105 // pattern. For top-level result pattern, `resultIndex` indicates which result 106 // of the matched root op this pattern is intended to replace, which can be 107 // used to deduce the result type of the op generated from this result 108 // pattern. 109 std::string handleResultPattern(DagNode resultTree, int resultIndex, 110 int depth); 111 112 // Emits the C++ statement to replace the matched DAG with a value built via 113 // calling native C++ code. 114 std::string handleReplaceWithNativeCodeCall(DagNode resultTree); 115 116 // Returns the C++ expression referencing the old value serving as the 117 // replacement. 118 std::string handleReplaceWithValue(DagNode tree); 119 120 // Emits the C++ statement to build a new op out of the given DAG `tree` and 121 // returns the variable name that this op is assigned to. If the root op in 122 // DAG `tree` has a specified name, the created op will be assigned to a 123 // variable of the given name. Otherwise, a unique name will be used as the 124 // result value name. 125 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 126 127 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; 128 129 // Emits a local variable for each value and attribute to be used for creating 130 // an op. 131 void createSeparateLocalVarsForOpArgs(DagNode node, 132 ChildNodeIndexNameMap &childNodeNames); 133 134 // Emits the concrete arguments used to call a op's builder. 135 void supplyValuesForOpArgs(DagNode node, 136 const ChildNodeIndexNameMap &childNodeNames); 137 138 // Emits the local variables for holding all values as a whole and all named 139 // attributes as a whole to be used for creating an op. 140 void createAggregateLocalVarsForOpArgs( 141 DagNode node, const ChildNodeIndexNameMap &childNodeNames); 142 143 // Returns the C++ expression to construct a constant attribute of the given 144 // `value` for the given attribute kind `attr`. 145 std::string handleConstantAttr(Attribute attr, StringRef value); 146 147 // Returns the C++ expression to build an argument from the given DAG `leaf`. 148 // `patArgName` is used to bound the argument to the source pattern. 149 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 150 151 //===--------------------------------------------------------------------===// 152 // General utilities 153 //===--------------------------------------------------------------------===// 154 155 // Collects all of the operations within the given dag tree. 156 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 157 158 // Returns a unique symbol for a local variable of the given `op`. 159 std::string getUniqueSymbol(const Operator *op); 160 161 //===--------------------------------------------------------------------===// 162 // Symbol utilities 163 //===--------------------------------------------------------------------===// 164 165 // Returns how many static values the given DAG `node` correspond to. 166 int getNodeValueCount(DagNode node); 167 168 private: 169 // Pattern instantiation location followed by the location of multiclass 170 // prototypes used. This is intended to be used as a whole to 171 // PrintFatalError() on errors. 172 ArrayRef<llvm::SMLoc> loc; 173 174 // Op's TableGen Record to wrapper object. 175 RecordOperatorMap *opMap; 176 177 // Handy wrapper for pattern being emitted. 178 Pattern pattern; 179 180 // Map for all bound symbols' info. 181 SymbolInfoMap symbolInfoMap; 182 183 // The next unused ID for newly created values. 184 unsigned nextValueId; 185 186 raw_ostream &os; 187 188 // Format contexts containing placeholder substitutions. 189 FmtContext fmtCtx; 190 191 // Number of op processed. 192 int opCounter = 0; 193 }; 194 } // end anonymous namespace 195 196 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 197 raw_ostream &os) 198 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 199 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 200 fmtCtx.withBuilder("rewriter"); 201 } 202 203 std::string PatternEmitter::handleConstantAttr(Attribute attr, 204 StringRef value) { 205 if (!attr.isConstBuildable()) 206 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 207 " does not have the 'constBuilderCall' field"); 208 209 // TODO(jpienaar): Verify the constants here 210 return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value); 211 } 212 213 // Helper function to match patterns. 214 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 215 Operator &op = tree.getDialectOp(opMap); 216 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 217 << op.getOperationName() << "' at depth " << depth 218 << '\n'); 219 220 int indent = 4 + 2 * depth; 221 os.indent(indent) << formatv( 222 "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n", 223 depth, op.getQualCppClassName()); 224 // Skip the operand matching at depth 0 as the pattern rewriter already does. 225 if (depth != 0) { 226 // Skip if there is no defining operation (e.g., arguments to function). 227 os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n", 228 depth); 229 } 230 if (tree.getNumArgs() != op.getNumArgs()) { 231 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 232 "pattern vs. {2} in definition", 233 op.getOperationName(), tree.getNumArgs(), 234 op.getNumArgs())); 235 } 236 237 // If the operand's name is set, set to that variable. 238 auto name = tree.getSymbol(); 239 if (!name.empty()) 240 os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth); 241 242 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 243 auto opArg = op.getArg(i); 244 245 // Handle nested DAG construct first 246 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 247 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 248 if (operand->isVariadic()) { 249 auto error = formatv("use nested DAG construct to match op {0}'s " 250 "variadic operand #{1} unsupported now", 251 op.getOperationName(), i); 252 PrintFatalError(loc, error); 253 } 254 } 255 os.indent(indent) << "{\n"; 256 257 os.indent(indent + 2) << formatv( 258 "auto *op{0} = " 259 "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n", 260 depth + 1, depth, i); 261 emitOpMatch(argTree, depth + 1); 262 os.indent(indent + 2) 263 << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); 264 os.indent(indent) << "}\n"; 265 continue; 266 } 267 268 // Next handle DAG leaf: operand or attribute 269 if (opArg.is<NamedTypeConstraint *>()) { 270 emitOperandMatch(tree, i, depth, indent); 271 } else if (opArg.is<NamedAttribute *>()) { 272 emitAttributeMatch(tree, i, depth, indent); 273 } else { 274 PrintFatalError(loc, "unhandled case when matching op"); 275 } 276 } 277 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 278 << op.getOperationName() << "' at depth " << depth 279 << '\n'); 280 } 281 282 void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth, 283 int indent) { 284 Operator &op = tree.getDialectOp(opMap); 285 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>(); 286 auto matcher = tree.getArgAsLeaf(argIndex); 287 288 // If a constraint is specified, we need to generate C++ statements to 289 // check the constraint. 290 if (!matcher.isUnspecified()) { 291 if (!matcher.isOperandMatcher()) { 292 PrintFatalError( 293 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 294 op.getOperationName(), argIndex + 1)); 295 } 296 297 // Only need to verify if the matcher's type is different from the one 298 // of op definition. 299 if (operand->constraint != matcher.getAsConstraint()) { 300 if (operand->isVariadic()) { 301 auto error = formatv( 302 "further constrain op {0}'s variadic operand #{1} unsupported now", 303 op.getOperationName(), argIndex); 304 PrintFatalError(loc, error); 305 } 306 auto self = 307 formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()", 308 depth, argIndex); 309 os.indent(indent) << "if (!(" 310 << tgfmt(matcher.getConditionTemplate(), 311 &fmtCtx.withSelf(self)) 312 << ")) return matchFailure();\n"; 313 } 314 } 315 316 // Capture the value 317 auto name = tree.getArgName(argIndex); 318 // `$_` is a special symbol to ignore op argument matching. 319 if (!name.empty() && name != "_") { 320 // We need to subtract the number of attributes before this operand to get 321 // the index in the operand list. 322 auto numPrevAttrs = std::count_if( 323 op.arg_begin(), op.arg_begin() + argIndex, 324 [](const Argument &arg) { return arg.is<NamedAttribute *>(); }); 325 326 os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n", 327 name, depth, argIndex - numPrevAttrs); 328 } 329 } 330 331 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth, 332 int indent) { 333 334 Operator &op = tree.getDialectOp(opMap); 335 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>(); 336 const auto &attr = namedAttr->attr; 337 338 os.indent(indent) << "{\n"; 339 indent += 2; 340 os.indent(indent) << formatv( 341 "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, 342 attr.getStorageType(), namedAttr->name); 343 344 // TODO(antiagainst): This should use getter method to avoid duplication. 345 if (attr.hasDefaultValue()) { 346 os.indent(indent) << "if (!tblgen_attr) tblgen_attr = " 347 << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 348 attr.getDefaultValue()) 349 << ";\n"; 350 } else if (attr.isOptional()) { 351 // For a missing attribute that is optional according to definition, we 352 // should just capture a mlir::Attribute() to signal the missing state. 353 // That is precisely what getAttr() returns on missing attributes. 354 } else { 355 os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n"; 356 } 357 358 auto matcher = tree.getArgAsLeaf(argIndex); 359 if (!matcher.isUnspecified()) { 360 if (!matcher.isAttrMatcher()) { 361 PrintFatalError( 362 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 363 op.getOperationName(), argIndex + 1)); 364 } 365 366 // If a constraint is specified, we need to generate C++ statements to 367 // check the constraint. 368 os.indent(indent) << "if (!(" 369 << tgfmt(matcher.getConditionTemplate(), 370 &fmtCtx.withSelf("tblgen_attr")) 371 << ")) return matchFailure();\n"; 372 } 373 374 // Capture the value 375 auto name = tree.getArgName(argIndex); 376 // `$_` is a special symbol to ignore op argument matching. 377 if (!name.empty() && name != "_") { 378 os.indent(indent) << formatv("{0} = tblgen_attr;\n", name); 379 } 380 381 indent -= 2; 382 os.indent(indent) << "}\n"; 383 } 384 385 void PatternEmitter::emitMatchLogic(DagNode tree) { 386 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 387 emitOpMatch(tree, 0); 388 389 for (auto &appliedConstraint : pattern.getConstraints()) { 390 auto &constraint = appliedConstraint.constraint; 391 auto &entities = appliedConstraint.entities; 392 393 auto condition = constraint.getConditionTemplate(); 394 auto cmd = "if (!({0})) return matchFailure();\n"; 395 396 if (isa<TypeConstraint>(constraint)) { 397 auto self = formatv("({0}->getType())", 398 symbolInfoMap.getValueAndRangeUse(entities.front())); 399 os.indent(4) << formatv(cmd, 400 tgfmt(condition, &fmtCtx.withSelf(self.str()))); 401 } else if (isa<AttrConstraint>(constraint)) { 402 PrintFatalError( 403 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 404 } else { 405 // TODO(b/138794486): replace formatv arguments with the exact specified 406 // args. 407 if (entities.size() > 4) { 408 PrintFatalError(loc, "only support up to 4-entity constraints now"); 409 } 410 SmallVector<std::string, 4> names; 411 int i = 0; 412 for (int e = entities.size(); i < e; ++i) 413 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 414 std::string self = appliedConstraint.self; 415 if (!self.empty()) 416 self = symbolInfoMap.getValueAndRangeUse(self); 417 for (; i < 4; ++i) 418 names.push_back("<unused>"); 419 os.indent(4) << formatv(cmd, 420 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 421 names[1], names[2], names[3])); 422 } 423 } 424 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 425 } 426 427 void PatternEmitter::collectOps(DagNode tree, 428 llvm::SmallPtrSetImpl<const Operator *> &ops) { 429 // Check if this tree is an operation. 430 if (tree.isOperation()) { 431 const Operator &op = tree.getDialectOp(opMap); 432 LLVM_DEBUG(llvm::dbgs() 433 << "found operation " << op.getOperationName() << '\n'); 434 ops.insert(&op); 435 } 436 437 // Recurse the arguments of the tree. 438 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 439 if (auto child = tree.getArgAsNestedDag(i)) 440 collectOps(child, ops); 441 } 442 443 void PatternEmitter::emit(StringRef rewriteName) { 444 // Get the DAG tree for the source pattern. 445 DagNode sourceTree = pattern.getSourcePattern(); 446 447 const Operator &rootOp = pattern.getSourceRootOp(); 448 auto rootName = rootOp.getOperationName(); 449 450 // Collect the set of result operations. 451 llvm::SmallPtrSet<const Operator *, 4> resultOps; 452 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 453 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 454 collectOps(pattern.getResultPattern(i), resultOps); 455 } 456 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 457 458 // Emit RewritePattern for Pattern. 459 auto locs = pattern.getLocation(); 460 os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", 461 make_range(locs.rbegin(), locs.rend())); 462 os << formatv(R"(struct {0} : public RewritePattern { 463 {0}(MLIRContext *context) 464 : RewritePattern("{1}", {{)", 465 rewriteName, rootName); 466 // Sort result operators by name. 467 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 468 resultOps.end()); 469 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 470 return lhs->getOperationName() < rhs->getOperationName(); 471 }); 472 interleaveComma(sortedResultOps, os, [&](const Operator *op) { 473 os << '"' << op->getOperationName() << '"'; 474 }); 475 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 476 477 // Emit matchAndRewrite() function. 478 os << R"( 479 PatternMatchResult matchAndRewrite(Operation *op0, 480 PatternRewriter &rewriter) const override { 481 )"; 482 483 // Register all symbols bound in the source pattern. 484 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 485 486 LLVM_DEBUG( 487 llvm::dbgs() << "start creating local variables for capturing matches\n"); 488 os.indent(4) << "// Variables for capturing values and attributes used for " 489 "creating ops\n"; 490 // Create local variables for storing the arguments and results bound 491 // to symbols. 492 for (const auto &symbolInfoPair : symbolInfoMap) { 493 StringRef symbol = symbolInfoPair.getKey(); 494 auto &info = symbolInfoPair.getValue(); 495 os.indent(4) << info.getVarDecl(symbol); 496 } 497 // TODO(jpienaar): capture ops with consistent numbering so that it can be 498 // reused for fused loc. 499 os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n", 500 pattern.getSourcePattern().getNumOps()); 501 LLVM_DEBUG( 502 llvm::dbgs() << "done creating local variables for capturing matches\n"); 503 504 os.indent(4) << "// Match\n"; 505 os.indent(4) << "tblgen_ops[0] = op0;\n"; 506 emitMatchLogic(sourceTree); 507 os << "\n"; 508 509 os.indent(4) << "// Rewrite\n"; 510 emitRewriteLogic(); 511 512 os.indent(4) << "return matchSuccess();\n"; 513 os << " };\n"; 514 os << "};\n"; 515 } 516 517 void PatternEmitter::emitRewriteLogic() { 518 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 519 const Operator &rootOp = pattern.getSourceRootOp(); 520 int numExpectedResults = rootOp.getNumResults(); 521 int numResultPatterns = pattern.getNumResultPatterns(); 522 523 // First register all symbols bound to ops generated in result patterns. 524 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 525 526 // Only the last N static values generated are used to replace the matched 527 // root N-result op. We need to calculate the starting index (of the results 528 // of the matched op) each result pattern is to replace. 529 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 530 // If we don't need to replace any value at all, set the replacement starting 531 // index as the number of result patterns so we skip all of them when trying 532 // to replace the matched op's results. 533 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 534 for (int i = numResultPatterns - 1; i >= 0; --i) { 535 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 536 offsets[i] = offsets[i + 1] - numValues; 537 if (offsets[i] == 0) { 538 if (replStartIndex == -1) 539 replStartIndex = i; 540 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 541 auto error = formatv( 542 "cannot use the same multi-result op '{0}' to generate both " 543 "auxiliary values and values to be used for replacing the matched op", 544 pattern.getResultPattern(i).getSymbol()); 545 PrintFatalError(loc, error); 546 } 547 } 548 549 if (offsets.front() > 0) { 550 const char error[] = "no enough values generated to replace the matched op"; 551 PrintFatalError(loc, error); 552 } 553 554 os.indent(4) << "auto loc = rewriter.getFusedLoc({"; 555 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 556 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 557 } 558 os << "}); (void)loc;\n"; 559 560 // Process auxiliary result patterns. 561 for (int i = 0; i < replStartIndex; ++i) { 562 DagNode resultTree = pattern.getResultPattern(i); 563 auto val = handleResultPattern(resultTree, offsets[i], 0); 564 // Normal op creation will be streamed to `os` by the above call; but 565 // NativeCodeCall will only be materialized to `os` if it is used. Here 566 // we are handling auxiliary patterns so we want the side effect even if 567 // NativeCodeCall is not replacing matched root op's results. 568 if (resultTree.isNativeCodeCall()) 569 os.indent(4) << val << ";\n"; 570 } 571 572 if (numExpectedResults == 0) { 573 assert(replStartIndex >= numResultPatterns && 574 "invalid auxiliary vs. replacement pattern division!"); 575 // No result to replace. Just erase the op. 576 os.indent(4) << "rewriter.eraseOp(op0);\n"; 577 } else { 578 // Process replacement result patterns. 579 os.indent(4) << "SmallVector<Value *, 4> tblgen_repl_values;\n"; 580 for (int i = replStartIndex; i < numResultPatterns; ++i) { 581 DagNode resultTree = pattern.getResultPattern(i); 582 auto val = handleResultPattern(resultTree, offsets[i], 0); 583 os.indent(4) << "\n"; 584 // Resolve each symbol for all range use so that we can loop over them. 585 os << symbolInfoMap.getAllRangeUse( 586 val, " for (auto *v : {0}) {{ tblgen_repl_values.push_back(v); }", 587 "\n"); 588 } 589 os.indent(4) << "\n"; 590 os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n"; 591 } 592 593 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 594 } 595 596 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 597 return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++); 598 } 599 600 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 601 int resultIndex, int depth) { 602 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 603 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 604 LLVM_DEBUG(llvm::dbgs() << '\n'); 605 606 if (resultTree.isNativeCodeCall()) { 607 auto symbol = handleReplaceWithNativeCodeCall(resultTree); 608 symbolInfoMap.bindValue(symbol); 609 return symbol; 610 } 611 612 if (resultTree.isReplaceWithValue()) { 613 return handleReplaceWithValue(resultTree); 614 } 615 616 // Normal op creation. 617 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 618 if (resultTree.getSymbol().empty()) { 619 // This is an op not explicitly bound to a symbol in the rewrite rule. 620 // Register the auto-generated symbol for it. 621 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 622 } 623 return symbol; 624 } 625 626 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { 627 assert(tree.isReplaceWithValue()); 628 629 if (tree.getNumArgs() != 1) { 630 PrintFatalError( 631 loc, "replaceWithValue directive must take exactly one argument"); 632 } 633 634 if (!tree.getSymbol().empty()) { 635 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 636 } 637 638 return tree.getArgName(0); 639 } 640 641 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 642 StringRef patArgName) { 643 if (leaf.isConstantAttr()) { 644 auto constAttr = leaf.getAsConstantAttr(); 645 return handleConstantAttr(constAttr.getAttribute(), 646 constAttr.getConstantValue()); 647 } 648 if (leaf.isEnumAttrCase()) { 649 auto enumCase = leaf.getAsEnumAttrCase(); 650 if (enumCase.isStrCase()) 651 return handleConstantAttr(enumCase, enumCase.getSymbol()); 652 // This is an enum case backed by an IntegerAttr. We need to get its value 653 // to build the constant. 654 std::string val = std::to_string(enumCase.getValue()); 655 return handleConstantAttr(enumCase, val); 656 } 657 658 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 659 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 660 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 661 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 662 << "' (via symbol ref)\n"); 663 return argName; 664 } 665 if (leaf.isNativeCodeCall()) { 666 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 667 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 668 << "' (via NativeCodeCall)\n"); 669 return repl; 670 } 671 PrintFatalError(loc, "unhandled case when rewriting op"); 672 } 673 674 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { 675 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 676 LLVM_DEBUG(tree.print(llvm::dbgs())); 677 LLVM_DEBUG(llvm::dbgs() << '\n'); 678 679 auto fmt = tree.getNativeCodeTemplate(); 680 // TODO(b/138794486): replace formatv arguments with the exact specified args. 681 SmallVector<std::string, 8> attrs(8); 682 if (tree.getNumArgs() > 8) { 683 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 684 Twine(tree.getNumArgs())); 685 } 686 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 687 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 688 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i 689 << " replacement: " << attrs[i] << "\n"); 690 } 691 return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4], 692 attrs[5], attrs[6], attrs[7]); 693 } 694 695 int PatternEmitter::getNodeValueCount(DagNode node) { 696 if (node.isOperation()) { 697 // If the op is bound to a symbol in the rewrite rule, query its result 698 // count from the symbol info map. 699 auto symbol = node.getSymbol(); 700 if (!symbol.empty()) { 701 return symbolInfoMap.getStaticValueCount(symbol); 702 } 703 // Otherwise this is an unbound op; we will use all its results. 704 return pattern.getDialectOp(node).getNumResults(); 705 } 706 // TODO(antiagainst): This considers all NativeCodeCall as returning one 707 // value. Enhance if multi-value ones are needed. 708 return 1; 709 } 710 711 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 712 int depth) { 713 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); 714 LLVM_DEBUG(tree.print(llvm::dbgs())); 715 LLVM_DEBUG(llvm::dbgs() << '\n'); 716 717 Operator &resultOp = tree.getDialectOp(opMap); 718 auto numOpArgs = resultOp.getNumArgs(); 719 720 if (numOpArgs != tree.getNumArgs()) { 721 PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " 722 "{1} in pattern vs. {2} in definition", 723 resultOp.getOperationName(), tree.getNumArgs(), 724 numOpArgs)); 725 } 726 727 // A map to collect all nested DAG child nodes' names, with operand index as 728 // the key. This includes both bound and unbound child nodes. 729 ChildNodeIndexNameMap childNodeNames; 730 731 // First go through all the child nodes who are nested DAG constructs to 732 // create ops for them and remember the symbol names for them, so that we can 733 // use the results in the current node. This happens in a recursive manner. 734 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 735 if (auto child = tree.getArgAsNestedDag(i)) { 736 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 737 } 738 } 739 740 // The name of the local variable holding this op. 741 std::string valuePackName; 742 // The symbol for holding the result of this pattern. Note that the result of 743 // this pattern is not necessarily the same as the variable created by this 744 // pattern because we can use `__N` suffix to refer only a specific result if 745 // the generated op is a multi-result op. 746 std::string resultValue; 747 if (tree.getSymbol().empty()) { 748 // No symbol is explicitly bound to this op in the pattern. Generate a 749 // unique name. 750 valuePackName = resultValue = getUniqueSymbol(&resultOp); 751 } else { 752 resultValue = tree.getSymbol(); 753 // Strip the index to get the name for the value pack and use it to name the 754 // local variable for the op. 755 valuePackName = SymbolInfoMap::getValuePackName(resultValue); 756 } 757 758 // Create the local variable for this op. 759 os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(), 760 valuePackName); 761 os.indent(4) << "{\n"; 762 763 // Right now ODS don't have general type inference support. Except a few 764 // special cases listed below, DRR needs to supply types for all results 765 // when building an op. 766 bool isSameOperandsAndResultType = 767 resultOp.getTrait("OpTrait::SameOperandsAndResultType"); 768 bool useFirstAttr = resultOp.getTrait("OpTrait::FirstAttrDerivedResultType"); 769 770 if (isSameOperandsAndResultType || useFirstAttr) { 771 // We know how to deduce the result type for ops with these traits and we've 772 // generated builders taking aggregate parameters. Use those builders to 773 // create the ops. 774 775 // First prepare local variables for op arguments used in builder call. 776 createAggregateLocalVarsForOpArgs(tree, childNodeNames); 777 // Then create the op. 778 os.indent(6) << formatv( 779 "{0} = rewriter.create<{1}>(loc, tblgen_values, tblgen_attrs);\n", 780 valuePackName, resultOp.getQualCppClassName()); 781 os.indent(4) << "}\n"; 782 return resultValue; 783 } 784 785 bool isBroadcastable = 786 resultOp.getTrait("OpTrait::BroadcastableTwoOperandsOneResult"); 787 bool usePartialResults = valuePackName != resultValue; 788 789 if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) { 790 // For these cases (broadcastable ops, op results used both as auxiliary 791 // values and replacement values, ops in nested patterns, auxiliary ops), we 792 // still need to supply the result types when building the op. But because 793 // we don't generate a builder automatically with ODS for them, it's the 794 // developer's responsiblity to make sure such a builder (with result type 795 // deduction ability) exists. We go through the separate-parameter builder 796 // here given that it's easier for developers to write compared to 797 // aggregate-parameter builders. 798 createSeparateLocalVarsForOpArgs(tree, childNodeNames); 799 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, 800 resultOp.getQualCppClassName()); 801 supplyValuesForOpArgs(tree, childNodeNames); 802 os << "\n );\n"; 803 os.indent(4) << "}\n"; 804 return resultValue; 805 } 806 807 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 808 // generated from the source pattern root op. Then we can use the source 809 // pattern's value types to determine the value type of the generated op 810 // here. 811 812 // First prepare local variables for op arguments used in builder call. 813 createAggregateLocalVarsForOpArgs(tree, childNodeNames); 814 815 // Then prepare the result types. We need to specify the types for all 816 // results. 817 os.indent(6) << formatv( 818 "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n"); 819 int numResults = resultOp.getNumResults(); 820 if (numResults != 0) { 821 for (int i = 0; i < numResults; ++i) 822 os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) {{" 823 "tblgen_types.push_back(v->getType()); }\n", 824 resultIndex + i); 825 } 826 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc, tblgen_types, " 827 "tblgen_values, tblgen_attrs);\n", 828 valuePackName, resultOp.getQualCppClassName()); 829 os.indent(4) << "}\n"; 830 return resultValue; 831 } 832 833 void PatternEmitter::createSeparateLocalVarsForOpArgs( 834 DagNode node, ChildNodeIndexNameMap &childNodeNames) { 835 Operator &resultOp = node.getDialectOp(opMap); 836 837 // Now prepare operands used for building this op: 838 // * If the operand is non-variadic, we create a `Value*` local variable. 839 // * If the operand is variadic, we create a `SmallVector<Value*>` local 840 // variable. 841 842 int valueIndex = 0; // An index for uniquing local variable names. 843 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 844 const auto *operand = 845 resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>(); 846 if (!operand) { 847 // We do not need special handling for attributes. 848 continue; 849 } 850 851 std::string varName; 852 if (operand->isVariadic()) { 853 varName = formatv("tblgen_values_{0}", valueIndex++); 854 os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName); 855 std::string range; 856 if (node.isNestedDagArg(argIndex)) { 857 range = childNodeNames[argIndex]; 858 } else { 859 range = node.getArgName(argIndex); 860 } 861 // Resolve the symbol for all range use so that we have a uniform way of 862 // capturing the values. 863 range = symbolInfoMap.getValueAndRangeUse(range); 864 os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range, 865 varName); 866 } else { 867 varName = formatv("tblgen_value_{0}", valueIndex++); 868 os.indent(6) << formatv("Value *{0} = ", varName); 869 if (node.isNestedDagArg(argIndex)) { 870 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 871 } else { 872 DagLeaf leaf = node.getArgAsLeaf(argIndex); 873 auto symbol = 874 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 875 if (leaf.isNativeCodeCall()) { 876 os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)); 877 } else { 878 os << symbol; 879 } 880 } 881 os << ";\n"; 882 } 883 884 // Update to use the newly created local variable for building the op later. 885 childNodeNames[argIndex] = varName; 886 } 887 } 888 889 void PatternEmitter::supplyValuesForOpArgs( 890 DagNode node, const ChildNodeIndexNameMap &childNodeNames) { 891 Operator &resultOp = node.getDialectOp(opMap); 892 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); 893 argIndex != numOpArgs; ++argIndex) { 894 // Start each argument on its own line. 895 (os << ",\n").indent(8); 896 897 Argument opArg = resultOp.getArg(argIndex); 898 // Handle the case of operand first. 899 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 900 if (!operand->name.empty()) 901 os << "/*" << operand->name << "=*/"; 902 os << childNodeNames.lookup(argIndex); 903 continue; 904 } 905 906 // The argument in the op definition. 907 auto opArgName = resultOp.getArgName(argIndex); 908 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 909 if (!subTree.isNativeCodeCall()) 910 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 911 "for creating attribute"); 912 os << formatv("/*{0}=*/{1}", opArgName, 913 handleReplaceWithNativeCodeCall(subTree)); 914 } else { 915 auto leaf = node.getArgAsLeaf(argIndex); 916 // The argument in the result DAG pattern. 917 auto patArgName = node.getArgName(argIndex); 918 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 919 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 920 if (!opArg.is<NamedAttribute *>()) 921 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 922 if (!patArgName.empty()) 923 os << "/*" << patArgName << "=*/"; 924 } else { 925 os << "/*" << opArgName << "=*/"; 926 } 927 os << handleOpArgument(leaf, patArgName); 928 } 929 } 930 } 931 932 void PatternEmitter::createAggregateLocalVarsForOpArgs( 933 DagNode node, const ChildNodeIndexNameMap &childNodeNames) { 934 Operator &resultOp = node.getDialectOp(opMap); 935 936 os.indent(6) << formatv( 937 "SmallVector<Value *, 4> tblgen_values; (void)tblgen_values;\n"); 938 os.indent(6) << formatv( 939 "SmallVector<NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs;\n"); 940 941 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 942 if (const auto *attr = 943 resultOp.getArg(argIndex).dyn_cast<NamedAttribute *>()) { 944 const char *addAttrCmd = "if ({1}) {{" 945 " tblgen_attrs.emplace_back(rewriter." 946 "getIdentifier(\"{0}\"), {1}); }\n"; 947 // The argument in the op definition. 948 auto opArgName = resultOp.getArgName(argIndex); 949 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 950 if (!subTree.isNativeCodeCall()) 951 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 952 "for creating attribute"); 953 os.indent(6) << formatv(addAttrCmd, opArgName, 954 handleReplaceWithNativeCodeCall(subTree)); 955 } else { 956 auto leaf = node.getArgAsLeaf(argIndex); 957 // The argument in the result DAG pattern. 958 auto patArgName = node.getArgName(argIndex); 959 os.indent(6) << formatv(addAttrCmd, opArgName, 960 handleOpArgument(leaf, patArgName)); 961 } 962 continue; 963 } 964 965 const auto *operand = 966 resultOp.getArg(argIndex).get<NamedTypeConstraint *>(); 967 std::string varName; 968 if (operand->isVariadic()) { 969 std::string range; 970 if (node.isNestedDagArg(argIndex)) { 971 range = childNodeNames.lookup(argIndex); 972 } else { 973 range = node.getArgName(argIndex); 974 } 975 // Resolve the symbol for all range use so that we have a uniform way of 976 // capturing the values. 977 range = symbolInfoMap.getValueAndRangeUse(range); 978 os.indent(6) << formatv( 979 "for (auto *v : {0}) tblgen_values.push_back(v);\n", range); 980 } else { 981 os.indent(6) << formatv("tblgen_values.push_back(", varName); 982 if (node.isNestedDagArg(argIndex)) { 983 os << symbolInfoMap.getValueAndRangeUse( 984 childNodeNames.lookup(argIndex)); 985 } else { 986 DagLeaf leaf = node.getArgAsLeaf(argIndex); 987 auto symbol = 988 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 989 if (leaf.isNativeCodeCall()) { 990 os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)); 991 } else { 992 os << symbol; 993 } 994 } 995 os << ");\n"; 996 } 997 } 998 } 999 1000 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 1001 emitSourceFileHeader("Rewriters", os); 1002 1003 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 1004 auto numPatterns = patterns.size(); 1005 1006 // We put the map here because it can be shared among multiple patterns. 1007 RecordOperatorMap recordOpMap; 1008 1009 std::vector<std::string> rewriterNames; 1010 rewriterNames.reserve(numPatterns); 1011 1012 std::string baseRewriterName = "GeneratedConvert"; 1013 int rewriterIndex = 0; 1014 1015 for (Record *p : patterns) { 1016 std::string name; 1017 if (p->isAnonymous()) { 1018 // If no name is provided, ensure unique rewriter names simply by 1019 // appending unique suffix. 1020 name = baseRewriterName + llvm::utostr(rewriterIndex++); 1021 } else { 1022 name = p->getName(); 1023 } 1024 LLVM_DEBUG(llvm::dbgs() 1025 << "=== start generating pattern '" << name << "' ===\n"); 1026 PatternEmitter(p, &recordOpMap, os).emit(name); 1027 LLVM_DEBUG(llvm::dbgs() 1028 << "=== done generating pattern '" << name << "' ===\n"); 1029 rewriterNames.push_back(std::move(name)); 1030 } 1031 1032 // Emit function to add the generated matchers to the pattern list. 1033 os << "void populateWithGenerated(MLIRContext *context, " 1034 << "OwningRewritePatternList *patterns) {\n"; 1035 for (const auto &name : rewriterNames) { 1036 os << " patterns->insert<" << name << ">(context);\n"; 1037 } 1038 os << "}\n"; 1039 } 1040 1041 static mlir::GenRegistration 1042 genRewriters("gen-rewriters", "Generate pattern rewriters", 1043 [](const RecordKeeper &records, raw_ostream &os) { 1044 emitRewriters(records, os); 1045 return false; 1046 }); 1047