1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Support/IndentedOstream.h" 14 #include "mlir/TableGen/Attribute.h" 15 #include "mlir/TableGen/CodeGenHelpers.h" 16 #include "mlir/TableGen/Format.h" 17 #include "mlir/TableGen/GenInfo.h" 18 #include "mlir/TableGen/Operator.h" 19 #include "mlir/TableGen/Pattern.h" 20 #include "mlir/TableGen/Predicate.h" 21 #include "mlir/TableGen/Type.h" 22 #include "llvm/ADT/FunctionExtras.h" 23 #include "llvm/ADT/SetVector.h" 24 #include "llvm/ADT/StringExtras.h" 25 #include "llvm/ADT/StringSet.h" 26 #include "llvm/Support/CommandLine.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/FormatAdapters.h" 29 #include "llvm/Support/PrettyStackTrace.h" 30 #include "llvm/Support/Signals.h" 31 #include "llvm/TableGen/Error.h" 32 #include "llvm/TableGen/Main.h" 33 #include "llvm/TableGen/Record.h" 34 #include "llvm/TableGen/TableGenBackend.h" 35 36 using namespace mlir; 37 using namespace mlir::tblgen; 38 39 using llvm::formatv; 40 using llvm::Record; 41 using llvm::RecordKeeper; 42 43 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 44 45 namespace llvm { 46 template <> 47 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 48 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 49 raw_ostream &os, StringRef style) { 50 os << v.first << ":" << v.second; 51 } 52 }; 53 } // namespace llvm 54 55 //===----------------------------------------------------------------------===// 56 // PatternEmitter 57 //===----------------------------------------------------------------------===// 58 59 namespace { 60 61 class StaticMatcherHelper; 62 63 class PatternEmitter { 64 public: 65 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os, 66 StaticMatcherHelper &helper); 67 68 // Emits the mlir::RewritePattern struct named `rewriteName`. 69 void emit(StringRef rewriteName); 70 71 // Emits the static function of DAG matcher. 72 void emitStaticMatcher(DagNode tree, std::string funcName); 73 74 private: 75 // Emits the code for matching ops. 76 void emitMatchLogic(DagNode tree, StringRef opName); 77 78 // Emits the code for rewriting ops. 79 void emitRewriteLogic(); 80 81 //===--------------------------------------------------------------------===// 82 // Match utilities 83 //===--------------------------------------------------------------------===// 84 85 // Emits C++ statements for matching the DAG structure. 86 void emitMatch(DagNode tree, StringRef name, int depth); 87 88 // Emit C++ function call to static DAG matcher. 89 void emitStaticMatchCall(DagNode tree, StringRef name); 90 91 // Emit C++ function call to static type/attribute constraint function. 92 void emitStaticVerifierCall(StringRef funcName, StringRef opName, 93 StringRef arg, StringRef failureStr); 94 95 // Emits C++ statements for matching using a native code call. 96 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); 97 98 // Emits C++ statements for matching the op constrained by the given DAG 99 // `tree` returning the op's variable name. 100 void emitOpMatch(DagNode tree, StringRef opName, int depth); 101 102 // Emits C++ statements for matching the `argIndex`-th argument of the given 103 // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the 104 // bound name and the constraint of the operand respectively. 105 void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName, 106 DagLeaf operandMatcher, StringRef argName, 107 int argIndex); 108 109 // Emits C++ statements for matching the operands which can be matched in 110 // either order. 111 void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, 112 StringRef opName, int argIndex, int &operandIndex, 113 int depth); 114 115 // Emits C++ statements for matching the `argIndex`-th argument of the given 116 // DAG `tree` as an attribute. 117 void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, 118 int depth); 119 120 // Emits C++ for checking a match with a corresponding match failure 121 // diagnostic. 122 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, 123 const llvm::formatv_object_base &failureFmt); 124 125 // Emits C++ for checking a match with a corresponding match failure 126 // diagnostics. 127 void emitMatchCheck(StringRef opName, const std::string &matchStr, 128 const std::string &failureStr); 129 130 //===--------------------------------------------------------------------===// 131 // Rewrite utilities 132 //===--------------------------------------------------------------------===// 133 134 // The entry point for handling a result pattern rooted at `resultTree`. This 135 // method dispatches to concrete handlers according to `resultTree`'s kind and 136 // returns a symbol representing the whole value pack. Callers are expected to 137 // further resolve the symbol according to the specific use case. 138 // 139 // `depth` is the nesting level of `resultTree`; 0 means top-level result 140 // pattern. For top-level result pattern, `resultIndex` indicates which result 141 // of the matched root op this pattern is intended to replace, which can be 142 // used to deduce the result type of the op generated from this result 143 // pattern. 144 std::string handleResultPattern(DagNode resultTree, int resultIndex, 145 int depth); 146 147 // Emits the C++ statement to replace the matched DAG with a value built via 148 // calling native C++ code. 149 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); 150 151 // Returns the symbol of the old value serving as the replacement. 152 StringRef handleReplaceWithValue(DagNode tree); 153 154 // Trailing directives are used at the end of DAG node argument lists to 155 // specify additional behaviour for op matchers and creators, etc. 156 struct TrailingDirectives { 157 // DAG node containing the `location` directive. Null if there is none. 158 DagNode location; 159 160 // DAG node containing the `returnType` directive. Null if there is none. 161 DagNode returnType; 162 163 // Number of found trailing directives. 164 int numDirectives; 165 }; 166 167 // Collect any trailing directives. 168 TrailingDirectives getTrailingDirectives(DagNode tree); 169 170 // Returns the location value to use. 171 std::string getLocation(TrailingDirectives &tail); 172 173 // Returns the location value to use. 174 std::string handleLocationDirective(DagNode tree); 175 176 // Emit return type argument. 177 std::string handleReturnTypeArg(DagNode returnType, int i, int depth); 178 179 // Emits the C++ statement to build a new op out of the given DAG `tree` and 180 // returns the variable name that this op is assigned to. If the root op in 181 // DAG `tree` has a specified name, the created op will be assigned to a 182 // variable of the given name. Otherwise, a unique name will be used as the 183 // result value name. 184 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 185 186 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; 187 188 // Emits a local variable for each value and attribute to be used for creating 189 // an op. 190 void createSeparateLocalVarsForOpArgs(DagNode node, 191 ChildNodeIndexNameMap &childNodeNames); 192 193 // Emits the concrete arguments used to call an op's builder. 194 void supplyValuesForOpArgs(DagNode node, 195 const ChildNodeIndexNameMap &childNodeNames, 196 int depth); 197 198 // Emits the local variables for holding all values as a whole and all named 199 // attributes as a whole to be used for creating an op. 200 void createAggregateLocalVarsForOpArgs( 201 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); 202 203 // Returns the C++ expression to construct a constant attribute of the given 204 // `value` for the given attribute kind `attr`. 205 std::string handleConstantAttr(Attribute attr, const Twine &value); 206 207 // Returns the C++ expression to build an argument from the given DAG `leaf`. 208 // `patArgName` is used to bound the argument to the source pattern. 209 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 210 211 //===--------------------------------------------------------------------===// 212 // General utilities 213 //===--------------------------------------------------------------------===// 214 215 // Collects all of the operations within the given dag tree. 216 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 217 218 // Returns a unique symbol for a local variable of the given `op`. 219 std::string getUniqueSymbol(const Operator *op); 220 221 //===--------------------------------------------------------------------===// 222 // Symbol utilities 223 //===--------------------------------------------------------------------===// 224 225 // Returns how many static values the given DAG `node` correspond to. 226 int getNodeValueCount(DagNode node); 227 228 private: 229 // Pattern instantiation location followed by the location of multiclass 230 // prototypes used. This is intended to be used as a whole to 231 // PrintFatalError() on errors. 232 ArrayRef<SMLoc> loc; 233 234 // Op's TableGen Record to wrapper object. 235 RecordOperatorMap *opMap; 236 237 // Handy wrapper for pattern being emitted. 238 Pattern pattern; 239 240 // Map for all bound symbols' info. 241 SymbolInfoMap symbolInfoMap; 242 243 StaticMatcherHelper &staticMatcherHelper; 244 245 // The next unused ID for newly created values. 246 unsigned nextValueId = 0; 247 248 raw_indented_ostream os; 249 250 // Format contexts containing placeholder substitutions. 251 FmtContext fmtCtx; 252 }; 253 254 // Tracks DagNode's reference multiple times across patterns. Enables generating 255 // static matcher functions for DagNode's referenced multiple times rather than 256 // inlining them. 257 class StaticMatcherHelper { 258 public: 259 StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper, 260 RecordOperatorMap &mapper); 261 262 // Determine if we should inline the match logic or delegate to a static 263 // function. 264 bool useStaticMatcher(DagNode node) { 265 return refStats[node] > kStaticMatcherThreshold; 266 } 267 268 // Get the name of the static DAG matcher function corresponding to the node. 269 std::string getMatcherName(DagNode node) { 270 assert(useStaticMatcher(node)); 271 return matcherNames[node]; 272 } 273 274 // Get the name of static type/attribute verification function. 275 StringRef getVerifierName(DagLeaf leaf); 276 277 // Collect the `Record`s, i.e., the DRR, so that we can get the information of 278 // the duplicated DAGs. 279 void addPattern(Record *record); 280 281 // Emit all static functions of DAG Matcher. 282 void populateStaticMatchers(raw_ostream &os); 283 284 // Emit all static functions for Constraints. 285 void populateStaticConstraintFunctions(raw_ostream &os); 286 287 private: 288 static constexpr unsigned kStaticMatcherThreshold = 1; 289 290 // Consider two patterns as down below, 291 // DagNode_Root_A DagNode_Root_B 292 // \ \ 293 // DagNode_C DagNode_C 294 // \ \ 295 // DagNode_D DagNode_D 296 // 297 // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of 298 // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced 299 // multiple times so we'll have static matchers for both of them. When we're 300 // emitting the match logic for DagNode_C, we will check if DagNode_D has the 301 // static matcher generated. If so, then we'll generate a call to the 302 // function, inline otherwise. In this case, inlining is not what we want. As 303 // a result, generate the static matcher in topological order to ensure all 304 // the dependent static matchers are generated and we can avoid accidentally 305 // inlining. 306 // 307 // The topological order of all the DagNodes among all patterns. 308 SmallVector<std::pair<DagNode, Record *>> topologicalOrder; 309 310 RecordOperatorMap &opMap; 311 312 // Records of the static function name of each DagNode 313 DenseMap<DagNode, std::string> matcherNames; 314 315 // After collecting all the DagNode in each pattern, `refStats` records the 316 // number of users for each DagNode. We will generate the static matcher for a 317 // DagNode while the number of users exceeds a certain threshold. 318 DenseMap<DagNode, unsigned> refStats; 319 320 // Number of static matcher generated. This is used to generate a unique name 321 // for each DagNode. 322 int staticMatcherCounter = 0; 323 324 // The DagLeaf which contains type or attr constraint. 325 SetVector<DagLeaf> constraints; 326 327 // Static type/attribute verification function emitter. 328 StaticVerifierFunctionEmitter staticVerifierEmitter; 329 }; 330 331 } // namespace 332 333 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 334 raw_ostream &os, StaticMatcherHelper &helper) 335 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 336 symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) { 337 fmtCtx.withBuilder("rewriter"); 338 } 339 340 std::string PatternEmitter::handleConstantAttr(Attribute attr, 341 const Twine &value) { 342 if (!attr.isConstBuildable()) 343 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 344 " does not have the 'constBuilderCall' field"); 345 346 // TODO: Verify the constants here 347 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value)); 348 } 349 350 void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) { 351 os << formatv( 352 "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, " 353 "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation " 354 "*, 4> &tblgen_ops", 355 funcName); 356 357 // We pass the reference of the variables that need to be captured. Hence we 358 // need to collect all the symbols in the tree first. 359 pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true); 360 symbolInfoMap.assignUniqueAlternativeNames(); 361 for (const auto &info : symbolInfoMap) 362 os << formatv(", {0}", info.second.getArgDecl(info.first)); 363 364 os << ") {\n"; 365 os.indent(); 366 os << "(void)tblgen_ops;\n"; 367 368 // Note that a static matcher is considered at least one step from the match 369 // entry. 370 emitMatch(tree, "op0", /*depth=*/1); 371 372 os << "return ::mlir::success();\n"; 373 os.unindent(); 374 os << "}\n\n"; 375 } 376 377 // Helper function to match patterns. 378 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { 379 if (tree.isNativeCodeCall()) { 380 emitNativeCodeMatch(tree, name, depth); 381 return; 382 } 383 384 if (tree.isOperation()) { 385 emitOpMatch(tree, name, depth); 386 return; 387 } 388 389 PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match."); 390 } 391 392 void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) { 393 std::string funcName = staticMatcherHelper.getMatcherName(tree); 394 os << formatv("if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", funcName, 395 opName); 396 397 // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in 398 // one pass. 399 400 // In general, bound symbol should have the unique name in the pattern but 401 // for the operand, binding same symbol to multiple operands imply a 402 // constraint at the same time. In this case, we will rename those operands 403 // with different names. As a result, we need to collect all the symbolInfos 404 // from the DagNode then get the updated name of the local variables from the 405 // global symbolInfoMap. 406 407 // Collect all the bound symbols in the Dag 408 SymbolInfoMap localSymbolMap(loc); 409 pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true); 410 411 for (const auto &info : localSymbolMap) { 412 auto name = info.first; 413 auto symboInfo = info.second; 414 auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo); 415 os << formatv(", {0}", ret->second.getVarName(name)); 416 } 417 418 os << "))) {\n"; 419 os.scope().os << "return ::mlir::failure();\n"; 420 os << "}\n"; 421 } 422 423 void PatternEmitter::emitStaticVerifierCall(StringRef funcName, 424 StringRef opName, StringRef arg, 425 StringRef failureStr) { 426 os << formatv("if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n", 427 funcName, opName, arg, failureStr); 428 os.scope().os << "return ::mlir::failure();\n"; 429 os << "}\n"; 430 } 431 432 // Helper function to match patterns. 433 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, 434 int depth) { 435 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: "); 436 LLVM_DEBUG(tree.print(llvm::dbgs())); 437 LLVM_DEBUG(llvm::dbgs() << '\n'); 438 439 // The order of generating static matcher follows the topological order so 440 // that for every dependent DagNode already have their static matcher 441 // generated if needed. The reason we check if `getMatcherName(tree).empty()` 442 // is when we are generating the static matcher for a DagNode itself. In this 443 // case, we need to emit the function body rather than a function call. 444 if (staticMatcherHelper.useStaticMatcher(tree) && 445 !staticMatcherHelper.getMatcherName(tree).empty()) { 446 emitStaticMatchCall(tree, opName); 447 448 // NativeCodeCall will never be at depth 0 so that we don't need to catch 449 // the root operation as emitOpMatch(); 450 451 return; 452 } 453 454 // TODO(suderman): iterate through arguments, determine their types, output 455 // names. 456 SmallVector<std::string, 8> capture; 457 458 raw_indented_ostream::DelimitedScope scope(os); 459 460 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 461 std::string argName = formatv("arg{0}_{1}", depth, i); 462 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 463 if (argTree.isEither()) 464 PrintFatalError(loc, "NativeCodeCall cannot have `either` operands"); 465 466 os << "::mlir::Value " << argName << ";\n"; 467 } else { 468 auto leaf = tree.getArgAsLeaf(i); 469 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { 470 os << "::mlir::Attribute " << argName << ";\n"; 471 } else { 472 os << "::mlir::Value " << argName << ";\n"; 473 } 474 } 475 476 capture.push_back(std::move(argName)); 477 } 478 479 auto tail = getTrailingDirectives(tree); 480 if (tail.returnType) 481 PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier"); 482 auto locToUse = getLocation(tail); 483 484 auto fmt = tree.getNativeCodeTemplate(); 485 if (fmt.count("$_self") != 1) 486 PrintFatalError(loc, "NativeCodeCall must have $_self as argument for " 487 "passing the defining Operation"); 488 489 auto nativeCodeCall = std::string( 490 tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), 491 static_cast<ArrayRef<std::string>>(capture))); 492 493 emitMatchCheck(opName, formatv("!::mlir::failed({0})", nativeCodeCall), 494 formatv("\"{0} return ::mlir::failure\"", nativeCodeCall)); 495 496 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 497 auto name = tree.getArgName(i); 498 if (!name.empty() && name != "_") { 499 os << formatv("{0} = {1};\n", name, capture[i]); 500 } 501 } 502 503 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 504 std::string argName = capture[i]; 505 506 // Handle nested DAG construct first 507 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 508 PrintFatalError( 509 loc, formatv("Matching nested tree in NativeCodecall not support for " 510 "{0} as arg {1}", 511 argName, i)); 512 } 513 514 DagLeaf leaf = tree.getArgAsLeaf(i); 515 516 // The parameter for native function doesn't bind any constraints. 517 if (leaf.isUnspecified()) 518 continue; 519 520 auto constraint = leaf.getAsConstraint(); 521 522 std::string self; 523 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) 524 self = argName; 525 else 526 self = formatv("{0}.getType()", argName); 527 StringRef verifier = staticMatcherHelper.getVerifierName(leaf); 528 emitStaticVerifierCall( 529 verifier, opName, self, 530 formatv("\"operand {0} of native code call '{1}' failed to satisfy " 531 "constraint: " 532 "'{2}'\"", 533 i, tree.getNativeCodeTemplate(), 534 escapeString(constraint.getSummary())) 535 .str()); 536 } 537 538 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n"); 539 } 540 541 // Helper function to match patterns. 542 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { 543 Operator &op = tree.getDialectOp(opMap); 544 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 545 << op.getOperationName() << "' at depth " << depth 546 << '\n'); 547 548 auto getCastedName = [depth]() -> std::string { 549 return formatv("castedOp{0}", depth); 550 }; 551 552 // The order of generating static matcher follows the topological order so 553 // that for every dependent DagNode already have their static matcher 554 // generated if needed. The reason we check if `getMatcherName(tree).empty()` 555 // is when we are generating the static matcher for a DagNode itself. In this 556 // case, we need to emit the function body rather than a function call. 557 if (staticMatcherHelper.useStaticMatcher(tree) && 558 !staticMatcherHelper.getMatcherName(tree).empty()) { 559 emitStaticMatchCall(tree, opName); 560 // In the codegen of rewriter, we suppose that castedOp0 will capture the 561 // root operation. Manually add it if the root DagNode is a static matcher. 562 if (depth == 0) 563 os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); " 564 "(void){2};\n", 565 opName, op.getQualCppClassName(), getCastedName()); 566 return; 567 } 568 569 std::string castedName = getCastedName(); 570 os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); " 571 "(void){0};\n", 572 castedName, opName, op.getQualCppClassName()); 573 574 // Skip the operand matching at depth 0 as the pattern rewriter already does. 575 if (depth != 0) 576 emitMatchCheck(opName, /*matchStr=*/castedName, 577 formatv("\"{0} is not {1} type\"", castedName, 578 op.getQualCppClassName())); 579 580 // If the operand's name is set, set to that variable. 581 auto name = tree.getSymbol(); 582 if (!name.empty()) 583 os << formatv("{0} = {1};\n", name, castedName); 584 585 for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; 586 ++i, ++opArgIdx) { 587 auto opArg = op.getArg(opArgIdx); 588 std::string argName = formatv("op{0}", depth + 1); 589 590 // Handle nested DAG construct first 591 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 592 if (argTree.isEither()) { 593 emitEitherOperandMatch(tree, argTree, castedName, opArgIdx, nextOperand, 594 depth); 595 ++opArgIdx; 596 continue; 597 } 598 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) { 599 if (operand->isVariableLength()) { 600 auto error = formatv("use nested DAG construct to match op {0}'s " 601 "variadic operand #{1} unsupported now", 602 op.getOperationName(), opArgIdx); 603 PrintFatalError(loc, error); 604 } 605 } 606 607 os << "{\n"; 608 609 // Attributes don't count for getODSOperands. 610 // TODO: Operand is a Value, check if we should remove `getDefiningOp()`. 611 os.indent() << formatv( 612 "auto *{0} = " 613 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", 614 argName, castedName, nextOperand); 615 // Null check of operand's definingOp 616 emitMatchCheck( 617 castedName, /*matchStr=*/argName, 618 formatv("\"There's no operation that defines operand {0} of {1}\"", 619 nextOperand++, castedName)); 620 emitMatch(argTree, argName, depth + 1); 621 os << formatv("tblgen_ops.push_back({0});\n", argName); 622 os.unindent() << "}\n"; 623 continue; 624 } 625 626 // Next handle DAG leaf: operand or attribute 627 if (opArg.is<NamedTypeConstraint *>()) { 628 auto operandName = 629 formatv("{0}.getODSOperands({1})", castedName, nextOperand); 630 emitOperandMatch(tree, castedName, operandName.str(), 631 /*operandMatcher=*/tree.getArgAsLeaf(i), 632 /*argName=*/tree.getArgName(i), opArgIdx); 633 ++nextOperand; 634 } else if (opArg.is<NamedAttribute *>()) { 635 emitAttributeMatch(tree, opName, opArgIdx, depth); 636 } else { 637 PrintFatalError(loc, "unhandled case when matching op"); 638 } 639 } 640 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 641 << op.getOperationName() << "' at depth " << depth 642 << '\n'); 643 } 644 645 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, 646 StringRef operandName, 647 DagLeaf operandMatcher, StringRef argName, 648 int argIndex) { 649 Operator &op = tree.getDialectOp(opMap); 650 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>(); 651 652 // If a constraint is specified, we need to generate C++ statements to 653 // check the constraint. 654 if (!operandMatcher.isUnspecified()) { 655 if (!operandMatcher.isOperandMatcher()) 656 PrintFatalError( 657 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 658 op.getOperationName(), argIndex + 1)); 659 660 // Only need to verify if the matcher's type is different from the one 661 // of op definition. 662 Constraint constraint = operandMatcher.getAsConstraint(); 663 if (operand->constraint != constraint) { 664 if (operand->isVariableLength()) { 665 auto error = formatv( 666 "further constrain op {0}'s variadic operand #{1} unsupported now", 667 op.getOperationName(), argIndex); 668 PrintFatalError(loc, error); 669 } 670 auto self = formatv("(*{0}.begin()).getType()", operandName); 671 StringRef verifier = staticMatcherHelper.getVerifierName(operandMatcher); 672 emitStaticVerifierCall( 673 verifier, opName, self.str(), 674 formatv( 675 "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"", 676 operand - op.operand_begin(), op.getOperationName(), 677 escapeString(constraint.getSummary())) 678 .str()); 679 } 680 } 681 682 // Capture the value 683 // `$_` is a special symbol to ignore op argument matching. 684 if (!argName.empty() && argName != "_") { 685 auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex); 686 os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName); 687 } 688 } 689 690 void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, 691 StringRef opName, int argIndex, 692 int &operandIndex, int depth) { 693 constexpr int numEitherArgs = 2; 694 if (eitherArgTree.getNumArgs() != numEitherArgs) 695 PrintFatalError(loc, "`either` only supports grouping two operands"); 696 697 Operator &op = tree.getDialectOp(opMap); 698 699 std::string codeBuffer; 700 llvm::raw_string_ostream tblgenOps(codeBuffer); 701 702 std::string lambda = formatv("eitherLambda{0}", depth); 703 os << formatv( 704 "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n", 705 lambda); 706 707 os.indent(); 708 709 for (int i = 0; i < numEitherArgs; ++i, ++argIndex) { 710 if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) { 711 if (argTree.isEither()) 712 PrintFatalError(loc, "either cannot be nested"); 713 714 std::string argName = formatv("local_op_{0}", i).str(); 715 716 os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName, 717 i); 718 719 // Indent emitMatchCheck and emitMatch because they declare local 720 // variables. 721 os << "{\n"; 722 os.indent(); 723 724 emitMatchCheck( 725 opName, /*matchStr=*/argName, 726 formatv("\"There's no operation that defines operand {0} of {1}\"", 727 operandIndex++, opName)); 728 emitMatch(argTree, argName, depth + 1); 729 730 os.unindent() << "}\n"; 731 732 // `tblgen_ops` is used to collect the matched operations. In either, we 733 // need to queue the operation only if the matching success. Thus we emit 734 // the code at the end. 735 tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName); 736 } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) { 737 emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(), 738 /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i), 739 /*argName=*/eitherArgTree.getArgName(i), argIndex); 740 ++operandIndex; 741 } else { 742 PrintFatalError(loc, "either can only be applied on operand"); 743 } 744 } 745 746 os << tblgenOps.str(); 747 os << "return ::mlir::success();\n"; 748 os.unindent() << "};\n"; 749 750 os << "{\n"; 751 os.indent(); 752 753 os << formatv("auto eitherOperand0 = {0}.getODSOperands({1});\n", opName, 754 operandIndex - 2); 755 os << formatv("auto eitherOperand1 = {0}.getODSOperands({1});\n", opName, 756 operandIndex - 1); 757 758 os << formatv("if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && " 759 "::mlir::failed({0}(eitherOperand1, " 760 "eitherOperand0)))\n", 761 lambda); 762 os.indent() << "return ::mlir::failure();\n"; 763 764 os.unindent().unindent() << "}\n"; 765 } 766 767 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, 768 int argIndex, int depth) { 769 Operator &op = tree.getDialectOp(opMap); 770 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>(); 771 const auto &attr = namedAttr->attr; 772 773 os << "{\n"; 774 os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" 775 "(void)tblgen_attr;\n", 776 opName, attr.getStorageType(), namedAttr->name); 777 778 // TODO: This should use getter method to avoid duplication. 779 if (attr.hasDefaultValue()) { 780 os << "if (!tblgen_attr) tblgen_attr = " 781 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 782 attr.getDefaultValue())) 783 << ";\n"; 784 } else if (attr.isOptional()) { 785 // For a missing attribute that is optional according to definition, we 786 // should just capture a mlir::Attribute() to signal the missing state. 787 // That is precisely what getAttr() returns on missing attributes. 788 } else { 789 emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx), 790 formatv("\"expected op '{0}' to have attribute '{1}' " 791 "of type '{2}'\"", 792 op.getOperationName(), namedAttr->name, 793 attr.getStorageType())); 794 } 795 796 auto matcher = tree.getArgAsLeaf(argIndex); 797 if (!matcher.isUnspecified()) { 798 if (!matcher.isAttrMatcher()) { 799 PrintFatalError( 800 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 801 op.getOperationName(), argIndex + 1)); 802 } 803 804 // If a constraint is specified, we need to generate function call to its 805 // static verifier. 806 StringRef verifier = staticMatcherHelper.getVerifierName(matcher); 807 if (attr.isOptional()) { 808 // Avoid dereferencing null attribute. This is using a simple heuristic to 809 // avoid common cases of attempting to dereference null attribute. This 810 // will return where there is no check if attribute is null unless the 811 // attribute's value is not used. 812 // FIXME: This could be improved as some null dereferences could slip 813 // through. 814 if (!StringRef(matcher.getConditionTemplate()).contains("!$_self") && 815 StringRef(matcher.getConditionTemplate()).contains("$_self")) { 816 os << "if (!tblgen_attr) return ::mlir::failure();\n"; 817 } 818 } 819 emitStaticVerifierCall( 820 verifier, opName, "tblgen_attr", 821 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " 822 "'{2}'\"", 823 op.getOperationName(), namedAttr->name, 824 escapeString(matcher.getAsConstraint().getSummary())) 825 .str()); 826 } 827 828 // Capture the value 829 auto name = tree.getArgName(argIndex); 830 // `$_` is a special symbol to ignore op argument matching. 831 if (!name.empty() && name != "_") { 832 os << formatv("{0} = tblgen_attr;\n", name); 833 } 834 835 os.unindent() << "}\n"; 836 } 837 838 void PatternEmitter::emitMatchCheck( 839 StringRef opName, const FmtObjectBase &matchFmt, 840 const llvm::formatv_object_base &failureFmt) { 841 emitMatchCheck(opName, matchFmt.str(), failureFmt.str()); 842 } 843 844 void PatternEmitter::emitMatchCheck(StringRef opName, 845 const std::string &matchStr, 846 const std::string &failureStr) { 847 848 os << "if (!(" << matchStr << "))"; 849 os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName 850 << ", [&](::mlir::Diagnostic &diag) {\n diag << " 851 << failureStr << ";\n});"; 852 } 853 854 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { 855 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 856 int depth = 0; 857 emitMatch(tree, opName, depth); 858 859 for (auto &appliedConstraint : pattern.getConstraints()) { 860 auto &constraint = appliedConstraint.constraint; 861 auto &entities = appliedConstraint.entities; 862 863 auto condition = constraint.getConditionTemplate(); 864 if (isa<TypeConstraint>(constraint)) { 865 if (entities.size() != 1) 866 PrintFatalError(loc, "type constraint requires exactly one argument"); 867 868 auto self = formatv("({0}.getType())", 869 symbolInfoMap.getValueAndRangeUse(entities.front())); 870 emitMatchCheck( 871 opName, tgfmt(condition, &fmtCtx.withSelf(self.str())), 872 formatv("\"value entity '{0}' failed to satisfy constraint: '{1}'\"", 873 entities.front(), escapeString(constraint.getSummary()))); 874 875 } else if (isa<AttrConstraint>(constraint)) { 876 PrintFatalError( 877 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 878 } else { 879 // TODO: replace formatv arguments with the exact specified 880 // args. 881 if (entities.size() > 4) { 882 PrintFatalError(loc, "only support up to 4-entity constraints now"); 883 } 884 SmallVector<std::string, 4> names; 885 int i = 0; 886 for (int e = entities.size(); i < e; ++i) 887 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 888 std::string self = appliedConstraint.self; 889 if (!self.empty()) 890 self = symbolInfoMap.getValueAndRangeUse(self); 891 for (; i < 4; ++i) 892 names.push_back("<unused>"); 893 emitMatchCheck(opName, 894 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 895 names[1], names[2], names[3]), 896 formatv("\"entities '{0}' failed to satisfy constraint: " 897 "'{1}'\"", 898 llvm::join(entities, ", "), 899 escapeString(constraint.getSummary()))); 900 } 901 } 902 903 // Some of the operands could be bound to the same symbol name, we need 904 // to enforce equality constraint on those. 905 // TODO: we should be able to emit equality checks early 906 // and short circuit unnecessary work if vars are not equal. 907 for (auto symbolInfoIt = symbolInfoMap.begin(); 908 symbolInfoIt != symbolInfoMap.end();) { 909 auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); 910 auto startRange = range.first; 911 auto endRange = range.second; 912 913 auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); 914 for (++startRange; startRange != endRange; ++startRange) { 915 auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); 916 emitMatchCheck( 917 opName, 918 formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), 919 formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, 920 secondOperand)); 921 } 922 923 symbolInfoIt = endRange; 924 } 925 926 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 927 } 928 929 void PatternEmitter::collectOps(DagNode tree, 930 llvm::SmallPtrSetImpl<const Operator *> &ops) { 931 // Check if this tree is an operation. 932 if (tree.isOperation()) { 933 const Operator &op = tree.getDialectOp(opMap); 934 LLVM_DEBUG(llvm::dbgs() 935 << "found operation " << op.getOperationName() << '\n'); 936 ops.insert(&op); 937 } 938 939 // Recurse the arguments of the tree. 940 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 941 if (auto child = tree.getArgAsNestedDag(i)) 942 collectOps(child, ops); 943 } 944 945 void PatternEmitter::emit(StringRef rewriteName) { 946 // Get the DAG tree for the source pattern. 947 DagNode sourceTree = pattern.getSourcePattern(); 948 949 const Operator &rootOp = pattern.getSourceRootOp(); 950 auto rootName = rootOp.getOperationName(); 951 952 // Collect the set of result operations. 953 llvm::SmallPtrSet<const Operator *, 4> resultOps; 954 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 955 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 956 collectOps(pattern.getResultPattern(i), resultOps); 957 } 958 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 959 960 // Emit RewritePattern for Pattern. 961 auto locs = pattern.getLocation(); 962 os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n", 963 make_range(locs.rbegin(), locs.rend())); 964 os << formatv(R"(struct {0} : public ::mlir::RewritePattern { 965 {0}(::mlir::MLIRContext *context) 966 : ::mlir::RewritePattern("{1}", {2}, context, {{)", 967 rewriteName, rootName, pattern.getBenefit()); 968 // Sort result operators by name. 969 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 970 resultOps.end()); 971 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 972 return lhs->getOperationName() < rhs->getOperationName(); 973 }); 974 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) { 975 os << '"' << op->getOperationName() << '"'; 976 }); 977 os << "}) {}\n"; 978 979 // Emit matchAndRewrite() function. 980 { 981 auto classScope = os.scope(); 982 os.printReindented(R"( 983 ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, 984 ::mlir::PatternRewriter &rewriter) const override {)") 985 << '\n'; 986 { 987 auto functionScope = os.scope(); 988 989 // Register all symbols bound in the source pattern. 990 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 991 992 LLVM_DEBUG(llvm::dbgs() 993 << "start creating local variables for capturing matches\n"); 994 os << "// Variables for capturing values and attributes used while " 995 "creating ops\n"; 996 // Create local variables for storing the arguments and results bound 997 // to symbols. 998 for (const auto &symbolInfoPair : symbolInfoMap) { 999 const auto &symbol = symbolInfoPair.first; 1000 const auto &info = symbolInfoPair.second; 1001 1002 os << info.getVarDecl(symbol); 1003 } 1004 // TODO: capture ops with consistent numbering so that it can be 1005 // reused for fused loc. 1006 os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n"; 1007 LLVM_DEBUG(llvm::dbgs() 1008 << "done creating local variables for capturing matches\n"); 1009 1010 os << "// Match\n"; 1011 os << "tblgen_ops.push_back(op0);\n"; 1012 emitMatchLogic(sourceTree, "op0"); 1013 1014 os << "\n// Rewrite\n"; 1015 emitRewriteLogic(); 1016 1017 os << "return ::mlir::success();\n"; 1018 } 1019 os << "};\n"; 1020 } 1021 os << "};\n\n"; 1022 } 1023 1024 void PatternEmitter::emitRewriteLogic() { 1025 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 1026 const Operator &rootOp = pattern.getSourceRootOp(); 1027 int numExpectedResults = rootOp.getNumResults(); 1028 int numResultPatterns = pattern.getNumResultPatterns(); 1029 1030 // First register all symbols bound to ops generated in result patterns. 1031 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 1032 1033 // Only the last N static values generated are used to replace the matched 1034 // root N-result op. We need to calculate the starting index (of the results 1035 // of the matched op) each result pattern is to replace. 1036 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 1037 // If we don't need to replace any value at all, set the replacement starting 1038 // index as the number of result patterns so we skip all of them when trying 1039 // to replace the matched op's results. 1040 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 1041 for (int i = numResultPatterns - 1; i >= 0; --i) { 1042 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 1043 offsets[i] = offsets[i + 1] - numValues; 1044 if (offsets[i] == 0) { 1045 if (replStartIndex == -1) 1046 replStartIndex = i; 1047 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 1048 auto error = formatv( 1049 "cannot use the same multi-result op '{0}' to generate both " 1050 "auxiliary values and values to be used for replacing the matched op", 1051 pattern.getResultPattern(i).getSymbol()); 1052 PrintFatalError(loc, error); 1053 } 1054 } 1055 1056 if (offsets.front() > 0) { 1057 const char error[] = 1058 "not enough values generated to replace the matched op"; 1059 PrintFatalError(loc, error); 1060 } 1061 1062 os << "auto odsLoc = rewriter.getFusedLoc({"; 1063 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 1064 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 1065 } 1066 os << "}); (void)odsLoc;\n"; 1067 1068 // Process auxiliary result patterns. 1069 for (int i = 0; i < replStartIndex; ++i) { 1070 DagNode resultTree = pattern.getResultPattern(i); 1071 auto val = handleResultPattern(resultTree, offsets[i], 0); 1072 // Normal op creation will be streamed to `os` by the above call; but 1073 // NativeCodeCall will only be materialized to `os` if it is used. Here 1074 // we are handling auxiliary patterns so we want the side effect even if 1075 // NativeCodeCall is not replacing matched root op's results. 1076 if (resultTree.isNativeCodeCall() && 1077 resultTree.getNumReturnsOfNativeCode() == 0) 1078 os << val << ";\n"; 1079 } 1080 1081 if (numExpectedResults == 0) { 1082 assert(replStartIndex >= numResultPatterns && 1083 "invalid auxiliary vs. replacement pattern division!"); 1084 // No result to replace. Just erase the op. 1085 os << "rewriter.eraseOp(op0);\n"; 1086 } else { 1087 // Process replacement result patterns. 1088 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n"; 1089 for (int i = replStartIndex; i < numResultPatterns; ++i) { 1090 DagNode resultTree = pattern.getResultPattern(i); 1091 auto val = handleResultPattern(resultTree, offsets[i], 0); 1092 os << "\n"; 1093 // Resolve each symbol for all range use so that we can loop over them. 1094 // We need an explicit cast to `SmallVector` to capture the cases where 1095 // `{0}` resolves to an `Operation::result_range` as well as cases that 1096 // are not iterable (e.g. vector that gets wrapped in additional braces by 1097 // RewriterGen). 1098 // TODO: Revisit the need for materializing a vector. 1099 os << symbolInfoMap.getAllRangeUse( 1100 val, 1101 "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" 1102 " tblgen_repl_values.push_back(v);\n}\n", 1103 "\n"); 1104 } 1105 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n"; 1106 } 1107 1108 // Process supplemtal patterns. 1109 int numSupplementalPatterns = pattern.getNumSupplementalPatterns(); 1110 for (int i = 0, offset = -numSupplementalPatterns; 1111 i < numSupplementalPatterns; ++i) { 1112 DagNode resultTree = pattern.getSupplementalPattern(i); 1113 auto val = handleResultPattern(resultTree, offset++, 0); 1114 if (resultTree.isNativeCodeCall() && 1115 resultTree.getNumReturnsOfNativeCode() == 0) 1116 os << val << ";\n"; 1117 } 1118 1119 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 1120 } 1121 1122 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 1123 return std::string( 1124 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++)); 1125 } 1126 1127 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 1128 int resultIndex, int depth) { 1129 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 1130 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 1131 LLVM_DEBUG(llvm::dbgs() << '\n'); 1132 1133 if (resultTree.isLocationDirective()) { 1134 PrintFatalError(loc, 1135 "location directive can only be used with op creation"); 1136 } 1137 1138 if (resultTree.isNativeCodeCall()) 1139 return handleReplaceWithNativeCodeCall(resultTree, depth); 1140 1141 if (resultTree.isReplaceWithValue()) 1142 return handleReplaceWithValue(resultTree).str(); 1143 1144 // Normal op creation. 1145 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 1146 if (resultTree.getSymbol().empty()) { 1147 // This is an op not explicitly bound to a symbol in the rewrite rule. 1148 // Register the auto-generated symbol for it. 1149 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 1150 } 1151 return symbol; 1152 } 1153 1154 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { 1155 assert(tree.isReplaceWithValue()); 1156 1157 if (tree.getNumArgs() != 1) { 1158 PrintFatalError( 1159 loc, "replaceWithValue directive must take exactly one argument"); 1160 } 1161 1162 if (!tree.getSymbol().empty()) { 1163 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 1164 } 1165 1166 return tree.getArgName(0); 1167 } 1168 1169 std::string PatternEmitter::handleLocationDirective(DagNode tree) { 1170 assert(tree.isLocationDirective()); 1171 auto lookUpArgLoc = [this, &tree](int idx) { 1172 const auto *const lookupFmt = "{0}.getLoc()"; 1173 return symbolInfoMap.getValueAndRangeUse(tree.getArgName(idx), lookupFmt); 1174 }; 1175 1176 if (tree.getNumArgs() == 0) 1177 llvm::PrintFatalError( 1178 "At least one argument to location directive required"); 1179 1180 if (!tree.getSymbol().empty()) 1181 PrintFatalError(loc, "cannot bind symbol to location"); 1182 1183 if (tree.getNumArgs() == 1) { 1184 DagLeaf leaf = tree.getArgAsLeaf(0); 1185 if (leaf.isStringAttr()) 1186 return formatv("::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))", 1187 leaf.getStringAttr()) 1188 .str(); 1189 return lookUpArgLoc(0); 1190 } 1191 1192 std::string ret; 1193 llvm::raw_string_ostream os(ret); 1194 std::string strAttr; 1195 os << "rewriter.getFusedLoc({"; 1196 bool first = true; 1197 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 1198 DagLeaf leaf = tree.getArgAsLeaf(i); 1199 // Handle the optional string value. 1200 if (leaf.isStringAttr()) { 1201 if (!strAttr.empty()) 1202 llvm::PrintFatalError("Only one string attribute may be specified"); 1203 strAttr = leaf.getStringAttr(); 1204 continue; 1205 } 1206 os << (first ? "" : ", ") << lookUpArgLoc(i); 1207 first = false; 1208 } 1209 os << "}"; 1210 if (!strAttr.empty()) { 1211 os << ", rewriter.getStringAttr(\"" << strAttr << "\")"; 1212 } 1213 os << ")"; 1214 return os.str(); 1215 } 1216 1217 std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i, 1218 int depth) { 1219 // Nested NativeCodeCall. 1220 if (auto dagNode = returnType.getArgAsNestedDag(i)) { 1221 if (!dagNode.isNativeCodeCall()) 1222 PrintFatalError(loc, "nested DAG in `returnType` must be a native code " 1223 "call"); 1224 return handleReplaceWithNativeCodeCall(dagNode, depth); 1225 } 1226 // String literal. 1227 auto dagLeaf = returnType.getArgAsLeaf(i); 1228 if (dagLeaf.isStringAttr()) 1229 return tgfmt(dagLeaf.getStringAttr(), &fmtCtx); 1230 return tgfmt( 1231 "$0.getType()", &fmtCtx, 1232 handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i))); 1233 } 1234 1235 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 1236 StringRef patArgName) { 1237 if (leaf.isStringAttr()) 1238 PrintFatalError(loc, "raw string not supported as argument"); 1239 if (leaf.isConstantAttr()) { 1240 auto constAttr = leaf.getAsConstantAttr(); 1241 return handleConstantAttr(constAttr.getAttribute(), 1242 constAttr.getConstantValue()); 1243 } 1244 if (leaf.isEnumAttrCase()) { 1245 auto enumCase = leaf.getAsEnumAttrCase(); 1246 // This is an enum case backed by an IntegerAttr. We need to get its value 1247 // to build the constant. 1248 std::string val = std::to_string(enumCase.getValue()); 1249 return handleConstantAttr(enumCase, val); 1250 } 1251 1252 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 1253 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 1254 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 1255 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 1256 << "' (via symbol ref)\n"); 1257 return argName; 1258 } 1259 if (leaf.isNativeCodeCall()) { 1260 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 1261 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 1262 << "' (via NativeCodeCall)\n"); 1263 return std::string(repl); 1264 } 1265 PrintFatalError(loc, "unhandled case when rewriting op"); 1266 } 1267 1268 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, 1269 int depth) { 1270 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 1271 LLVM_DEBUG(tree.print(llvm::dbgs())); 1272 LLVM_DEBUG(llvm::dbgs() << '\n'); 1273 1274 auto fmt = tree.getNativeCodeTemplate(); 1275 1276 SmallVector<std::string, 16> attrs; 1277 1278 auto tail = getTrailingDirectives(tree); 1279 if (tail.returnType) 1280 PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier"); 1281 auto locToUse = getLocation(tail); 1282 1283 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 1284 if (tree.isNestedDagArg(i)) { 1285 attrs.push_back( 1286 handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1)); 1287 } else { 1288 attrs.push_back( 1289 handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i))); 1290 } 1291 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i 1292 << " replacement: " << attrs[i] << "\n"); 1293 } 1294 1295 std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), 1296 static_cast<ArrayRef<std::string>>(attrs)); 1297 1298 // In general, NativeCodeCall without naming binding don't need this. To 1299 // ensure void helper function has been correctly labeled, i.e., use 1300 // NativeCodeCallVoid, we cache the result to a local variable so that we will 1301 // get a compilation error in the auto-generated file. 1302 // Example. 1303 // // In the td file 1304 // Pat<(...), (NativeCodeCall<Foo> ...)> 1305 // 1306 // --- 1307 // 1308 // // In the auto-generated .cpp 1309 // ... 1310 // // Causes compilation error if Foo() returns void. 1311 // auto nativeVar = Foo(); 1312 // ... 1313 if (tree.getNumReturnsOfNativeCode() != 0) { 1314 // Determine the local variable name for return value. 1315 std::string varName = 1316 SymbolInfoMap::getValuePackName(tree.getSymbol()).str(); 1317 if (varName.empty()) { 1318 varName = formatv("nativeVar_{0}", nextValueId++); 1319 // Register the local variable for later uses. 1320 symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode()); 1321 } 1322 1323 // Catch the return value of helper function. 1324 os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol); 1325 1326 if (!tree.getSymbol().empty()) 1327 symbol = tree.getSymbol().str(); 1328 else 1329 symbol = varName; 1330 } 1331 1332 return symbol; 1333 } 1334 1335 int PatternEmitter::getNodeValueCount(DagNode node) { 1336 if (node.isOperation()) { 1337 // If the op is bound to a symbol in the rewrite rule, query its result 1338 // count from the symbol info map. 1339 auto symbol = node.getSymbol(); 1340 if (!symbol.empty()) { 1341 return symbolInfoMap.getStaticValueCount(symbol); 1342 } 1343 // Otherwise this is an unbound op; we will use all its results. 1344 return pattern.getDialectOp(node).getNumResults(); 1345 } 1346 1347 if (node.isNativeCodeCall()) 1348 return node.getNumReturnsOfNativeCode(); 1349 1350 return 1; 1351 } 1352 1353 PatternEmitter::TrailingDirectives 1354 PatternEmitter::getTrailingDirectives(DagNode tree) { 1355 TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0}; 1356 1357 // Look backwards through the arguments. 1358 auto numPatArgs = tree.getNumArgs(); 1359 for (int i = numPatArgs - 1; i >= 0; --i) { 1360 auto dagArg = tree.getArgAsNestedDag(i); 1361 // A leaf is not a directive. Stop looking. 1362 if (!dagArg) 1363 break; 1364 1365 auto isLocation = dagArg.isLocationDirective(); 1366 auto isReturnType = dagArg.isReturnTypeDirective(); 1367 // If encountered a DAG node that isn't a trailing directive, stop looking. 1368 if (!(isLocation || isReturnType)) 1369 break; 1370 // Save the directive, but error if one of the same type was already 1371 // found. 1372 ++tail.numDirectives; 1373 if (isLocation) { 1374 if (tail.location) 1375 PrintFatalError(loc, "`location` directive can only be specified " 1376 "once"); 1377 tail.location = dagArg; 1378 } else if (isReturnType) { 1379 if (tail.returnType) 1380 PrintFatalError(loc, "`returnType` directive can only be specified " 1381 "once"); 1382 tail.returnType = dagArg; 1383 } 1384 } 1385 1386 return tail; 1387 } 1388 1389 std::string 1390 PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) { 1391 if (tail.location) 1392 return handleLocationDirective(tail.location); 1393 1394 // If no explicit location is given, use the default, all fused, location. 1395 return "odsLoc"; 1396 } 1397 1398 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 1399 int depth) { 1400 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); 1401 LLVM_DEBUG(tree.print(llvm::dbgs())); 1402 LLVM_DEBUG(llvm::dbgs() << '\n'); 1403 1404 Operator &resultOp = tree.getDialectOp(opMap); 1405 auto numOpArgs = resultOp.getNumArgs(); 1406 auto numPatArgs = tree.getNumArgs(); 1407 1408 auto tail = getTrailingDirectives(tree); 1409 auto locToUse = getLocation(tail); 1410 1411 auto inPattern = numPatArgs - tail.numDirectives; 1412 if (numOpArgs != inPattern) { 1413 PrintFatalError(loc, 1414 formatv("resultant op '{0}' argument number mismatch: " 1415 "{1} in pattern vs. {2} in definition", 1416 resultOp.getOperationName(), inPattern, numOpArgs)); 1417 } 1418 1419 // A map to collect all nested DAG child nodes' names, with operand index as 1420 // the key. This includes both bound and unbound child nodes. 1421 ChildNodeIndexNameMap childNodeNames; 1422 1423 // First go through all the child nodes who are nested DAG constructs to 1424 // create ops for them and remember the symbol names for them, so that we can 1425 // use the results in the current node. This happens in a recursive manner. 1426 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 1427 if (auto child = tree.getArgAsNestedDag(i)) 1428 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 1429 } 1430 1431 // The name of the local variable holding this op. 1432 std::string valuePackName; 1433 // The symbol for holding the result of this pattern. Note that the result of 1434 // this pattern is not necessarily the same as the variable created by this 1435 // pattern because we can use `__N` suffix to refer only a specific result if 1436 // the generated op is a multi-result op. 1437 std::string resultValue; 1438 if (tree.getSymbol().empty()) { 1439 // No symbol is explicitly bound to this op in the pattern. Generate a 1440 // unique name. 1441 valuePackName = resultValue = getUniqueSymbol(&resultOp); 1442 } else { 1443 resultValue = std::string(tree.getSymbol()); 1444 // Strip the index to get the name for the value pack and use it to name the 1445 // local variable for the op. 1446 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue)); 1447 } 1448 1449 // Create the local variable for this op. 1450 os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(), 1451 valuePackName); 1452 1453 // Right now ODS don't have general type inference support. Except a few 1454 // special cases listed below, DRR needs to supply types for all results 1455 // when building an op. 1456 bool isSameOperandsAndResultType = 1457 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType"); 1458 bool useFirstAttr = 1459 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"); 1460 1461 if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) { 1462 // We know how to deduce the result type for ops with these traits and we've 1463 // generated builders taking aggregate parameters. Use those builders to 1464 // create the ops. 1465 1466 // First prepare local variables for op arguments used in builder call. 1467 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 1468 1469 // Then create the op. 1470 os.scope("", "\n}\n").os << formatv( 1471 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);", 1472 valuePackName, resultOp.getQualCppClassName(), locToUse); 1473 return resultValue; 1474 } 1475 1476 bool usePartialResults = valuePackName != resultValue; 1477 1478 if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) { 1479 // For these cases (broadcastable ops, op results used both as auxiliary 1480 // values and replacement values, ops in nested patterns, auxiliary ops), we 1481 // still need to supply the result types when building the op. But because 1482 // we don't generate a builder automatically with ODS for them, it's the 1483 // developer's responsibility to make sure such a builder (with result type 1484 // deduction ability) exists. We go through the separate-parameter builder 1485 // here given that it's easier for developers to write compared to 1486 // aggregate-parameter builders. 1487 createSeparateLocalVarsForOpArgs(tree, childNodeNames); 1488 1489 os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, 1490 resultOp.getQualCppClassName(), locToUse); 1491 supplyValuesForOpArgs(tree, childNodeNames, depth); 1492 os << "\n );\n}\n"; 1493 return resultValue; 1494 } 1495 1496 // If we are provided explicit return types, use them to build the op. 1497 // However, if depth == 0 and resultIndex >= 0, it means we are replacing 1498 // the values generated from the source pattern root op. Then we must use the 1499 // source pattern's value types to determine the value type of the generated 1500 // op here. 1501 if (depth == 0 && resultIndex >= 0 && tail.returnType) 1502 PrintFatalError(loc, "Cannot specify explicit return types in an op whose " 1503 "return values replace the source pattern's root op"); 1504 1505 // First prepare local variables for op arguments used in builder call. 1506 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 1507 1508 // Then prepare the result types. We need to specify the types for all 1509 // results. 1510 os.indent() << formatv("::llvm::SmallVector<::mlir::Type, 4> tblgen_types; " 1511 "(void)tblgen_types;\n"); 1512 int numResults = resultOp.getNumResults(); 1513 if (tail.returnType) { 1514 auto numRetTys = tail.returnType.getNumArgs(); 1515 for (int i = 0; i < numRetTys; ++i) { 1516 auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1); 1517 os << "tblgen_types.push_back(" << varName << ");\n"; 1518 } 1519 } else { 1520 if (numResults != 0) { 1521 // Copy the result types from the source pattern. 1522 for (int i = 0; i < numResults; ++i) 1523 os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" 1524 " tblgen_types.push_back(v.getType());\n}\n", 1525 resultIndex + i); 1526 } 1527 } 1528 os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " 1529 "tblgen_values, tblgen_attrs);\n", 1530 valuePackName, resultOp.getQualCppClassName(), locToUse); 1531 os.unindent() << "}\n"; 1532 return resultValue; 1533 } 1534 1535 void PatternEmitter::createSeparateLocalVarsForOpArgs( 1536 DagNode node, ChildNodeIndexNameMap &childNodeNames) { 1537 Operator &resultOp = node.getDialectOp(opMap); 1538 1539 // Now prepare operands used for building this op: 1540 // * If the operand is non-variadic, we create a `Value` local variable. 1541 // * If the operand is variadic, we create a `SmallVector<Value>` local 1542 // variable. 1543 1544 int valueIndex = 0; // An index for uniquing local variable names. 1545 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1546 const auto *operand = 1547 llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(argIndex)); 1548 // We do not need special handling for attributes. 1549 if (!operand) 1550 continue; 1551 1552 raw_indented_ostream::DelimitedScope scope(os); 1553 std::string varName; 1554 if (operand->isVariadic()) { 1555 varName = std::string(formatv("tblgen_values_{0}", valueIndex++)); 1556 os << formatv("::llvm::SmallVector<::mlir::Value, 4> {0};\n", varName); 1557 std::string range; 1558 if (node.isNestedDagArg(argIndex)) { 1559 range = childNodeNames[argIndex]; 1560 } else { 1561 range = std::string(node.getArgName(argIndex)); 1562 } 1563 // Resolve the symbol for all range use so that we have a uniform way of 1564 // capturing the values. 1565 range = symbolInfoMap.getValueAndRangeUse(range); 1566 os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range, 1567 varName); 1568 } else { 1569 varName = std::string(formatv("tblgen_value_{0}", valueIndex++)); 1570 os << formatv("::mlir::Value {0} = ", varName); 1571 if (node.isNestedDagArg(argIndex)) { 1572 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 1573 } else { 1574 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1575 auto symbol = 1576 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1577 if (leaf.isNativeCodeCall()) { 1578 os << std::string( 1579 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1580 } else { 1581 os << symbol; 1582 } 1583 } 1584 os << ";\n"; 1585 } 1586 1587 // Update to use the newly created local variable for building the op later. 1588 childNodeNames[argIndex] = varName; 1589 } 1590 } 1591 1592 void PatternEmitter::supplyValuesForOpArgs( 1593 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 1594 Operator &resultOp = node.getDialectOp(opMap); 1595 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); 1596 argIndex != numOpArgs; ++argIndex) { 1597 // Start each argument on its own line. 1598 os << ",\n "; 1599 1600 Argument opArg = resultOp.getArg(argIndex); 1601 // Handle the case of operand first. 1602 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) { 1603 if (!operand->name.empty()) 1604 os << "/*" << operand->name << "=*/"; 1605 os << childNodeNames.lookup(argIndex); 1606 continue; 1607 } 1608 1609 // The argument in the op definition. 1610 auto opArgName = resultOp.getArgName(argIndex); 1611 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1612 if (!subTree.isNativeCodeCall()) 1613 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1614 "for creating attribute"); 1615 os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex)); 1616 } else { 1617 auto leaf = node.getArgAsLeaf(argIndex); 1618 // The argument in the result DAG pattern. 1619 auto patArgName = node.getArgName(argIndex); 1620 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 1621 // TODO: Refactor out into map to avoid recomputing these. 1622 if (!opArg.is<NamedAttribute *>()) 1623 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 1624 if (!patArgName.empty()) 1625 os << "/*" << patArgName << "=*/"; 1626 } else { 1627 os << "/*" << opArgName << "=*/"; 1628 } 1629 os << handleOpArgument(leaf, patArgName); 1630 } 1631 } 1632 } 1633 1634 void PatternEmitter::createAggregateLocalVarsForOpArgs( 1635 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 1636 Operator &resultOp = node.getDialectOp(opMap); 1637 1638 auto scope = os.scope(); 1639 os << formatv("::llvm::SmallVector<::mlir::Value, 4> " 1640 "tblgen_values; (void)tblgen_values;\n"); 1641 os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> " 1642 "tblgen_attrs; (void)tblgen_attrs;\n"); 1643 1644 const char *addAttrCmd = 1645 "if (auto tmpAttr = {1}) {\n" 1646 " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), " 1647 "tmpAttr);\n}\n"; 1648 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1649 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) { 1650 // The argument in the op definition. 1651 auto opArgName = resultOp.getArgName(argIndex); 1652 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1653 if (!subTree.isNativeCodeCall()) 1654 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1655 "for creating attribute"); 1656 os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex)); 1657 } else { 1658 auto leaf = node.getArgAsLeaf(argIndex); 1659 // The argument in the result DAG pattern. 1660 auto patArgName = node.getArgName(argIndex); 1661 os << formatv(addAttrCmd, opArgName, 1662 handleOpArgument(leaf, patArgName)); 1663 } 1664 continue; 1665 } 1666 1667 const auto *operand = 1668 resultOp.getArg(argIndex).get<NamedTypeConstraint *>(); 1669 std::string varName; 1670 if (operand->isVariadic()) { 1671 std::string range; 1672 if (node.isNestedDagArg(argIndex)) { 1673 range = childNodeNames.lookup(argIndex); 1674 } else { 1675 range = std::string(node.getArgName(argIndex)); 1676 } 1677 // Resolve the symbol for all range use so that we have a uniform way of 1678 // capturing the values. 1679 range = symbolInfoMap.getValueAndRangeUse(range); 1680 os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n", 1681 range); 1682 } else { 1683 os << formatv("tblgen_values.push_back("); 1684 if (node.isNestedDagArg(argIndex)) { 1685 os << symbolInfoMap.getValueAndRangeUse( 1686 childNodeNames.lookup(argIndex)); 1687 } else { 1688 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1689 if (leaf.isConstantAttr()) 1690 // TODO: Use better location 1691 PrintFatalError( 1692 loc, 1693 "attribute found where value was expected, if attempting to use " 1694 "constant value, construct a constant op with given attribute " 1695 "instead"); 1696 1697 auto symbol = 1698 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1699 if (leaf.isNativeCodeCall()) { 1700 os << std::string( 1701 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1702 } else { 1703 os << symbol; 1704 } 1705 } 1706 os << ");\n"; 1707 } 1708 } 1709 } 1710 1711 StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os, 1712 const RecordKeeper &recordKeeper, 1713 RecordOperatorMap &mapper) 1714 : opMap(mapper), staticVerifierEmitter(os, recordKeeper) {} 1715 1716 void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) { 1717 // PatternEmitter will use the static matcher if there's one generated. To 1718 // ensure that all the dependent static matchers are generated before emitting 1719 // the matching logic of the DagNode, we use topological order to achieve it. 1720 for (auto &dagInfo : topologicalOrder) { 1721 DagNode node = dagInfo.first; 1722 if (!useStaticMatcher(node)) 1723 continue; 1724 1725 std::string funcName = 1726 formatv("static_dag_matcher_{0}", staticMatcherCounter++); 1727 assert(!matcherNames.contains(node)); 1728 PatternEmitter(dagInfo.second, &opMap, os, *this) 1729 .emitStaticMatcher(node, funcName); 1730 matcherNames[node] = funcName; 1731 } 1732 } 1733 1734 void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) { 1735 staticVerifierEmitter.emitPatternConstraints(constraints.getArrayRef()); 1736 } 1737 1738 void StaticMatcherHelper::addPattern(Record *record) { 1739 Pattern pat(record, &opMap); 1740 1741 // While generating the function body of the DAG matcher, it may depends on 1742 // other DAG matchers. To ensure the dependent matchers are ready, we compute 1743 // the topological order for all the DAGs and emit the DAG matchers in this 1744 // order. 1745 llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) { 1746 ++refStats[node]; 1747 1748 if (refStats[node] != 1) 1749 return; 1750 1751 for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i) 1752 if (DagNode sibling = node.getArgAsNestedDag(i)) 1753 dfs(sibling); 1754 else { 1755 DagLeaf leaf = node.getArgAsLeaf(i); 1756 if (!leaf.isUnspecified()) 1757 constraints.insert(leaf); 1758 } 1759 1760 topologicalOrder.push_back(std::make_pair(node, record)); 1761 }; 1762 1763 dfs(pat.getSourcePattern()); 1764 } 1765 1766 StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) { 1767 if (leaf.isAttrMatcher()) { 1768 std::optional<StringRef> constraint = 1769 staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint()); 1770 assert(constraint && "attribute constraint was not uniqued"); 1771 return *constraint; 1772 } 1773 assert(leaf.isOperandMatcher()); 1774 return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint()); 1775 } 1776 1777 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 1778 emitSourceFileHeader("Rewriters", os); 1779 1780 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 1781 1782 // We put the map here because it can be shared among multiple patterns. 1783 RecordOperatorMap recordOpMap; 1784 1785 // Exam all the patterns and generate static matcher for the duplicated 1786 // DagNode. 1787 StaticMatcherHelper staticMatcher(os, recordKeeper, recordOpMap); 1788 for (Record *p : patterns) 1789 staticMatcher.addPattern(p); 1790 staticMatcher.populateStaticConstraintFunctions(os); 1791 staticMatcher.populateStaticMatchers(os); 1792 1793 std::vector<std::string> rewriterNames; 1794 rewriterNames.reserve(patterns.size()); 1795 1796 std::string baseRewriterName = "GeneratedConvert"; 1797 int rewriterIndex = 0; 1798 1799 for (Record *p : patterns) { 1800 std::string name; 1801 if (p->isAnonymous()) { 1802 // If no name is provided, ensure unique rewriter names simply by 1803 // appending unique suffix. 1804 name = baseRewriterName + llvm::utostr(rewriterIndex++); 1805 } else { 1806 name = std::string(p->getName()); 1807 } 1808 LLVM_DEBUG(llvm::dbgs() 1809 << "=== start generating pattern '" << name << "' ===\n"); 1810 PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name); 1811 LLVM_DEBUG(llvm::dbgs() 1812 << "=== done generating pattern '" << name << "' ===\n"); 1813 rewriterNames.push_back(std::move(name)); 1814 } 1815 1816 // Emit function to add the generated matchers to the pattern list. 1817 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(" 1818 "::mlir::RewritePatternSet &patterns) {\n"; 1819 for (const auto &name : rewriterNames) { 1820 os << " patterns.add<" << name << ">(patterns.getContext());\n"; 1821 } 1822 os << "}\n"; 1823 } 1824 1825 static mlir::GenRegistration 1826 genRewriters("gen-rewriters", "Generate pattern rewriters", 1827 [](const RecordKeeper &records, raw_ostream &os) { 1828 emitRewriters(records, os); 1829 return false; 1830 }); 1831