1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Support/STLExtras.h" 23 #include "mlir/TableGen/Attribute.h" 24 #include "mlir/TableGen/Format.h" 25 #include "mlir/TableGen/GenInfo.h" 26 #include "mlir/TableGen/Operator.h" 27 #include "mlir/TableGen/Pattern.h" 28 #include "mlir/TableGen/Predicate.h" 29 #include "mlir/TableGen/Type.h" 30 #include "llvm/ADT/StringExtras.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/FormatAdapters.h" 35 #include "llvm/Support/PrettyStackTrace.h" 36 #include "llvm/Support/Signals.h" 37 #include "llvm/TableGen/Error.h" 38 #include "llvm/TableGen/Main.h" 39 #include "llvm/TableGen/Record.h" 40 #include "llvm/TableGen/TableGenBackend.h" 41 42 using namespace mlir; 43 using namespace mlir::tblgen; 44 45 using llvm::formatv; 46 using llvm::Record; 47 using llvm::RecordKeeper; 48 49 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 50 51 namespace llvm { 52 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 53 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 54 raw_ostream &os, StringRef style) { 55 os << v.first << ":" << v.second; 56 } 57 }; 58 } // end namespace llvm 59 60 //===----------------------------------------------------------------------===// 61 // PatternEmitter 62 //===----------------------------------------------------------------------===// 63 64 namespace { 65 class PatternEmitter { 66 public: 67 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 68 69 // Emits the mlir::RewritePattern struct named `rewriteName`. 70 void emit(StringRef rewriteName); 71 72 private: 73 // Emits the code for matching ops. 74 void emitMatchLogic(DagNode tree); 75 76 // Emits the code for rewriting ops. 77 void emitRewriteLogic(); 78 79 //===--------------------------------------------------------------------===// 80 // Match utilities 81 //===--------------------------------------------------------------------===// 82 83 // Emits C++ statements for matching the op constrained by the given DAG 84 // `tree`. 85 void emitOpMatch(DagNode tree, int depth); 86 87 // Emits C++ statements for matching the `argIndex`-th argument of the given 88 // DAG `tree` as an operand. 89 void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent); 90 91 // Emits C++ statements for matching the `argIndex`-th argument of the given 92 // DAG `tree` as an attribute. 93 void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent); 94 95 //===--------------------------------------------------------------------===// 96 // Rewrite utilities 97 //===--------------------------------------------------------------------===// 98 99 // The entry point for handling a result pattern rooted at `resultTree`. This 100 // method dispatches to concrete handlers according to `resultTree`'s kind and 101 // returns a symbol representing the whole value pack. Callers are expected to 102 // further resolve the symbol according to the specific use case. 103 // 104 // `depth` is the nesting level of `resultTree`; 0 means top-level result 105 // pattern. For top-level result pattern, `resultIndex` indicates which result 106 // of the matched root op this pattern is intended to replace, which can be 107 // used to deduce the result type of the op generated from this result 108 // pattern. 109 std::string handleResultPattern(DagNode resultTree, int resultIndex, 110 int depth); 111 112 // Emits the C++ statement to replace the matched DAG with a value built via 113 // calling native C++ code. 114 std::string handleReplaceWithNativeCodeCall(DagNode resultTree); 115 116 // Returns the C++ expression referencing the old value serving as the 117 // replacement. 118 std::string handleReplaceWithValue(DagNode tree); 119 120 // Emits the C++ statement to build a new op out of the given DAG `tree` and 121 // returns the variable name that this op is assigned to. If the root op in 122 // DAG `tree` has a specified name, the created op will be assigned to a 123 // variable of the given name. Otherwise, a unique name will be used as the 124 // result value name. 125 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 126 127 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; 128 129 // Emits a local variable for each value and attribute to be used for creating 130 // an op. 131 void createSeparateLocalVarsForOpArgs(DagNode node, 132 ChildNodeIndexNameMap &childNodeNames); 133 134 // Emits the concrete arguments used to call a op's builder. 135 void supplyValuesForOpArgs(DagNode node, 136 const ChildNodeIndexNameMap &childNodeNames); 137 138 // Emits the local variables for holding all values as a whole and all named 139 // attributes as a whole to be used for creating an op. 140 void createAggregateLocalVarsForOpArgs( 141 DagNode node, const ChildNodeIndexNameMap &childNodeNames); 142 143 // Returns the C++ expression to construct a constant attribute of the given 144 // `value` for the given attribute kind `attr`. 145 std::string handleConstantAttr(Attribute attr, StringRef value); 146 147 // Returns the C++ expression to build an argument from the given DAG `leaf`. 148 // `patArgName` is used to bound the argument to the source pattern. 149 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 150 151 //===--------------------------------------------------------------------===// 152 // General utilities 153 //===--------------------------------------------------------------------===// 154 155 // Collects all of the operations within the given dag tree. 156 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 157 158 // Returns a unique symbol for a local variable of the given `op`. 159 std::string getUniqueSymbol(const Operator *op); 160 161 //===--------------------------------------------------------------------===// 162 // Symbol utilities 163 //===--------------------------------------------------------------------===// 164 165 // Returns how many static values the given DAG `node` correspond to. 166 int getNodeValueCount(DagNode node); 167 168 private: 169 // Pattern instantiation location followed by the location of multiclass 170 // prototypes used. This is intended to be used as a whole to 171 // PrintFatalError() on errors. 172 ArrayRef<llvm::SMLoc> loc; 173 174 // Op's TableGen Record to wrapper object. 175 RecordOperatorMap *opMap; 176 177 // Handy wrapper for pattern being emitted. 178 Pattern pattern; 179 180 // Map for all bound symbols' info. 181 SymbolInfoMap symbolInfoMap; 182 183 // The next unused ID for newly created values. 184 unsigned nextValueId; 185 186 raw_ostream &os; 187 188 // Format contexts containing placeholder substitutions. 189 FmtContext fmtCtx; 190 191 // Number of op processed. 192 int opCounter = 0; 193 }; 194 } // end anonymous namespace 195 196 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 197 raw_ostream &os) 198 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 199 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 200 fmtCtx.withBuilder("rewriter"); 201 } 202 203 std::string PatternEmitter::handleConstantAttr(Attribute attr, 204 StringRef value) { 205 if (!attr.isConstBuildable()) 206 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 207 " does not have the 'constBuilderCall' field"); 208 209 // TODO(jpienaar): Verify the constants here 210 return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value); 211 } 212 213 // Helper function to match patterns. 214 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 215 Operator &op = tree.getDialectOp(opMap); 216 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 217 << op.getOperationName() << "' at depth " << depth 218 << '\n'); 219 220 int indent = 4 + 2 * depth; 221 os.indent(indent) << formatv( 222 "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n", 223 depth, op.getQualCppClassName()); 224 // Skip the operand matching at depth 0 as the pattern rewriter already does. 225 if (depth != 0) { 226 // Skip if there is no defining operation (e.g., arguments to function). 227 os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n", 228 depth); 229 } 230 if (tree.getNumArgs() != op.getNumArgs()) { 231 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 232 "pattern vs. {2} in definition", 233 op.getOperationName(), tree.getNumArgs(), 234 op.getNumArgs())); 235 } 236 237 // If the operand's name is set, set to that variable. 238 auto name = tree.getSymbol(); 239 if (!name.empty()) 240 os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth); 241 242 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 243 auto opArg = op.getArg(i); 244 245 // Handle nested DAG construct first 246 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 247 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 248 if (operand->isVariadic()) { 249 auto error = formatv("use nested DAG construct to match op {0}'s " 250 "variadic operand #{1} unsupported now", 251 op.getOperationName(), i); 252 PrintFatalError(loc, error); 253 } 254 } 255 os.indent(indent) << "{\n"; 256 257 os.indent(indent + 2) << formatv( 258 "auto *op{0} = " 259 "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n", 260 depth + 1, depth, i); 261 emitOpMatch(argTree, depth + 1); 262 os.indent(indent + 2) 263 << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); 264 os.indent(indent) << "}\n"; 265 continue; 266 } 267 268 // Next handle DAG leaf: operand or attribute 269 if (opArg.is<NamedTypeConstraint *>()) { 270 emitOperandMatch(tree, i, depth, indent); 271 } else if (opArg.is<NamedAttribute *>()) { 272 emitAttributeMatch(tree, i, depth, indent); 273 } else { 274 PrintFatalError(loc, "unhandled case when matching op"); 275 } 276 } 277 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 278 << op.getOperationName() << "' at depth " << depth 279 << '\n'); 280 } 281 282 void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth, 283 int indent) { 284 Operator &op = tree.getDialectOp(opMap); 285 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>(); 286 auto matcher = tree.getArgAsLeaf(argIndex); 287 288 // If a constraint is specified, we need to generate C++ statements to 289 // check the constraint. 290 if (!matcher.isUnspecified()) { 291 if (!matcher.isOperandMatcher()) { 292 PrintFatalError( 293 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 294 op.getOperationName(), argIndex + 1)); 295 } 296 297 // Only need to verify if the matcher's type is different from the one 298 // of op definition. 299 if (operand->constraint != matcher.getAsConstraint()) { 300 if (operand->isVariadic()) { 301 auto error = formatv( 302 "further constrain op {0}'s variadic operand #{1} unsupported now", 303 op.getOperationName(), argIndex); 304 PrintFatalError(loc, error); 305 } 306 auto self = 307 formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()", 308 depth, argIndex); 309 os.indent(indent) << "if (!(" 310 << tgfmt(matcher.getConditionTemplate(), 311 &fmtCtx.withSelf(self)) 312 << ")) return matchFailure();\n"; 313 } 314 } 315 316 // Capture the value 317 auto name = tree.getArgName(argIndex); 318 if (!name.empty()) { 319 // We need to subtract the number of attributes before this operand to get 320 // the index in the operand list. 321 auto numPrevAttrs = std::count_if( 322 op.arg_begin(), op.arg_begin() + argIndex, 323 [](const Argument &arg) { return arg.is<NamedAttribute *>(); }); 324 325 os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n", 326 name, depth, argIndex - numPrevAttrs); 327 } 328 } 329 330 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth, 331 int indent) { 332 Operator &op = tree.getDialectOp(opMap); 333 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>(); 334 const auto &attr = namedAttr->attr; 335 336 os.indent(indent) << "{\n"; 337 indent += 2; 338 os.indent(indent) << formatv( 339 "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, 340 attr.getStorageType(), namedAttr->name); 341 342 // TODO(antiagainst): This should use getter method to avoid duplication. 343 if (attr.hasDefaultValueInitializer()) { 344 os.indent(indent) << "if (!tblgen_attr) tblgen_attr = " 345 << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 346 attr.getDefaultValueInitializer()) 347 << ";\n"; 348 } else if (attr.isOptional()) { 349 // For a missing attribute that is optional according to definition, we 350 // should just capture a mlir::Attribute() to signal the missing state. 351 // That is precisely what getAttr() returns on missing attributes. 352 } else { 353 os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n"; 354 } 355 356 auto matcher = tree.getArgAsLeaf(argIndex); 357 if (!matcher.isUnspecified()) { 358 if (!matcher.isAttrMatcher()) { 359 PrintFatalError( 360 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 361 op.getOperationName(), argIndex + 1)); 362 } 363 364 // If a constraint is specified, we need to generate C++ statements to 365 // check the constraint. 366 os.indent(indent) << "if (!(" 367 << tgfmt(matcher.getConditionTemplate(), 368 &fmtCtx.withSelf("tblgen_attr")) 369 << ")) return matchFailure();\n"; 370 } 371 372 // Capture the value 373 auto name = tree.getArgName(argIndex); 374 if (!name.empty()) { 375 os.indent(indent) << formatv("{0} = tblgen_attr;\n", name); 376 } 377 378 indent -= 2; 379 os.indent(indent) << "}\n"; 380 } 381 382 void PatternEmitter::emitMatchLogic(DagNode tree) { 383 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 384 emitOpMatch(tree, 0); 385 386 for (auto &appliedConstraint : pattern.getConstraints()) { 387 auto &constraint = appliedConstraint.constraint; 388 auto &entities = appliedConstraint.entities; 389 390 auto condition = constraint.getConditionTemplate(); 391 auto cmd = "if (!({0})) return matchFailure();\n"; 392 393 if (isa<TypeConstraint>(constraint)) { 394 auto self = formatv("({0}->getType())", 395 symbolInfoMap.getValueAndRangeUse(entities.front())); 396 os.indent(4) << formatv(cmd, 397 tgfmt(condition, &fmtCtx.withSelf(self.str()))); 398 } else if (isa<AttrConstraint>(constraint)) { 399 PrintFatalError( 400 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 401 } else { 402 // TODO(b/138794486): replace formatv arguments with the exact specified 403 // args. 404 if (entities.size() > 4) { 405 PrintFatalError(loc, "only support up to 4-entity constraints now"); 406 } 407 SmallVector<std::string, 4> names; 408 int i = 0; 409 for (int e = entities.size(); i < e; ++i) 410 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 411 std::string self = appliedConstraint.self; 412 if (!self.empty()) 413 self = symbolInfoMap.getValueAndRangeUse(self); 414 for (; i < 4; ++i) 415 names.push_back("<unused>"); 416 os.indent(4) << formatv(cmd, 417 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 418 names[1], names[2], names[3])); 419 } 420 } 421 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 422 } 423 424 void PatternEmitter::collectOps(DagNode tree, 425 llvm::SmallPtrSetImpl<const Operator *> &ops) { 426 // Check if this tree is an operation. 427 if (tree.isOperation()) { 428 const Operator &op = tree.getDialectOp(opMap); 429 LLVM_DEBUG(llvm::dbgs() 430 << "found operation " << op.getOperationName() << '\n'); 431 ops.insert(&op); 432 } 433 434 // Recurse the arguments of the tree. 435 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 436 if (auto child = tree.getArgAsNestedDag(i)) 437 collectOps(child, ops); 438 } 439 440 void PatternEmitter::emit(StringRef rewriteName) { 441 // Get the DAG tree for the source pattern. 442 DagNode sourceTree = pattern.getSourcePattern(); 443 444 const Operator &rootOp = pattern.getSourceRootOp(); 445 auto rootName = rootOp.getOperationName(); 446 447 // Collect the set of result operations. 448 llvm::SmallPtrSet<const Operator *, 4> resultOps; 449 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 450 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 451 collectOps(pattern.getResultPattern(i), resultOps); 452 } 453 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 454 455 // Emit RewritePattern for Pattern. 456 auto locs = pattern.getLocation(); 457 os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", 458 make_range(locs.rbegin(), locs.rend())); 459 os << formatv(R"(struct {0} : public RewritePattern { 460 {0}(MLIRContext *context) 461 : RewritePattern("{1}", {{)", 462 rewriteName, rootName); 463 // Sort result operators by name. 464 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 465 resultOps.end()); 466 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 467 return lhs->getOperationName() < rhs->getOperationName(); 468 }); 469 interleaveComma(sortedResultOps, os, [&](const Operator *op) { 470 os << '"' << op->getOperationName() << '"'; 471 }); 472 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 473 474 // Emit matchAndRewrite() function. 475 os << R"( 476 PatternMatchResult matchAndRewrite(Operation *op0, 477 PatternRewriter &rewriter) const override { 478 )"; 479 480 // Register all symbols bound in the source pattern. 481 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 482 483 LLVM_DEBUG( 484 llvm::dbgs() << "start creating local variables for capturing matches\n"); 485 os.indent(4) << "// Variables for capturing values and attributes used for " 486 "creating ops\n"; 487 // Create local variables for storing the arguments and results bound 488 // to symbols. 489 for (const auto &symbolInfoPair : symbolInfoMap) { 490 StringRef symbol = symbolInfoPair.getKey(); 491 auto &info = symbolInfoPair.getValue(); 492 os.indent(4) << info.getVarDecl(symbol); 493 } 494 // TODO(jpienaar): capture ops with consistent numbering so that it can be 495 // reused for fused loc. 496 os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n", 497 pattern.getSourcePattern().getNumOps()); 498 LLVM_DEBUG( 499 llvm::dbgs() << "done creating local variables for capturing matches\n"); 500 501 os.indent(4) << "// Match\n"; 502 os.indent(4) << "tblgen_ops[0] = op0;\n"; 503 emitMatchLogic(sourceTree); 504 os << "\n"; 505 506 os.indent(4) << "// Rewrite\n"; 507 emitRewriteLogic(); 508 509 os.indent(4) << "return matchSuccess();\n"; 510 os << " };\n"; 511 os << "};\n"; 512 } 513 514 void PatternEmitter::emitRewriteLogic() { 515 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 516 const Operator &rootOp = pattern.getSourceRootOp(); 517 int numExpectedResults = rootOp.getNumResults(); 518 int numResultPatterns = pattern.getNumResultPatterns(); 519 520 // First register all symbols bound to ops generated in result patterns. 521 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 522 523 // Only the last N static values generated are used to replace the matched 524 // root N-result op. We need to calculate the starting index (of the results 525 // of the matched op) each result pattern is to replace. 526 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 527 // If we don't need to replace any value at all, set the replacement starting 528 // index as the number of result patterns so we skip all of them when trying 529 // to replace the matched op's results. 530 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 531 for (int i = numResultPatterns - 1; i >= 0; --i) { 532 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 533 offsets[i] = offsets[i + 1] - numValues; 534 if (offsets[i] == 0) { 535 if (replStartIndex == -1) 536 replStartIndex = i; 537 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 538 auto error = formatv( 539 "cannot use the same multi-result op '{0}' to generate both " 540 "auxiliary values and values to be used for replacing the matched op", 541 pattern.getResultPattern(i).getSymbol()); 542 PrintFatalError(loc, error); 543 } 544 } 545 546 if (offsets.front() > 0) { 547 const char error[] = "no enough values generated to replace the matched op"; 548 PrintFatalError(loc, error); 549 } 550 551 os.indent(4) << "auto loc = rewriter.getFusedLoc({"; 552 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 553 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 554 } 555 os << "}); (void)loc;\n"; 556 557 // Process auxiliary result patterns. 558 for (int i = 0; i < replStartIndex; ++i) { 559 DagNode resultTree = pattern.getResultPattern(i); 560 auto val = handleResultPattern(resultTree, offsets[i], 0); 561 // Normal op creation will be streamed to `os` by the above call; but 562 // NativeCodeCall will only be materialized to `os` if it is used. Here 563 // we are handling auxiliary patterns so we want the side effect even if 564 // NativeCodeCall is not replacing matched root op's results. 565 if (resultTree.isNativeCodeCall()) 566 os.indent(4) << val << ";\n"; 567 } 568 569 if (numExpectedResults == 0) { 570 assert(replStartIndex >= numResultPatterns && 571 "invalid auxiliary vs. replacement pattern division!"); 572 // No result to replace. Just erase the op. 573 os.indent(4) << "rewriter.eraseOp(op0);\n"; 574 } else { 575 // Process replacement result patterns. 576 os.indent(4) << "SmallVector<Value *, 4> tblgen_repl_values;\n"; 577 for (int i = replStartIndex; i < numResultPatterns; ++i) { 578 DagNode resultTree = pattern.getResultPattern(i); 579 auto val = handleResultPattern(resultTree, offsets[i], 0); 580 os.indent(4) << "\n"; 581 // Resolve each symbol for all range use so that we can loop over them. 582 os << symbolInfoMap.getAllRangeUse( 583 val, " for (auto *v : {0}) {{ tblgen_repl_values.push_back(v); }", 584 "\n"); 585 } 586 os.indent(4) << "\n"; 587 os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n"; 588 } 589 590 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 591 } 592 593 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 594 return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++); 595 } 596 597 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 598 int resultIndex, int depth) { 599 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 600 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 601 LLVM_DEBUG(llvm::dbgs() << '\n'); 602 603 if (resultTree.isNativeCodeCall()) { 604 auto symbol = handleReplaceWithNativeCodeCall(resultTree); 605 symbolInfoMap.bindValue(symbol); 606 return symbol; 607 } 608 609 if (resultTree.isReplaceWithValue()) { 610 return handleReplaceWithValue(resultTree); 611 } 612 613 // Normal op creation. 614 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 615 if (resultTree.getSymbol().empty()) { 616 // This is an op not explicitly bound to a symbol in the rewrite rule. 617 // Register the auto-generated symbol for it. 618 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 619 } 620 return symbol; 621 } 622 623 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { 624 assert(tree.isReplaceWithValue()); 625 626 if (tree.getNumArgs() != 1) { 627 PrintFatalError( 628 loc, "replaceWithValue directive must take exactly one argument"); 629 } 630 631 if (!tree.getSymbol().empty()) { 632 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 633 } 634 635 return tree.getArgName(0); 636 } 637 638 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 639 StringRef patArgName) { 640 if (leaf.isConstantAttr()) { 641 auto constAttr = leaf.getAsConstantAttr(); 642 return handleConstantAttr(constAttr.getAttribute(), 643 constAttr.getConstantValue()); 644 } 645 if (leaf.isEnumAttrCase()) { 646 auto enumCase = leaf.getAsEnumAttrCase(); 647 if (enumCase.isStrCase()) 648 return handleConstantAttr(enumCase, enumCase.getSymbol()); 649 // This is an enum case backed by an IntegerAttr. We need to get its value 650 // to build the constant. 651 std::string val = std::to_string(enumCase.getValue()); 652 return handleConstantAttr(enumCase, val); 653 } 654 655 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 656 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 657 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 658 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 659 << "' (via symbol ref)\n"); 660 return argName; 661 } 662 if (leaf.isNativeCodeCall()) { 663 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 664 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 665 << "' (via NativeCodeCall)\n"); 666 return repl; 667 } 668 PrintFatalError(loc, "unhandled case when rewriting op"); 669 } 670 671 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { 672 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 673 LLVM_DEBUG(tree.print(llvm::dbgs())); 674 LLVM_DEBUG(llvm::dbgs() << '\n'); 675 676 auto fmt = tree.getNativeCodeTemplate(); 677 // TODO(b/138794486): replace formatv arguments with the exact specified args. 678 SmallVector<std::string, 8> attrs(8); 679 if (tree.getNumArgs() > 8) { 680 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 681 Twine(tree.getNumArgs())); 682 } 683 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 684 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 685 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argment #" << i 686 << " replacement: " << attrs[i] << "\n"); 687 } 688 return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4], 689 attrs[5], attrs[6], attrs[7]); 690 } 691 692 int PatternEmitter::getNodeValueCount(DagNode node) { 693 if (node.isOperation()) { 694 // If the op is bound to a symbol in the rewrite rule, query its result 695 // count from the symbol info map. 696 auto symbol = node.getSymbol(); 697 if (!symbol.empty()) { 698 return symbolInfoMap.getStaticValueCount(symbol); 699 } 700 // Otherwise this is an unbound op; we will use all its results. 701 return pattern.getDialectOp(node).getNumResults(); 702 } 703 // TODO(antiagainst): This considers all NativeCodeCall as returning one 704 // value. Enhance if multi-value ones are needed. 705 return 1; 706 } 707 708 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 709 int depth) { 710 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); 711 LLVM_DEBUG(tree.print(llvm::dbgs())); 712 LLVM_DEBUG(llvm::dbgs() << '\n'); 713 714 Operator &resultOp = tree.getDialectOp(opMap); 715 auto numOpArgs = resultOp.getNumArgs(); 716 717 if (numOpArgs != tree.getNumArgs()) { 718 PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " 719 "{1} in pattern vs. {2} in definition", 720 resultOp.getOperationName(), tree.getNumArgs(), 721 numOpArgs)); 722 } 723 724 // A map to collect all nested DAG child nodes' names, with operand index as 725 // the key. This includes both bound and unbound child nodes. 726 ChildNodeIndexNameMap childNodeNames; 727 728 // First go through all the child nodes who are nested DAG constructs to 729 // create ops for them and remember the symbol names for them, so that we can 730 // use the results in the current node. This happens in a recursive manner. 731 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 732 if (auto child = tree.getArgAsNestedDag(i)) { 733 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 734 } 735 } 736 737 // The name of the local variable holding this op. 738 std::string valuePackName; 739 // The symbol for holding the result of this pattern. Note that the result of 740 // this pattern is not necessarily the same as the variable created by this 741 // pattern because we can use `__N` suffix to refer only a specific result if 742 // the generated op is a multi-result op. 743 std::string resultValue; 744 if (tree.getSymbol().empty()) { 745 // No symbol is explicitly bound to this op in the pattern. Generate a 746 // unique name. 747 valuePackName = resultValue = getUniqueSymbol(&resultOp); 748 } else { 749 resultValue = tree.getSymbol(); 750 // Strip the index to get the name for the value pack and use it to name the 751 // local variable for the op. 752 valuePackName = SymbolInfoMap::getValuePackName(resultValue); 753 } 754 755 // Create the local variable for this op. 756 os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(), 757 valuePackName); 758 os.indent(4) << "{\n"; 759 760 // Right now ODS don't have general type inference support. Except a few 761 // special cases listed below, DRR needs to supply types for all results 762 // when building an op. 763 bool isSameOperandsAndResultType = 764 resultOp.getTrait("OpTrait::SameOperandsAndResultType"); 765 bool useFirstAttr = resultOp.getTrait("OpTrait::FirstAttrDerivedResultType"); 766 767 if (isSameOperandsAndResultType || useFirstAttr) { 768 // We know how to deduce the result type for ops with these traits and we've 769 // generated builders taking aggregrate parameters. Use those builders to 770 // create the ops. 771 772 // First prepare local variables for op arguments used in builder call. 773 createAggregateLocalVarsForOpArgs(tree, childNodeNames); 774 // Then create the op. 775 os.indent(6) << formatv( 776 "{0} = rewriter.create<{1}>(loc, tblgen_values, tblgen_attrs);\n", 777 valuePackName, resultOp.getQualCppClassName()); 778 os.indent(4) << "}\n"; 779 return resultValue; 780 } 781 782 bool isBroadcastable = 783 resultOp.getTrait("OpTrait::BroadcastableTwoOperandsOneResult"); 784 bool usePartialResults = valuePackName != resultValue; 785 786 if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) { 787 // For these cases (broadcastable ops, op results used both as auxiliary 788 // values and replacement values, ops in nested patterns, auxiliary ops), we 789 // still need to supply the result types when building the op. But because 790 // we don't generate a builder automatically with ODS for them, it's the 791 // developer's responsiblity to make sure such a builder (with result type 792 // deduction ability) exists. We go through the separate-parameter builder 793 // here given that it's easier for developers to write compared to 794 // aggregate-parameter builders. 795 createSeparateLocalVarsForOpArgs(tree, childNodeNames); 796 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, 797 resultOp.getQualCppClassName()); 798 supplyValuesForOpArgs(tree, childNodeNames); 799 os << "\n );\n"; 800 os.indent(4) << "}\n"; 801 return resultValue; 802 } 803 804 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 805 // generated from the source pattern root op. Then we can use the source 806 // pattern's value types to determine the value type of the generated op 807 // here. 808 809 // First prepare local variables for op arguments used in builder call. 810 createAggregateLocalVarsForOpArgs(tree, childNodeNames); 811 812 // Then prepare the result types. We need to specify the types for all 813 // results. 814 os.indent(6) << formatv( 815 "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n"); 816 int numResults = resultOp.getNumResults(); 817 if (numResults != 0) { 818 for (int i = 0; i < numResults; ++i) 819 os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) {{" 820 "tblgen_types.push_back(v->getType()); }\n", 821 resultIndex + i); 822 } 823 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc, tblgen_types, " 824 "tblgen_values, tblgen_attrs);\n", 825 valuePackName, resultOp.getQualCppClassName()); 826 os.indent(4) << "}\n"; 827 return resultValue; 828 } 829 830 void PatternEmitter::createSeparateLocalVarsForOpArgs( 831 DagNode node, ChildNodeIndexNameMap &childNodeNames) { 832 Operator &resultOp = node.getDialectOp(opMap); 833 834 // Now prepare operands used for building this op: 835 // * If the operand is non-variadic, we create a `Value*` local variable. 836 // * If the operand is variadic, we create a `SmallVector<Value*>` local 837 // variable. 838 839 int valueIndex = 0; // An index for uniquing local variable names. 840 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 841 const auto *operand = 842 resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>(); 843 if (!operand) { 844 // We do not need special handling for attributes. 845 continue; 846 } 847 848 std::string varName; 849 if (operand->isVariadic()) { 850 varName = formatv("tblgen_values_{0}", valueIndex++); 851 os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName); 852 std::string range; 853 if (node.isNestedDagArg(argIndex)) { 854 range = childNodeNames[argIndex]; 855 } else { 856 range = node.getArgName(argIndex); 857 } 858 // Resolve the symbol for all range use so that we have a uniform way of 859 // capturing the values. 860 range = symbolInfoMap.getValueAndRangeUse(range); 861 os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range, 862 varName); 863 } else { 864 varName = formatv("tblgen_value_{0}", valueIndex++); 865 os.indent(6) << formatv("Value *{0} = ", varName); 866 if (node.isNestedDagArg(argIndex)) { 867 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 868 } else { 869 DagLeaf leaf = node.getArgAsLeaf(argIndex); 870 auto symbol = 871 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 872 if (leaf.isNativeCodeCall()) { 873 os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)); 874 } else { 875 os << symbol; 876 } 877 } 878 os << ";\n"; 879 } 880 881 // Update to use the newly created local variable for building the op later. 882 childNodeNames[argIndex] = varName; 883 } 884 } 885 886 void PatternEmitter::supplyValuesForOpArgs( 887 DagNode node, const ChildNodeIndexNameMap &childNodeNames) { 888 Operator &resultOp = node.getDialectOp(opMap); 889 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); 890 argIndex != numOpArgs; ++argIndex) { 891 // Start each argment on its own line. 892 (os << ",\n").indent(8); 893 894 Argument opArg = resultOp.getArg(argIndex); 895 // Handle the case of operand first. 896 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 897 if (!operand->name.empty()) 898 os << "/*" << operand->name << "=*/"; 899 os << childNodeNames.lookup(argIndex); 900 continue; 901 } 902 903 // The argument in the op definition. 904 auto opArgName = resultOp.getArgName(argIndex); 905 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 906 if (!subTree.isNativeCodeCall()) 907 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 908 "for creating attribute"); 909 os << formatv("/*{0}=*/{1}", opArgName, 910 handleReplaceWithNativeCodeCall(subTree)); 911 } else { 912 auto leaf = node.getArgAsLeaf(argIndex); 913 // The argument in the result DAG pattern. 914 auto patArgName = node.getArgName(argIndex); 915 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 916 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 917 if (!opArg.is<NamedAttribute *>()) 918 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 919 if (!patArgName.empty()) 920 os << "/*" << patArgName << "=*/"; 921 } else { 922 os << "/*" << opArgName << "=*/"; 923 } 924 os << handleOpArgument(leaf, patArgName); 925 } 926 } 927 } 928 929 void PatternEmitter::createAggregateLocalVarsForOpArgs( 930 DagNode node, const ChildNodeIndexNameMap &childNodeNames) { 931 Operator &resultOp = node.getDialectOp(opMap); 932 933 os.indent(6) << formatv( 934 "SmallVector<Value *, 4> tblgen_values; (void)tblgen_values;\n"); 935 os.indent(6) << formatv( 936 "SmallVector<NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs;\n"); 937 938 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 939 if (const auto *attr = 940 resultOp.getArg(argIndex).dyn_cast<NamedAttribute *>()) { 941 const char *addAttrCmd = "if ({1}) {{" 942 " tblgen_attrs.emplace_back(rewriter." 943 "getIdentifier(\"{0}\"), {1}); }\n"; 944 // The argument in the op definition. 945 auto opArgName = resultOp.getArgName(argIndex); 946 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 947 if (!subTree.isNativeCodeCall()) 948 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 949 "for creating attribute"); 950 os.indent(6) << formatv(addAttrCmd, opArgName, 951 handleReplaceWithNativeCodeCall(subTree)); 952 } else { 953 auto leaf = node.getArgAsLeaf(argIndex); 954 // The argument in the result DAG pattern. 955 auto patArgName = node.getArgName(argIndex); 956 os.indent(6) << formatv(addAttrCmd, opArgName, 957 handleOpArgument(leaf, patArgName)); 958 } 959 continue; 960 } 961 962 const auto *operand = 963 resultOp.getArg(argIndex).get<NamedTypeConstraint *>(); 964 std::string varName; 965 if (operand->isVariadic()) { 966 std::string range; 967 if (node.isNestedDagArg(argIndex)) { 968 range = childNodeNames.lookup(argIndex); 969 } else { 970 range = node.getArgName(argIndex); 971 } 972 // Resolve the symbol for all range use so that we have a uniform way of 973 // capturing the values. 974 range = symbolInfoMap.getValueAndRangeUse(range); 975 os.indent(6) << formatv( 976 "for (auto *v : {0}) tblgen_values.push_back(v);\n", range); 977 } else { 978 os.indent(6) << formatv("tblgen_values.push_back(", varName); 979 if (node.isNestedDagArg(argIndex)) { 980 os << symbolInfoMap.getValueAndRangeUse( 981 childNodeNames.lookup(argIndex)); 982 } else { 983 DagLeaf leaf = node.getArgAsLeaf(argIndex); 984 auto symbol = 985 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 986 if (leaf.isNativeCodeCall()) { 987 os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)); 988 } else { 989 os << symbol; 990 } 991 } 992 os << ");\n"; 993 } 994 } 995 } 996 997 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 998 emitSourceFileHeader("Rewriters", os); 999 1000 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 1001 auto numPatterns = patterns.size(); 1002 1003 // We put the map here because it can be shared among multiple patterns. 1004 RecordOperatorMap recordOpMap; 1005 1006 std::vector<std::string> rewriterNames; 1007 rewriterNames.reserve(numPatterns); 1008 1009 std::string baseRewriterName = "GeneratedConvert"; 1010 int rewriterIndex = 0; 1011 1012 for (Record *p : patterns) { 1013 std::string name; 1014 if (p->isAnonymous()) { 1015 // If no name is provided, ensure unique rewriter names simply by 1016 // appending unique suffix. 1017 name = baseRewriterName + llvm::utostr(rewriterIndex++); 1018 } else { 1019 name = p->getName(); 1020 } 1021 LLVM_DEBUG(llvm::dbgs() 1022 << "=== start generating pattern '" << name << "' ===\n"); 1023 PatternEmitter(p, &recordOpMap, os).emit(name); 1024 LLVM_DEBUG(llvm::dbgs() 1025 << "=== done generating pattern '" << name << "' ===\n"); 1026 rewriterNames.push_back(std::move(name)); 1027 } 1028 1029 // Emit function to add the generated matchers to the pattern list. 1030 os << "void populateWithGenerated(MLIRContext *context, " 1031 << "OwningRewritePatternList *patterns) {\n"; 1032 for (const auto &name : rewriterNames) { 1033 os << " patterns->insert<" << name << ">(context);\n"; 1034 } 1035 os << "}\n"; 1036 } 1037 1038 static mlir::GenRegistration 1039 genRewriters("gen-rewriters", "Generate pattern rewriters", 1040 [](const RecordKeeper &records, raw_ostream &os) { 1041 emitRewriters(records, os); 1042 return false; 1043 }); 1044