1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Support/IndentedOstream.h" 14 #include "mlir/TableGen/Attribute.h" 15 #include "mlir/TableGen/CodeGenHelpers.h" 16 #include "mlir/TableGen/Format.h" 17 #include "mlir/TableGen/GenInfo.h" 18 #include "mlir/TableGen/Operator.h" 19 #include "mlir/TableGen/Pattern.h" 20 #include "mlir/TableGen/Predicate.h" 21 #include "mlir/TableGen/Type.h" 22 #include "llvm/ADT/FunctionExtras.h" 23 #include "llvm/ADT/SetVector.h" 24 #include "llvm/ADT/StringExtras.h" 25 #include "llvm/ADT/StringSet.h" 26 #include "llvm/Support/CommandLine.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/FormatAdapters.h" 29 #include "llvm/Support/PrettyStackTrace.h" 30 #include "llvm/Support/Signals.h" 31 #include "llvm/TableGen/Error.h" 32 #include "llvm/TableGen/Main.h" 33 #include "llvm/TableGen/Record.h" 34 #include "llvm/TableGen/TableGenBackend.h" 35 36 using namespace mlir; 37 using namespace mlir::tblgen; 38 39 using llvm::formatv; 40 using llvm::Record; 41 using llvm::RecordKeeper; 42 43 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 44 45 namespace llvm { 46 template <> 47 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 48 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 49 raw_ostream &os, StringRef style) { 50 os << v.first << ":" << v.second; 51 } 52 }; 53 } // namespace llvm 54 55 //===----------------------------------------------------------------------===// 56 // PatternEmitter 57 //===----------------------------------------------------------------------===// 58 59 namespace { 60 61 class StaticMatcherHelper; 62 63 class PatternEmitter { 64 public: 65 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os, 66 StaticMatcherHelper &helper); 67 68 // Emits the mlir::RewritePattern struct named `rewriteName`. 69 void emit(StringRef rewriteName); 70 71 // Emits the static function of DAG matcher. 72 void emitStaticMatcher(DagNode tree, std::string funcName); 73 74 private: 75 // Emits the code for matching ops. 76 void emitMatchLogic(DagNode tree, StringRef opName); 77 78 // Emits the code for rewriting ops. 79 void emitRewriteLogic(); 80 81 //===--------------------------------------------------------------------===// 82 // Match utilities 83 //===--------------------------------------------------------------------===// 84 85 // Emits C++ statements for matching the DAG structure. 86 void emitMatch(DagNode tree, StringRef name, int depth); 87 88 // Emit C++ function call to static DAG matcher. 89 void emitStaticMatchCall(DagNode tree, StringRef name); 90 91 // Emit C++ function call to static type/attribute constraint function. 92 void emitStaticVerifierCall(StringRef funcName, StringRef opName, 93 StringRef arg, StringRef failureStr); 94 95 // Emits C++ statements for matching using a native code call. 96 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); 97 98 // Emits C++ statements for matching the op constrained by the given DAG 99 // `tree` returning the op's variable name. 100 void emitOpMatch(DagNode tree, StringRef opName, int depth); 101 102 // Emits C++ statements for matching the `argIndex`-th argument of the given 103 // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the 104 // bound name and the constraint of the operand respectively. 105 void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName, 106 int operandIndex, DagLeaf operandMatcher, 107 StringRef argName, int argIndex, 108 std::optional<int> variadicSubIndex); 109 110 // Emits C++ statements for matching the operands which can be matched in 111 // either order. 112 void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, 113 StringRef opName, int argIndex, int &operandIndex, 114 int depth); 115 116 // Emits C++ statements for matching a variadic operand. 117 void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree, 118 StringRef opName, int argIndex, 119 int &operandIndex, int depth); 120 121 // Emits C++ statements for matching the `argIndex`-th argument of the given 122 // DAG `tree` as an attribute. 123 void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, 124 int depth); 125 126 // Emits C++ for checking a match with a corresponding match failure 127 // diagnostic. 128 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, 129 const llvm::formatv_object_base &failureFmt); 130 131 // Emits C++ for checking a match with a corresponding match failure 132 // diagnostics. 133 void emitMatchCheck(StringRef opName, const std::string &matchStr, 134 const std::string &failureStr); 135 136 //===--------------------------------------------------------------------===// 137 // Rewrite utilities 138 //===--------------------------------------------------------------------===// 139 140 // The entry point for handling a result pattern rooted at `resultTree`. This 141 // method dispatches to concrete handlers according to `resultTree`'s kind and 142 // returns a symbol representing the whole value pack. Callers are expected to 143 // further resolve the symbol according to the specific use case. 144 // 145 // `depth` is the nesting level of `resultTree`; 0 means top-level result 146 // pattern. For top-level result pattern, `resultIndex` indicates which result 147 // of the matched root op this pattern is intended to replace, which can be 148 // used to deduce the result type of the op generated from this result 149 // pattern. 150 std::string handleResultPattern(DagNode resultTree, int resultIndex, 151 int depth); 152 153 // Emits the C++ statement to replace the matched DAG with a value built via 154 // calling native C++ code. 155 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); 156 157 // Returns the symbol of the old value serving as the replacement. 158 StringRef handleReplaceWithValue(DagNode tree); 159 160 // Trailing directives are used at the end of DAG node argument lists to 161 // specify additional behaviour for op matchers and creators, etc. 162 struct TrailingDirectives { 163 // DAG node containing the `location` directive. Null if there is none. 164 DagNode location; 165 166 // DAG node containing the `returnType` directive. Null if there is none. 167 DagNode returnType; 168 169 // Number of found trailing directives. 170 int numDirectives; 171 }; 172 173 // Collect any trailing directives. 174 TrailingDirectives getTrailingDirectives(DagNode tree); 175 176 // Returns the location value to use. 177 std::string getLocation(TrailingDirectives &tail); 178 179 // Returns the location value to use. 180 std::string handleLocationDirective(DagNode tree); 181 182 // Emit return type argument. 183 std::string handleReturnTypeArg(DagNode returnType, int i, int depth); 184 185 // Emits the C++ statement to build a new op out of the given DAG `tree` and 186 // returns the variable name that this op is assigned to. If the root op in 187 // DAG `tree` has a specified name, the created op will be assigned to a 188 // variable of the given name. Otherwise, a unique name will be used as the 189 // result value name. 190 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 191 192 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; 193 194 // Emits a local variable for each value and attribute to be used for creating 195 // an op. 196 void createSeparateLocalVarsForOpArgs(DagNode node, 197 ChildNodeIndexNameMap &childNodeNames); 198 199 // Emits the concrete arguments used to call an op's builder. 200 void supplyValuesForOpArgs(DagNode node, 201 const ChildNodeIndexNameMap &childNodeNames, 202 int depth); 203 204 // Emits the local variables for holding all values as a whole and all named 205 // attributes as a whole to be used for creating an op. 206 void createAggregateLocalVarsForOpArgs( 207 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); 208 209 // Returns the C++ expression to construct a constant attribute of the given 210 // `value` for the given attribute kind `attr`. 211 std::string handleConstantAttr(Attribute attr, const Twine &value); 212 213 // Returns the C++ expression to build an argument from the given DAG `leaf`. 214 // `patArgName` is used to bound the argument to the source pattern. 215 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 216 217 //===--------------------------------------------------------------------===// 218 // General utilities 219 //===--------------------------------------------------------------------===// 220 221 // Collects all of the operations within the given dag tree. 222 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 223 224 // Returns a unique symbol for a local variable of the given `op`. 225 std::string getUniqueSymbol(const Operator *op); 226 227 //===--------------------------------------------------------------------===// 228 // Symbol utilities 229 //===--------------------------------------------------------------------===// 230 231 // Returns how many static values the given DAG `node` correspond to. 232 int getNodeValueCount(DagNode node); 233 234 private: 235 // Pattern instantiation location followed by the location of multiclass 236 // prototypes used. This is intended to be used as a whole to 237 // PrintFatalError() on errors. 238 ArrayRef<SMLoc> loc; 239 240 // Op's TableGen Record to wrapper object. 241 RecordOperatorMap *opMap; 242 243 // Handy wrapper for pattern being emitted. 244 Pattern pattern; 245 246 // Map for all bound symbols' info. 247 SymbolInfoMap symbolInfoMap; 248 249 StaticMatcherHelper &staticMatcherHelper; 250 251 // The next unused ID for newly created values. 252 unsigned nextValueId = 0; 253 254 raw_indented_ostream os; 255 256 // Format contexts containing placeholder substitutions. 257 FmtContext fmtCtx; 258 }; 259 260 // Tracks DagNode's reference multiple times across patterns. Enables generating 261 // static matcher functions for DagNode's referenced multiple times rather than 262 // inlining them. 263 class StaticMatcherHelper { 264 public: 265 StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper, 266 RecordOperatorMap &mapper); 267 268 // Determine if we should inline the match logic or delegate to a static 269 // function. 270 bool useStaticMatcher(DagNode node) { 271 // either/variadic node must be associated to the parentOp, thus we can't 272 // emit a static matcher rooted at them. 273 if (node.isEither() || node.isVariadic()) 274 return false; 275 276 return refStats[node] > kStaticMatcherThreshold; 277 } 278 279 // Get the name of the static DAG matcher function corresponding to the node. 280 std::string getMatcherName(DagNode node) { 281 assert(useStaticMatcher(node)); 282 return matcherNames[node]; 283 } 284 285 // Get the name of static type/attribute verification function. 286 StringRef getVerifierName(DagLeaf leaf); 287 288 // Collect the `Record`s, i.e., the DRR, so that we can get the information of 289 // the duplicated DAGs. 290 void addPattern(Record *record); 291 292 // Emit all static functions of DAG Matcher. 293 void populateStaticMatchers(raw_ostream &os); 294 295 // Emit all static functions for Constraints. 296 void populateStaticConstraintFunctions(raw_ostream &os); 297 298 private: 299 static constexpr unsigned kStaticMatcherThreshold = 1; 300 301 // Consider two patterns as down below, 302 // DagNode_Root_A DagNode_Root_B 303 // \ \ 304 // DagNode_C DagNode_C 305 // \ \ 306 // DagNode_D DagNode_D 307 // 308 // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of 309 // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced 310 // multiple times so we'll have static matchers for both of them. When we're 311 // emitting the match logic for DagNode_C, we will check if DagNode_D has the 312 // static matcher generated. If so, then we'll generate a call to the 313 // function, inline otherwise. In this case, inlining is not what we want. As 314 // a result, generate the static matcher in topological order to ensure all 315 // the dependent static matchers are generated and we can avoid accidentally 316 // inlining. 317 // 318 // The topological order of all the DagNodes among all patterns. 319 SmallVector<std::pair<DagNode, Record *>> topologicalOrder; 320 321 RecordOperatorMap &opMap; 322 323 // Records of the static function name of each DagNode 324 DenseMap<DagNode, std::string> matcherNames; 325 326 // After collecting all the DagNode in each pattern, `refStats` records the 327 // number of users for each DagNode. We will generate the static matcher for a 328 // DagNode while the number of users exceeds a certain threshold. 329 DenseMap<DagNode, unsigned> refStats; 330 331 // Number of static matcher generated. This is used to generate a unique name 332 // for each DagNode. 333 int staticMatcherCounter = 0; 334 335 // The DagLeaf which contains type or attr constraint. 336 SetVector<DagLeaf> constraints; 337 338 // Static type/attribute verification function emitter. 339 StaticVerifierFunctionEmitter staticVerifierEmitter; 340 }; 341 342 } // namespace 343 344 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 345 raw_ostream &os, StaticMatcherHelper &helper) 346 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 347 symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) { 348 fmtCtx.withBuilder("rewriter"); 349 } 350 351 std::string PatternEmitter::handleConstantAttr(Attribute attr, 352 const Twine &value) { 353 if (!attr.isConstBuildable()) 354 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 355 " does not have the 'constBuilderCall' field"); 356 357 // TODO: Verify the constants here 358 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value)); 359 } 360 361 void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) { 362 os << formatv( 363 "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, " 364 "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation " 365 "*, 4> &tblgen_ops", 366 funcName); 367 368 // We pass the reference of the variables that need to be captured. Hence we 369 // need to collect all the symbols in the tree first. 370 pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true); 371 symbolInfoMap.assignUniqueAlternativeNames(); 372 for (const auto &info : symbolInfoMap) 373 os << formatv(", {0}", info.second.getArgDecl(info.first)); 374 375 os << ") {\n"; 376 os.indent(); 377 os << "(void)tblgen_ops;\n"; 378 379 // Note that a static matcher is considered at least one step from the match 380 // entry. 381 emitMatch(tree, "op0", /*depth=*/1); 382 383 os << "return ::mlir::success();\n"; 384 os.unindent(); 385 os << "}\n\n"; 386 } 387 388 // Helper function to match patterns. 389 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { 390 if (tree.isNativeCodeCall()) { 391 emitNativeCodeMatch(tree, name, depth); 392 return; 393 } 394 395 if (tree.isOperation()) { 396 emitOpMatch(tree, name, depth); 397 return; 398 } 399 400 PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match."); 401 } 402 403 void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) { 404 std::string funcName = staticMatcherHelper.getMatcherName(tree); 405 os << formatv("if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", funcName, 406 opName); 407 408 // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in 409 // one pass. 410 411 // In general, bound symbol should have the unique name in the pattern but 412 // for the operand, binding same symbol to multiple operands imply a 413 // constraint at the same time. In this case, we will rename those operands 414 // with different names. As a result, we need to collect all the symbolInfos 415 // from the DagNode then get the updated name of the local variables from the 416 // global symbolInfoMap. 417 418 // Collect all the bound symbols in the Dag 419 SymbolInfoMap localSymbolMap(loc); 420 pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true); 421 422 for (const auto &info : localSymbolMap) { 423 auto name = info.first; 424 auto symboInfo = info.second; 425 auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo); 426 os << formatv(", {0}", ret->second.getVarName(name)); 427 } 428 429 os << "))) {\n"; 430 os.scope().os << "return ::mlir::failure();\n"; 431 os << "}\n"; 432 } 433 434 void PatternEmitter::emitStaticVerifierCall(StringRef funcName, 435 StringRef opName, StringRef arg, 436 StringRef failureStr) { 437 os << formatv("if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n", 438 funcName, opName, arg, failureStr); 439 os.scope().os << "return ::mlir::failure();\n"; 440 os << "}\n"; 441 } 442 443 // Helper function to match patterns. 444 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, 445 int depth) { 446 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: "); 447 LLVM_DEBUG(tree.print(llvm::dbgs())); 448 LLVM_DEBUG(llvm::dbgs() << '\n'); 449 450 // The order of generating static matcher follows the topological order so 451 // that for every dependent DagNode already have their static matcher 452 // generated if needed. The reason we check if `getMatcherName(tree).empty()` 453 // is when we are generating the static matcher for a DagNode itself. In this 454 // case, we need to emit the function body rather than a function call. 455 if (staticMatcherHelper.useStaticMatcher(tree) && 456 !staticMatcherHelper.getMatcherName(tree).empty()) { 457 emitStaticMatchCall(tree, opName); 458 459 // NativeCodeCall will never be at depth 0 so that we don't need to catch 460 // the root operation as emitOpMatch(); 461 462 return; 463 } 464 465 // TODO(suderman): iterate through arguments, determine their types, output 466 // names. 467 SmallVector<std::string, 8> capture; 468 469 raw_indented_ostream::DelimitedScope scope(os); 470 471 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 472 std::string argName = formatv("arg{0}_{1}", depth, i); 473 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 474 if (argTree.isEither()) 475 PrintFatalError(loc, "NativeCodeCall cannot have `either` operands"); 476 if (argTree.isVariadic()) 477 PrintFatalError(loc, "NativeCodeCall cannot have `variadic` operands"); 478 479 os << "::mlir::Value " << argName << ";\n"; 480 } else { 481 auto leaf = tree.getArgAsLeaf(i); 482 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { 483 os << "::mlir::Attribute " << argName << ";\n"; 484 } else { 485 os << "::mlir::Value " << argName << ";\n"; 486 } 487 } 488 489 capture.push_back(std::move(argName)); 490 } 491 492 auto tail = getTrailingDirectives(tree); 493 if (tail.returnType) 494 PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier"); 495 auto locToUse = getLocation(tail); 496 497 auto fmt = tree.getNativeCodeTemplate(); 498 if (fmt.count("$_self") != 1) 499 PrintFatalError(loc, "NativeCodeCall must have $_self as argument for " 500 "passing the defining Operation"); 501 502 auto nativeCodeCall = std::string( 503 tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), 504 static_cast<ArrayRef<std::string>>(capture))); 505 506 emitMatchCheck(opName, formatv("!::mlir::failed({0})", nativeCodeCall), 507 formatv("\"{0} return ::mlir::failure\"", nativeCodeCall)); 508 509 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 510 auto name = tree.getArgName(i); 511 if (!name.empty() && name != "_") { 512 os << formatv("{0} = {1};\n", name, capture[i]); 513 } 514 } 515 516 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 517 std::string argName = capture[i]; 518 519 // Handle nested DAG construct first 520 if (tree.getArgAsNestedDag(i)) { 521 PrintFatalError( 522 loc, formatv("Matching nested tree in NativeCodecall not support for " 523 "{0} as arg {1}", 524 argName, i)); 525 } 526 527 DagLeaf leaf = tree.getArgAsLeaf(i); 528 529 // The parameter for native function doesn't bind any constraints. 530 if (leaf.isUnspecified()) 531 continue; 532 533 auto constraint = leaf.getAsConstraint(); 534 535 std::string self; 536 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) 537 self = argName; 538 else 539 self = formatv("{0}.getType()", argName); 540 StringRef verifier = staticMatcherHelper.getVerifierName(leaf); 541 emitStaticVerifierCall( 542 verifier, opName, self, 543 formatv("\"operand {0} of native code call '{1}' failed to satisfy " 544 "constraint: " 545 "'{2}'\"", 546 i, tree.getNativeCodeTemplate(), 547 escapeString(constraint.getSummary())) 548 .str()); 549 } 550 551 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n"); 552 } 553 554 // Helper function to match patterns. 555 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { 556 Operator &op = tree.getDialectOp(opMap); 557 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 558 << op.getOperationName() << "' at depth " << depth 559 << '\n'); 560 561 auto getCastedName = [depth]() -> std::string { 562 return formatv("castedOp{0}", depth); 563 }; 564 565 // The order of generating static matcher follows the topological order so 566 // that for every dependent DagNode already have their static matcher 567 // generated if needed. The reason we check if `getMatcherName(tree).empty()` 568 // is when we are generating the static matcher for a DagNode itself. In this 569 // case, we need to emit the function body rather than a function call. 570 if (staticMatcherHelper.useStaticMatcher(tree) && 571 !staticMatcherHelper.getMatcherName(tree).empty()) { 572 emitStaticMatchCall(tree, opName); 573 // In the codegen of rewriter, we suppose that castedOp0 will capture the 574 // root operation. Manually add it if the root DagNode is a static matcher. 575 if (depth == 0) 576 os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); " 577 "(void){2};\n", 578 opName, op.getQualCppClassName(), getCastedName()); 579 return; 580 } 581 582 std::string castedName = getCastedName(); 583 os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); " 584 "(void){0};\n", 585 castedName, opName, op.getQualCppClassName()); 586 587 // Skip the operand matching at depth 0 as the pattern rewriter already does. 588 if (depth != 0) 589 emitMatchCheck(opName, /*matchStr=*/castedName, 590 formatv("\"{0} is not {1} type\"", castedName, 591 op.getQualCppClassName())); 592 593 // If the operand's name is set, set to that variable. 594 auto name = tree.getSymbol(); 595 if (!name.empty()) 596 os << formatv("{0} = {1};\n", name, castedName); 597 598 for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; 599 ++i, ++opArgIdx) { 600 auto opArg = op.getArg(opArgIdx); 601 std::string argName = formatv("op{0}", depth + 1); 602 603 // Handle nested DAG construct first 604 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 605 if (argTree.isEither()) { 606 emitEitherOperandMatch(tree, argTree, castedName, opArgIdx, nextOperand, 607 depth); 608 ++opArgIdx; 609 continue; 610 } 611 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) { 612 if (argTree.isVariadic()) { 613 if (!operand->isVariadic()) { 614 auto error = formatv("variadic DAG construct can't match op {0}'s " 615 "non-variadic operand #{1}", 616 op.getOperationName(), opArgIdx); 617 PrintFatalError(loc, error); 618 } 619 emitVariadicOperandMatch(tree, argTree, castedName, opArgIdx, 620 nextOperand, depth); 621 ++nextOperand; 622 continue; 623 } 624 if (operand->isVariableLength()) { 625 auto error = formatv("use nested DAG construct to match op {0}'s " 626 "variadic operand #{1} unsupported now", 627 op.getOperationName(), opArgIdx); 628 PrintFatalError(loc, error); 629 } 630 } 631 632 os << "{\n"; 633 634 // Attributes don't count for getODSOperands. 635 // TODO: Operand is a Value, check if we should remove `getDefiningOp()`. 636 os.indent() << formatv( 637 "auto *{0} = " 638 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", 639 argName, castedName, nextOperand); 640 // Null check of operand's definingOp 641 emitMatchCheck( 642 castedName, /*matchStr=*/argName, 643 formatv("\"There's no operation that defines operand {0} of {1}\"", 644 nextOperand++, castedName)); 645 emitMatch(argTree, argName, depth + 1); 646 os << formatv("tblgen_ops.push_back({0});\n", argName); 647 os.unindent() << "}\n"; 648 continue; 649 } 650 651 // Next handle DAG leaf: operand or attribute 652 if (opArg.is<NamedTypeConstraint *>()) { 653 auto operandName = 654 formatv("{0}.getODSOperands({1})", castedName, nextOperand); 655 emitOperandMatch(tree, castedName, operandName.str(), opArgIdx, 656 /*operandMatcher=*/tree.getArgAsLeaf(i), 657 /*argName=*/tree.getArgName(i), opArgIdx, 658 /*variadicSubIndex=*/std::nullopt); 659 ++nextOperand; 660 } else if (opArg.is<NamedAttribute *>()) { 661 emitAttributeMatch(tree, opName, opArgIdx, depth); 662 } else { 663 PrintFatalError(loc, "unhandled case when matching op"); 664 } 665 } 666 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 667 << op.getOperationName() << "' at depth " << depth 668 << '\n'); 669 } 670 671 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, 672 StringRef operandName, int operandIndex, 673 DagLeaf operandMatcher, StringRef argName, 674 int argIndex, 675 std::optional<int> variadicSubIndex) { 676 Operator &op = tree.getDialectOp(opMap); 677 auto *operand = op.getArg(operandIndex).get<NamedTypeConstraint *>(); 678 679 // If a constraint is specified, we need to generate C++ statements to 680 // check the constraint. 681 if (!operandMatcher.isUnspecified()) { 682 if (!operandMatcher.isOperandMatcher()) 683 PrintFatalError( 684 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 685 op.getOperationName(), argIndex + 1)); 686 687 // Only need to verify if the matcher's type is different from the one 688 // of op definition. 689 Constraint constraint = operandMatcher.getAsConstraint(); 690 if (operand->constraint != constraint) { 691 if (operand->isVariableLength()) { 692 auto error = formatv( 693 "further constrain op {0}'s variadic operand #{1} unsupported now", 694 op.getOperationName(), argIndex); 695 PrintFatalError(loc, error); 696 } 697 auto self = formatv("(*{0}.begin()).getType()", operandName); 698 StringRef verifier = staticMatcherHelper.getVerifierName(operandMatcher); 699 emitStaticVerifierCall( 700 verifier, opName, self.str(), 701 formatv( 702 "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"", 703 operand - op.operand_begin(), op.getOperationName(), 704 escapeString(constraint.getSummary())) 705 .str()); 706 } 707 } 708 709 // Capture the value 710 // `$_` is a special symbol to ignore op argument matching. 711 if (!argName.empty() && argName != "_") { 712 auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex, 713 variadicSubIndex); 714 if (res == symbolInfoMap.end()) 715 PrintFatalError(loc, formatv("symbol not found: {0}", argName)); 716 717 os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName); 718 } 719 } 720 721 void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, 722 StringRef opName, int argIndex, 723 int &operandIndex, int depth) { 724 constexpr int numEitherArgs = 2; 725 if (eitherArgTree.getNumArgs() != numEitherArgs) 726 PrintFatalError(loc, "`either` only supports grouping two operands"); 727 728 Operator &op = tree.getDialectOp(opMap); 729 730 std::string codeBuffer; 731 llvm::raw_string_ostream tblgenOps(codeBuffer); 732 733 std::string lambda = formatv("eitherLambda{0}", depth); 734 os << formatv( 735 "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n", 736 lambda); 737 738 os.indent(); 739 740 for (int i = 0; i < numEitherArgs; ++i, ++argIndex) { 741 if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) { 742 if (argTree.isEither()) 743 PrintFatalError(loc, "either cannot be nested"); 744 745 std::string argName = formatv("local_op_{0}", i).str(); 746 747 os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName, 748 i); 749 750 // Indent emitMatchCheck and emitMatch because they declare local 751 // variables. 752 os << "{\n"; 753 os.indent(); 754 755 emitMatchCheck( 756 opName, /*matchStr=*/argName, 757 formatv("\"There's no operation that defines operand {0} of {1}\"", 758 operandIndex++, opName)); 759 emitMatch(argTree, argName, depth + 1); 760 761 os.unindent() << "}\n"; 762 763 // `tblgen_ops` is used to collect the matched operations. In either, we 764 // need to queue the operation only if the matching success. Thus we emit 765 // the code at the end. 766 tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName); 767 } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) { 768 emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(), 769 operandIndex, 770 /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i), 771 /*argName=*/eitherArgTree.getArgName(i), argIndex, 772 /*variadicSubIndex=*/std::nullopt); 773 ++operandIndex; 774 } else { 775 PrintFatalError(loc, "either can only be applied on operand"); 776 } 777 } 778 779 os << tblgenOps.str(); 780 os << "return ::mlir::success();\n"; 781 os.unindent() << "};\n"; 782 783 os << "{\n"; 784 os.indent(); 785 786 os << formatv("auto eitherOperand0 = {0}.getODSOperands({1});\n", opName, 787 operandIndex - 2); 788 os << formatv("auto eitherOperand1 = {0}.getODSOperands({1});\n", opName, 789 operandIndex - 1); 790 791 os << formatv("if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && " 792 "::mlir::failed({0}(eitherOperand1, " 793 "eitherOperand0)))\n", 794 lambda); 795 os.indent() << "return ::mlir::failure();\n"; 796 797 os.unindent().unindent() << "}\n"; 798 } 799 800 void PatternEmitter::emitVariadicOperandMatch(DagNode tree, 801 DagNode variadicArgTree, 802 StringRef opName, int argIndex, 803 int &operandIndex, int depth) { 804 Operator &op = tree.getDialectOp(opMap); 805 806 os << "{\n"; 807 os.indent(); 808 809 os << formatv("auto variadic_operand_range = {0}.getODSOperands({1});\n", 810 opName, operandIndex); 811 os << formatv("if (variadic_operand_range.size() != {0}) " 812 "return ::mlir::failure();\n", 813 variadicArgTree.getNumArgs()); 814 815 StringRef variadicTreeName = variadicArgTree.getSymbol(); 816 if (!variadicTreeName.empty()) { 817 auto res = 818 symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex, 819 /*variadicSubIndex=*/std::nullopt); 820 if (res == symbolInfoMap.end()) 821 PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName)); 822 823 os << formatv("{0} = variadic_operand_range;\n", 824 res->second.getVarName(variadicTreeName)); 825 } 826 827 for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) { 828 if (DagNode argTree = variadicArgTree.getArgAsNestedDag(i)) { 829 if (!argTree.isOperation()) 830 PrintFatalError(loc, "variadic only accepts operation sub-dags"); 831 832 os << "{\n"; 833 os.indent(); 834 835 std::string argName = formatv("local_op_{0}", i).str(); 836 os << formatv("auto *{0} = " 837 "variadic_operand_range[{1}].getDefiningOp();\n", 838 argName, i); 839 emitMatchCheck( 840 opName, /*matchStr=*/argName, 841 formatv("\"There's no operation that defines variadic operand " 842 "{0} (variadic sub-opearnd #{1}) of {2}\"", 843 operandIndex, i, opName)); 844 emitMatch(argTree, argName, depth + 1); 845 os << formatv("tblgen_ops.push_back({0});\n", argName); 846 847 os.unindent() << "}\n"; 848 } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) { 849 auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i); 850 emitOperandMatch(tree, opName, operandName.str(), operandIndex, 851 /*operandMatcher=*/variadicArgTree.getArgAsLeaf(i), 852 /*argName=*/variadicArgTree.getArgName(i), argIndex, i); 853 } else { 854 PrintFatalError(loc, "variadic can only be applied on operand"); 855 } 856 } 857 858 os.unindent() << "}\n"; 859 } 860 861 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, 862 int argIndex, int depth) { 863 Operator &op = tree.getDialectOp(opMap); 864 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>(); 865 const auto &attr = namedAttr->attr; 866 867 os << "{\n"; 868 os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" 869 "(void)tblgen_attr;\n", 870 opName, attr.getStorageType(), namedAttr->name); 871 872 // TODO: This should use getter method to avoid duplication. 873 if (attr.hasDefaultValue()) { 874 os << "if (!tblgen_attr) tblgen_attr = " 875 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 876 attr.getDefaultValue())) 877 << ";\n"; 878 } else if (attr.isOptional()) { 879 // For a missing attribute that is optional according to definition, we 880 // should just capture a mlir::Attribute() to signal the missing state. 881 // That is precisely what getDiscardableAttr() returns on missing 882 // attributes. 883 } else { 884 emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx), 885 formatv("\"expected op '{0}' to have attribute '{1}' " 886 "of type '{2}'\"", 887 op.getOperationName(), namedAttr->name, 888 attr.getStorageType())); 889 } 890 891 auto matcher = tree.getArgAsLeaf(argIndex); 892 if (!matcher.isUnspecified()) { 893 if (!matcher.isAttrMatcher()) { 894 PrintFatalError( 895 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 896 op.getOperationName(), argIndex + 1)); 897 } 898 899 // If a constraint is specified, we need to generate function call to its 900 // static verifier. 901 StringRef verifier = staticMatcherHelper.getVerifierName(matcher); 902 if (attr.isOptional()) { 903 // Avoid dereferencing null attribute. This is using a simple heuristic to 904 // avoid common cases of attempting to dereference null attribute. This 905 // will return where there is no check if attribute is null unless the 906 // attribute's value is not used. 907 // FIXME: This could be improved as some null dereferences could slip 908 // through. 909 if (!StringRef(matcher.getConditionTemplate()).contains("!$_self") && 910 StringRef(matcher.getConditionTemplate()).contains("$_self")) { 911 os << "if (!tblgen_attr) return ::mlir::failure();\n"; 912 } 913 } 914 emitStaticVerifierCall( 915 verifier, opName, "tblgen_attr", 916 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " 917 "'{2}'\"", 918 op.getOperationName(), namedAttr->name, 919 escapeString(matcher.getAsConstraint().getSummary())) 920 .str()); 921 } 922 923 // Capture the value 924 auto name = tree.getArgName(argIndex); 925 // `$_` is a special symbol to ignore op argument matching. 926 if (!name.empty() && name != "_") { 927 os << formatv("{0} = tblgen_attr;\n", name); 928 } 929 930 os.unindent() << "}\n"; 931 } 932 933 void PatternEmitter::emitMatchCheck( 934 StringRef opName, const FmtObjectBase &matchFmt, 935 const llvm::formatv_object_base &failureFmt) { 936 emitMatchCheck(opName, matchFmt.str(), failureFmt.str()); 937 } 938 939 void PatternEmitter::emitMatchCheck(StringRef opName, 940 const std::string &matchStr, 941 const std::string &failureStr) { 942 943 os << "if (!(" << matchStr << "))"; 944 os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName 945 << ", [&](::mlir::Diagnostic &diag) {\n diag << " 946 << failureStr << ";\n});"; 947 } 948 949 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { 950 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 951 int depth = 0; 952 emitMatch(tree, opName, depth); 953 954 for (auto &appliedConstraint : pattern.getConstraints()) { 955 auto &constraint = appliedConstraint.constraint; 956 auto &entities = appliedConstraint.entities; 957 958 auto condition = constraint.getConditionTemplate(); 959 if (isa<TypeConstraint>(constraint)) { 960 if (entities.size() != 1) 961 PrintFatalError(loc, "type constraint requires exactly one argument"); 962 963 auto self = formatv("({0}.getType())", 964 symbolInfoMap.getValueAndRangeUse(entities.front())); 965 emitMatchCheck( 966 opName, tgfmt(condition, &fmtCtx.withSelf(self.str())), 967 formatv("\"value entity '{0}' failed to satisfy constraint: '{1}'\"", 968 entities.front(), escapeString(constraint.getSummary()))); 969 970 } else if (isa<AttrConstraint>(constraint)) { 971 PrintFatalError( 972 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 973 } else { 974 // TODO: replace formatv arguments with the exact specified 975 // args. 976 if (entities.size() > 4) { 977 PrintFatalError(loc, "only support up to 4-entity constraints now"); 978 } 979 SmallVector<std::string, 4> names; 980 int i = 0; 981 for (int e = entities.size(); i < e; ++i) 982 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 983 std::string self = appliedConstraint.self; 984 if (!self.empty()) 985 self = symbolInfoMap.getValueAndRangeUse(self); 986 for (; i < 4; ++i) 987 names.push_back("<unused>"); 988 emitMatchCheck(opName, 989 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 990 names[1], names[2], names[3]), 991 formatv("\"entities '{0}' failed to satisfy constraint: " 992 "'{1}'\"", 993 llvm::join(entities, ", "), 994 escapeString(constraint.getSummary()))); 995 } 996 } 997 998 // Some of the operands could be bound to the same symbol name, we need 999 // to enforce equality constraint on those. 1000 // TODO: we should be able to emit equality checks early 1001 // and short circuit unnecessary work if vars are not equal. 1002 for (auto symbolInfoIt = symbolInfoMap.begin(); 1003 symbolInfoIt != symbolInfoMap.end();) { 1004 auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); 1005 auto startRange = range.first; 1006 auto endRange = range.second; 1007 1008 auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); 1009 for (++startRange; startRange != endRange; ++startRange) { 1010 auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); 1011 emitMatchCheck( 1012 opName, 1013 formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), 1014 formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, 1015 secondOperand)); 1016 } 1017 1018 symbolInfoIt = endRange; 1019 } 1020 1021 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 1022 } 1023 1024 void PatternEmitter::collectOps(DagNode tree, 1025 llvm::SmallPtrSetImpl<const Operator *> &ops) { 1026 // Check if this tree is an operation. 1027 if (tree.isOperation()) { 1028 const Operator &op = tree.getDialectOp(opMap); 1029 LLVM_DEBUG(llvm::dbgs() 1030 << "found operation " << op.getOperationName() << '\n'); 1031 ops.insert(&op); 1032 } 1033 1034 // Recurse the arguments of the tree. 1035 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 1036 if (auto child = tree.getArgAsNestedDag(i)) 1037 collectOps(child, ops); 1038 } 1039 1040 void PatternEmitter::emit(StringRef rewriteName) { 1041 // Get the DAG tree for the source pattern. 1042 DagNode sourceTree = pattern.getSourcePattern(); 1043 1044 const Operator &rootOp = pattern.getSourceRootOp(); 1045 auto rootName = rootOp.getOperationName(); 1046 1047 // Collect the set of result operations. 1048 llvm::SmallPtrSet<const Operator *, 4> resultOps; 1049 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 1050 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 1051 collectOps(pattern.getResultPattern(i), resultOps); 1052 } 1053 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 1054 1055 // Emit RewritePattern for Pattern. 1056 auto locs = pattern.getLocation(); 1057 os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n", 1058 make_range(locs.rbegin(), locs.rend())); 1059 os << formatv(R"(struct {0} : public ::mlir::RewritePattern { 1060 {0}(::mlir::MLIRContext *context) 1061 : ::mlir::RewritePattern("{1}", {2}, context, {{)", 1062 rewriteName, rootName, pattern.getBenefit()); 1063 // Sort result operators by name. 1064 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 1065 resultOps.end()); 1066 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 1067 return lhs->getOperationName() < rhs->getOperationName(); 1068 }); 1069 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) { 1070 os << '"' << op->getOperationName() << '"'; 1071 }); 1072 os << "}) {}\n"; 1073 1074 // Emit matchAndRewrite() function. 1075 { 1076 auto classScope = os.scope(); 1077 os.printReindented(R"( 1078 ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, 1079 ::mlir::PatternRewriter &rewriter) const override {)") 1080 << '\n'; 1081 { 1082 auto functionScope = os.scope(); 1083 1084 // Register all symbols bound in the source pattern. 1085 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 1086 1087 LLVM_DEBUG(llvm::dbgs() 1088 << "start creating local variables for capturing matches\n"); 1089 os << "// Variables for capturing values and attributes used while " 1090 "creating ops\n"; 1091 // Create local variables for storing the arguments and results bound 1092 // to symbols. 1093 for (const auto &symbolInfoPair : symbolInfoMap) { 1094 const auto &symbol = symbolInfoPair.first; 1095 const auto &info = symbolInfoPair.second; 1096 1097 os << info.getVarDecl(symbol); 1098 } 1099 // TODO: capture ops with consistent numbering so that it can be 1100 // reused for fused loc. 1101 os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n"; 1102 LLVM_DEBUG(llvm::dbgs() 1103 << "done creating local variables for capturing matches\n"); 1104 1105 os << "// Match\n"; 1106 os << "tblgen_ops.push_back(op0);\n"; 1107 emitMatchLogic(sourceTree, "op0"); 1108 1109 os << "\n// Rewrite\n"; 1110 emitRewriteLogic(); 1111 1112 os << "return ::mlir::success();\n"; 1113 } 1114 os << "};\n"; 1115 } 1116 os << "};\n\n"; 1117 } 1118 1119 void PatternEmitter::emitRewriteLogic() { 1120 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 1121 const Operator &rootOp = pattern.getSourceRootOp(); 1122 int numExpectedResults = rootOp.getNumResults(); 1123 int numResultPatterns = pattern.getNumResultPatterns(); 1124 1125 // First register all symbols bound to ops generated in result patterns. 1126 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 1127 1128 // Only the last N static values generated are used to replace the matched 1129 // root N-result op. We need to calculate the starting index (of the results 1130 // of the matched op) each result pattern is to replace. 1131 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 1132 // If we don't need to replace any value at all, set the replacement starting 1133 // index as the number of result patterns so we skip all of them when trying 1134 // to replace the matched op's results. 1135 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 1136 for (int i = numResultPatterns - 1; i >= 0; --i) { 1137 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 1138 offsets[i] = offsets[i + 1] - numValues; 1139 if (offsets[i] == 0) { 1140 if (replStartIndex == -1) 1141 replStartIndex = i; 1142 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 1143 auto error = formatv( 1144 "cannot use the same multi-result op '{0}' to generate both " 1145 "auxiliary values and values to be used for replacing the matched op", 1146 pattern.getResultPattern(i).getSymbol()); 1147 PrintFatalError(loc, error); 1148 } 1149 } 1150 1151 if (offsets.front() > 0) { 1152 const char error[] = 1153 "not enough values generated to replace the matched op"; 1154 PrintFatalError(loc, error); 1155 } 1156 1157 os << "auto odsLoc = rewriter.getFusedLoc({"; 1158 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 1159 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 1160 } 1161 os << "}); (void)odsLoc;\n"; 1162 1163 // Process auxiliary result patterns. 1164 for (int i = 0; i < replStartIndex; ++i) { 1165 DagNode resultTree = pattern.getResultPattern(i); 1166 auto val = handleResultPattern(resultTree, offsets[i], 0); 1167 // Normal op creation will be streamed to `os` by the above call; but 1168 // NativeCodeCall will only be materialized to `os` if it is used. Here 1169 // we are handling auxiliary patterns so we want the side effect even if 1170 // NativeCodeCall is not replacing matched root op's results. 1171 if (resultTree.isNativeCodeCall() && 1172 resultTree.getNumReturnsOfNativeCode() == 0) 1173 os << val << ";\n"; 1174 } 1175 1176 if (numExpectedResults == 0) { 1177 assert(replStartIndex >= numResultPatterns && 1178 "invalid auxiliary vs. replacement pattern division!"); 1179 // No result to replace. Just erase the op. 1180 os << "rewriter.eraseOp(op0);\n"; 1181 } else { 1182 // Process replacement result patterns. 1183 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n"; 1184 for (int i = replStartIndex; i < numResultPatterns; ++i) { 1185 DagNode resultTree = pattern.getResultPattern(i); 1186 auto val = handleResultPattern(resultTree, offsets[i], 0); 1187 os << "\n"; 1188 // Resolve each symbol for all range use so that we can loop over them. 1189 // We need an explicit cast to `SmallVector` to capture the cases where 1190 // `{0}` resolves to an `Operation::result_range` as well as cases that 1191 // are not iterable (e.g. vector that gets wrapped in additional braces by 1192 // RewriterGen). 1193 // TODO: Revisit the need for materializing a vector. 1194 os << symbolInfoMap.getAllRangeUse( 1195 val, 1196 "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" 1197 " tblgen_repl_values.push_back(v);\n}\n", 1198 "\n"); 1199 } 1200 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n"; 1201 } 1202 1203 // Process supplemtal patterns. 1204 int numSupplementalPatterns = pattern.getNumSupplementalPatterns(); 1205 for (int i = 0, offset = -numSupplementalPatterns; 1206 i < numSupplementalPatterns; ++i) { 1207 DagNode resultTree = pattern.getSupplementalPattern(i); 1208 auto val = handleResultPattern(resultTree, offset++, 0); 1209 if (resultTree.isNativeCodeCall() && 1210 resultTree.getNumReturnsOfNativeCode() == 0) 1211 os << val << ";\n"; 1212 } 1213 1214 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 1215 } 1216 1217 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 1218 return std::string( 1219 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++)); 1220 } 1221 1222 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 1223 int resultIndex, int depth) { 1224 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 1225 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 1226 LLVM_DEBUG(llvm::dbgs() << '\n'); 1227 1228 if (resultTree.isLocationDirective()) { 1229 PrintFatalError(loc, 1230 "location directive can only be used with op creation"); 1231 } 1232 1233 if (resultTree.isNativeCodeCall()) 1234 return handleReplaceWithNativeCodeCall(resultTree, depth); 1235 1236 if (resultTree.isReplaceWithValue()) 1237 return handleReplaceWithValue(resultTree).str(); 1238 1239 // Normal op creation. 1240 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 1241 if (resultTree.getSymbol().empty()) { 1242 // This is an op not explicitly bound to a symbol in the rewrite rule. 1243 // Register the auto-generated symbol for it. 1244 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 1245 } 1246 return symbol; 1247 } 1248 1249 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { 1250 assert(tree.isReplaceWithValue()); 1251 1252 if (tree.getNumArgs() != 1) { 1253 PrintFatalError( 1254 loc, "replaceWithValue directive must take exactly one argument"); 1255 } 1256 1257 if (!tree.getSymbol().empty()) { 1258 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 1259 } 1260 1261 return tree.getArgName(0); 1262 } 1263 1264 std::string PatternEmitter::handleLocationDirective(DagNode tree) { 1265 assert(tree.isLocationDirective()); 1266 auto lookUpArgLoc = [this, &tree](int idx) { 1267 const auto *const lookupFmt = "{0}.getLoc()"; 1268 return symbolInfoMap.getValueAndRangeUse(tree.getArgName(idx), lookupFmt); 1269 }; 1270 1271 if (tree.getNumArgs() == 0) 1272 llvm::PrintFatalError( 1273 "At least one argument to location directive required"); 1274 1275 if (!tree.getSymbol().empty()) 1276 PrintFatalError(loc, "cannot bind symbol to location"); 1277 1278 if (tree.getNumArgs() == 1) { 1279 DagLeaf leaf = tree.getArgAsLeaf(0); 1280 if (leaf.isStringAttr()) 1281 return formatv("::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))", 1282 leaf.getStringAttr()) 1283 .str(); 1284 return lookUpArgLoc(0); 1285 } 1286 1287 std::string ret; 1288 llvm::raw_string_ostream os(ret); 1289 std::string strAttr; 1290 os << "rewriter.getFusedLoc({"; 1291 bool first = true; 1292 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 1293 DagLeaf leaf = tree.getArgAsLeaf(i); 1294 // Handle the optional string value. 1295 if (leaf.isStringAttr()) { 1296 if (!strAttr.empty()) 1297 llvm::PrintFatalError("Only one string attribute may be specified"); 1298 strAttr = leaf.getStringAttr(); 1299 continue; 1300 } 1301 os << (first ? "" : ", ") << lookUpArgLoc(i); 1302 first = false; 1303 } 1304 os << "}"; 1305 if (!strAttr.empty()) { 1306 os << ", rewriter.getStringAttr(\"" << strAttr << "\")"; 1307 } 1308 os << ")"; 1309 return os.str(); 1310 } 1311 1312 std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i, 1313 int depth) { 1314 // Nested NativeCodeCall. 1315 if (auto dagNode = returnType.getArgAsNestedDag(i)) { 1316 if (!dagNode.isNativeCodeCall()) 1317 PrintFatalError(loc, "nested DAG in `returnType` must be a native code " 1318 "call"); 1319 return handleReplaceWithNativeCodeCall(dagNode, depth); 1320 } 1321 // String literal. 1322 auto dagLeaf = returnType.getArgAsLeaf(i); 1323 if (dagLeaf.isStringAttr()) 1324 return tgfmt(dagLeaf.getStringAttr(), &fmtCtx); 1325 return tgfmt( 1326 "$0.getType()", &fmtCtx, 1327 handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i))); 1328 } 1329 1330 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 1331 StringRef patArgName) { 1332 if (leaf.isStringAttr()) 1333 PrintFatalError(loc, "raw string not supported as argument"); 1334 if (leaf.isConstantAttr()) { 1335 auto constAttr = leaf.getAsConstantAttr(); 1336 return handleConstantAttr(constAttr.getAttribute(), 1337 constAttr.getConstantValue()); 1338 } 1339 if (leaf.isEnumAttrCase()) { 1340 auto enumCase = leaf.getAsEnumAttrCase(); 1341 // This is an enum case backed by an IntegerAttr. We need to get its value 1342 // to build the constant. 1343 std::string val = std::to_string(enumCase.getValue()); 1344 return handleConstantAttr(enumCase, val); 1345 } 1346 1347 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 1348 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 1349 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 1350 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 1351 << "' (via symbol ref)\n"); 1352 return argName; 1353 } 1354 if (leaf.isNativeCodeCall()) { 1355 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 1356 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 1357 << "' (via NativeCodeCall)\n"); 1358 return std::string(repl); 1359 } 1360 PrintFatalError(loc, "unhandled case when rewriting op"); 1361 } 1362 1363 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, 1364 int depth) { 1365 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 1366 LLVM_DEBUG(tree.print(llvm::dbgs())); 1367 LLVM_DEBUG(llvm::dbgs() << '\n'); 1368 1369 auto fmt = tree.getNativeCodeTemplate(); 1370 1371 SmallVector<std::string, 16> attrs; 1372 1373 auto tail = getTrailingDirectives(tree); 1374 if (tail.returnType) 1375 PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier"); 1376 auto locToUse = getLocation(tail); 1377 1378 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 1379 if (tree.isNestedDagArg(i)) { 1380 attrs.push_back( 1381 handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1)); 1382 } else { 1383 attrs.push_back( 1384 handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i))); 1385 } 1386 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i 1387 << " replacement: " << attrs[i] << "\n"); 1388 } 1389 1390 std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), 1391 static_cast<ArrayRef<std::string>>(attrs)); 1392 1393 // In general, NativeCodeCall without naming binding don't need this. To 1394 // ensure void helper function has been correctly labeled, i.e., use 1395 // NativeCodeCallVoid, we cache the result to a local variable so that we will 1396 // get a compilation error in the auto-generated file. 1397 // Example. 1398 // // In the td file 1399 // Pat<(...), (NativeCodeCall<Foo> ...)> 1400 // 1401 // --- 1402 // 1403 // // In the auto-generated .cpp 1404 // ... 1405 // // Causes compilation error if Foo() returns void. 1406 // auto nativeVar = Foo(); 1407 // ... 1408 if (tree.getNumReturnsOfNativeCode() != 0) { 1409 // Determine the local variable name for return value. 1410 std::string varName = 1411 SymbolInfoMap::getValuePackName(tree.getSymbol()).str(); 1412 if (varName.empty()) { 1413 varName = formatv("nativeVar_{0}", nextValueId++); 1414 // Register the local variable for later uses. 1415 symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode()); 1416 } 1417 1418 // Catch the return value of helper function. 1419 os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol); 1420 1421 if (!tree.getSymbol().empty()) 1422 symbol = tree.getSymbol().str(); 1423 else 1424 symbol = varName; 1425 } 1426 1427 return symbol; 1428 } 1429 1430 int PatternEmitter::getNodeValueCount(DagNode node) { 1431 if (node.isOperation()) { 1432 // If the op is bound to a symbol in the rewrite rule, query its result 1433 // count from the symbol info map. 1434 auto symbol = node.getSymbol(); 1435 if (!symbol.empty()) { 1436 return symbolInfoMap.getStaticValueCount(symbol); 1437 } 1438 // Otherwise this is an unbound op; we will use all its results. 1439 return pattern.getDialectOp(node).getNumResults(); 1440 } 1441 1442 if (node.isNativeCodeCall()) 1443 return node.getNumReturnsOfNativeCode(); 1444 1445 return 1; 1446 } 1447 1448 PatternEmitter::TrailingDirectives 1449 PatternEmitter::getTrailingDirectives(DagNode tree) { 1450 TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0}; 1451 1452 // Look backwards through the arguments. 1453 auto numPatArgs = tree.getNumArgs(); 1454 for (int i = numPatArgs - 1; i >= 0; --i) { 1455 auto dagArg = tree.getArgAsNestedDag(i); 1456 // A leaf is not a directive. Stop looking. 1457 if (!dagArg) 1458 break; 1459 1460 auto isLocation = dagArg.isLocationDirective(); 1461 auto isReturnType = dagArg.isReturnTypeDirective(); 1462 // If encountered a DAG node that isn't a trailing directive, stop looking. 1463 if (!(isLocation || isReturnType)) 1464 break; 1465 // Save the directive, but error if one of the same type was already 1466 // found. 1467 ++tail.numDirectives; 1468 if (isLocation) { 1469 if (tail.location) 1470 PrintFatalError(loc, "`location` directive can only be specified " 1471 "once"); 1472 tail.location = dagArg; 1473 } else if (isReturnType) { 1474 if (tail.returnType) 1475 PrintFatalError(loc, "`returnType` directive can only be specified " 1476 "once"); 1477 tail.returnType = dagArg; 1478 } 1479 } 1480 1481 return tail; 1482 } 1483 1484 std::string 1485 PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) { 1486 if (tail.location) 1487 return handleLocationDirective(tail.location); 1488 1489 // If no explicit location is given, use the default, all fused, location. 1490 return "odsLoc"; 1491 } 1492 1493 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 1494 int depth) { 1495 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); 1496 LLVM_DEBUG(tree.print(llvm::dbgs())); 1497 LLVM_DEBUG(llvm::dbgs() << '\n'); 1498 1499 Operator &resultOp = tree.getDialectOp(opMap); 1500 auto numOpArgs = resultOp.getNumArgs(); 1501 auto numPatArgs = tree.getNumArgs(); 1502 1503 auto tail = getTrailingDirectives(tree); 1504 auto locToUse = getLocation(tail); 1505 1506 auto inPattern = numPatArgs - tail.numDirectives; 1507 if (numOpArgs != inPattern) { 1508 PrintFatalError(loc, 1509 formatv("resultant op '{0}' argument number mismatch: " 1510 "{1} in pattern vs. {2} in definition", 1511 resultOp.getOperationName(), inPattern, numOpArgs)); 1512 } 1513 1514 // A map to collect all nested DAG child nodes' names, with operand index as 1515 // the key. This includes both bound and unbound child nodes. 1516 ChildNodeIndexNameMap childNodeNames; 1517 1518 // First go through all the child nodes who are nested DAG constructs to 1519 // create ops for them and remember the symbol names for them, so that we can 1520 // use the results in the current node. This happens in a recursive manner. 1521 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 1522 if (auto child = tree.getArgAsNestedDag(i)) 1523 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 1524 } 1525 1526 // The name of the local variable holding this op. 1527 std::string valuePackName; 1528 // The symbol for holding the result of this pattern. Note that the result of 1529 // this pattern is not necessarily the same as the variable created by this 1530 // pattern because we can use `__N` suffix to refer only a specific result if 1531 // the generated op is a multi-result op. 1532 std::string resultValue; 1533 if (tree.getSymbol().empty()) { 1534 // No symbol is explicitly bound to this op in the pattern. Generate a 1535 // unique name. 1536 valuePackName = resultValue = getUniqueSymbol(&resultOp); 1537 } else { 1538 resultValue = std::string(tree.getSymbol()); 1539 // Strip the index to get the name for the value pack and use it to name the 1540 // local variable for the op. 1541 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue)); 1542 } 1543 1544 // Create the local variable for this op. 1545 os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(), 1546 valuePackName); 1547 1548 // Right now ODS don't have general type inference support. Except a few 1549 // special cases listed below, DRR needs to supply types for all results 1550 // when building an op. 1551 bool isSameOperandsAndResultType = 1552 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType"); 1553 bool useFirstAttr = 1554 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"); 1555 1556 if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) { 1557 // We know how to deduce the result type for ops with these traits and we've 1558 // generated builders taking aggregate parameters. Use those builders to 1559 // create the ops. 1560 1561 // First prepare local variables for op arguments used in builder call. 1562 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 1563 1564 // Then create the op. 1565 os.scope("", "\n}\n").os << formatv( 1566 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);", 1567 valuePackName, resultOp.getQualCppClassName(), locToUse); 1568 return resultValue; 1569 } 1570 1571 bool usePartialResults = valuePackName != resultValue; 1572 1573 if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) { 1574 // For these cases (broadcastable ops, op results used both as auxiliary 1575 // values and replacement values, ops in nested patterns, auxiliary ops), we 1576 // still need to supply the result types when building the op. But because 1577 // we don't generate a builder automatically with ODS for them, it's the 1578 // developer's responsibility to make sure such a builder (with result type 1579 // deduction ability) exists. We go through the separate-parameter builder 1580 // here given that it's easier for developers to write compared to 1581 // aggregate-parameter builders. 1582 createSeparateLocalVarsForOpArgs(tree, childNodeNames); 1583 1584 os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, 1585 resultOp.getQualCppClassName(), locToUse); 1586 supplyValuesForOpArgs(tree, childNodeNames, depth); 1587 os << "\n );\n}\n"; 1588 return resultValue; 1589 } 1590 1591 // If we are provided explicit return types, use them to build the op. 1592 // However, if depth == 0 and resultIndex >= 0, it means we are replacing 1593 // the values generated from the source pattern root op. Then we must use the 1594 // source pattern's value types to determine the value type of the generated 1595 // op here. 1596 if (depth == 0 && resultIndex >= 0 && tail.returnType) 1597 PrintFatalError(loc, "Cannot specify explicit return types in an op whose " 1598 "return values replace the source pattern's root op"); 1599 1600 // First prepare local variables for op arguments used in builder call. 1601 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 1602 1603 // Then prepare the result types. We need to specify the types for all 1604 // results. 1605 os.indent() << formatv("::llvm::SmallVector<::mlir::Type, 4> tblgen_types; " 1606 "(void)tblgen_types;\n"); 1607 int numResults = resultOp.getNumResults(); 1608 if (tail.returnType) { 1609 auto numRetTys = tail.returnType.getNumArgs(); 1610 for (int i = 0; i < numRetTys; ++i) { 1611 auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1); 1612 os << "tblgen_types.push_back(" << varName << ");\n"; 1613 } 1614 } else { 1615 if (numResults != 0) { 1616 // Copy the result types from the source pattern. 1617 for (int i = 0; i < numResults; ++i) 1618 os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" 1619 " tblgen_types.push_back(v.getType());\n}\n", 1620 resultIndex + i); 1621 } 1622 } 1623 os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " 1624 "tblgen_values, tblgen_attrs);\n", 1625 valuePackName, resultOp.getQualCppClassName(), locToUse); 1626 os.unindent() << "}\n"; 1627 return resultValue; 1628 } 1629 1630 void PatternEmitter::createSeparateLocalVarsForOpArgs( 1631 DagNode node, ChildNodeIndexNameMap &childNodeNames) { 1632 Operator &resultOp = node.getDialectOp(opMap); 1633 1634 // Now prepare operands used for building this op: 1635 // * If the operand is non-variadic, we create a `Value` local variable. 1636 // * If the operand is variadic, we create a `SmallVector<Value>` local 1637 // variable. 1638 1639 int valueIndex = 0; // An index for uniquing local variable names. 1640 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1641 const auto *operand = 1642 llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(argIndex)); 1643 // We do not need special handling for attributes. 1644 if (!operand) 1645 continue; 1646 1647 raw_indented_ostream::DelimitedScope scope(os); 1648 std::string varName; 1649 if (operand->isVariadic()) { 1650 varName = std::string(formatv("tblgen_values_{0}", valueIndex++)); 1651 os << formatv("::llvm::SmallVector<::mlir::Value, 4> {0};\n", varName); 1652 std::string range; 1653 if (node.isNestedDagArg(argIndex)) { 1654 range = childNodeNames[argIndex]; 1655 } else { 1656 range = std::string(node.getArgName(argIndex)); 1657 } 1658 // Resolve the symbol for all range use so that we have a uniform way of 1659 // capturing the values. 1660 range = symbolInfoMap.getValueAndRangeUse(range); 1661 os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range, 1662 varName); 1663 } else { 1664 varName = std::string(formatv("tblgen_value_{0}", valueIndex++)); 1665 os << formatv("::mlir::Value {0} = ", varName); 1666 if (node.isNestedDagArg(argIndex)) { 1667 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 1668 } else { 1669 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1670 auto symbol = 1671 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1672 if (leaf.isNativeCodeCall()) { 1673 os << std::string( 1674 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1675 } else { 1676 os << symbol; 1677 } 1678 } 1679 os << ";\n"; 1680 } 1681 1682 // Update to use the newly created local variable for building the op later. 1683 childNodeNames[argIndex] = varName; 1684 } 1685 } 1686 1687 void PatternEmitter::supplyValuesForOpArgs( 1688 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 1689 Operator &resultOp = node.getDialectOp(opMap); 1690 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); 1691 argIndex != numOpArgs; ++argIndex) { 1692 // Start each argument on its own line. 1693 os << ",\n "; 1694 1695 Argument opArg = resultOp.getArg(argIndex); 1696 // Handle the case of operand first. 1697 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) { 1698 if (!operand->name.empty()) 1699 os << "/*" << operand->name << "=*/"; 1700 os << childNodeNames.lookup(argIndex); 1701 continue; 1702 } 1703 1704 // The argument in the op definition. 1705 auto opArgName = resultOp.getArgName(argIndex); 1706 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1707 if (!subTree.isNativeCodeCall()) 1708 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1709 "for creating attribute"); 1710 os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex)); 1711 } else { 1712 auto leaf = node.getArgAsLeaf(argIndex); 1713 // The argument in the result DAG pattern. 1714 auto patArgName = node.getArgName(argIndex); 1715 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 1716 // TODO: Refactor out into map to avoid recomputing these. 1717 if (!opArg.is<NamedAttribute *>()) 1718 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 1719 if (!patArgName.empty()) 1720 os << "/*" << patArgName << "=*/"; 1721 } else { 1722 os << "/*" << opArgName << "=*/"; 1723 } 1724 os << handleOpArgument(leaf, patArgName); 1725 } 1726 } 1727 } 1728 1729 void PatternEmitter::createAggregateLocalVarsForOpArgs( 1730 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 1731 Operator &resultOp = node.getDialectOp(opMap); 1732 1733 auto scope = os.scope(); 1734 os << formatv("::llvm::SmallVector<::mlir::Value, 4> " 1735 "tblgen_values; (void)tblgen_values;\n"); 1736 os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> " 1737 "tblgen_attrs; (void)tblgen_attrs;\n"); 1738 1739 const char *addAttrCmd = 1740 "if (auto tmpAttr = {1}) {\n" 1741 " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), " 1742 "tmpAttr);\n}\n"; 1743 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1744 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) { 1745 // The argument in the op definition. 1746 auto opArgName = resultOp.getArgName(argIndex); 1747 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1748 if (!subTree.isNativeCodeCall()) 1749 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1750 "for creating attribute"); 1751 os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex)); 1752 } else { 1753 auto leaf = node.getArgAsLeaf(argIndex); 1754 // The argument in the result DAG pattern. 1755 auto patArgName = node.getArgName(argIndex); 1756 os << formatv(addAttrCmd, opArgName, 1757 handleOpArgument(leaf, patArgName)); 1758 } 1759 continue; 1760 } 1761 1762 const auto *operand = 1763 resultOp.getArg(argIndex).get<NamedTypeConstraint *>(); 1764 std::string varName; 1765 if (operand->isVariadic()) { 1766 std::string range; 1767 if (node.isNestedDagArg(argIndex)) { 1768 range = childNodeNames.lookup(argIndex); 1769 } else { 1770 range = std::string(node.getArgName(argIndex)); 1771 } 1772 // Resolve the symbol for all range use so that we have a uniform way of 1773 // capturing the values. 1774 range = symbolInfoMap.getValueAndRangeUse(range); 1775 os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n", 1776 range); 1777 } else { 1778 os << formatv("tblgen_values.push_back("); 1779 if (node.isNestedDagArg(argIndex)) { 1780 os << symbolInfoMap.getValueAndRangeUse( 1781 childNodeNames.lookup(argIndex)); 1782 } else { 1783 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1784 if (leaf.isConstantAttr()) 1785 // TODO: Use better location 1786 PrintFatalError( 1787 loc, 1788 "attribute found where value was expected, if attempting to use " 1789 "constant value, construct a constant op with given attribute " 1790 "instead"); 1791 1792 auto symbol = 1793 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1794 if (leaf.isNativeCodeCall()) { 1795 os << std::string( 1796 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1797 } else { 1798 os << symbol; 1799 } 1800 } 1801 os << ");\n"; 1802 } 1803 } 1804 } 1805 1806 StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os, 1807 const RecordKeeper &recordKeeper, 1808 RecordOperatorMap &mapper) 1809 : opMap(mapper), staticVerifierEmitter(os, recordKeeper) {} 1810 1811 void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) { 1812 // PatternEmitter will use the static matcher if there's one generated. To 1813 // ensure that all the dependent static matchers are generated before emitting 1814 // the matching logic of the DagNode, we use topological order to achieve it. 1815 for (auto &dagInfo : topologicalOrder) { 1816 DagNode node = dagInfo.first; 1817 if (!useStaticMatcher(node)) 1818 continue; 1819 1820 std::string funcName = 1821 formatv("static_dag_matcher_{0}", staticMatcherCounter++); 1822 assert(!matcherNames.contains(node)); 1823 PatternEmitter(dagInfo.second, &opMap, os, *this) 1824 .emitStaticMatcher(node, funcName); 1825 matcherNames[node] = funcName; 1826 } 1827 } 1828 1829 void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) { 1830 staticVerifierEmitter.emitPatternConstraints(constraints.getArrayRef()); 1831 } 1832 1833 void StaticMatcherHelper::addPattern(Record *record) { 1834 Pattern pat(record, &opMap); 1835 1836 // While generating the function body of the DAG matcher, it may depends on 1837 // other DAG matchers. To ensure the dependent matchers are ready, we compute 1838 // the topological order for all the DAGs and emit the DAG matchers in this 1839 // order. 1840 llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) { 1841 ++refStats[node]; 1842 1843 if (refStats[node] != 1) 1844 return; 1845 1846 for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i) 1847 if (DagNode sibling = node.getArgAsNestedDag(i)) 1848 dfs(sibling); 1849 else { 1850 DagLeaf leaf = node.getArgAsLeaf(i); 1851 if (!leaf.isUnspecified()) 1852 constraints.insert(leaf); 1853 } 1854 1855 topologicalOrder.push_back(std::make_pair(node, record)); 1856 }; 1857 1858 dfs(pat.getSourcePattern()); 1859 } 1860 1861 StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) { 1862 if (leaf.isAttrMatcher()) { 1863 std::optional<StringRef> constraint = 1864 staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint()); 1865 assert(constraint && "attribute constraint was not uniqued"); 1866 return *constraint; 1867 } 1868 assert(leaf.isOperandMatcher()); 1869 return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint()); 1870 } 1871 1872 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 1873 emitSourceFileHeader("Rewriters", os); 1874 1875 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 1876 1877 // We put the map here because it can be shared among multiple patterns. 1878 RecordOperatorMap recordOpMap; 1879 1880 // Exam all the patterns and generate static matcher for the duplicated 1881 // DagNode. 1882 StaticMatcherHelper staticMatcher(os, recordKeeper, recordOpMap); 1883 for (Record *p : patterns) 1884 staticMatcher.addPattern(p); 1885 staticMatcher.populateStaticConstraintFunctions(os); 1886 staticMatcher.populateStaticMatchers(os); 1887 1888 std::vector<std::string> rewriterNames; 1889 rewriterNames.reserve(patterns.size()); 1890 1891 std::string baseRewriterName = "GeneratedConvert"; 1892 int rewriterIndex = 0; 1893 1894 for (Record *p : patterns) { 1895 std::string name; 1896 if (p->isAnonymous()) { 1897 // If no name is provided, ensure unique rewriter names simply by 1898 // appending unique suffix. 1899 name = baseRewriterName + llvm::utostr(rewriterIndex++); 1900 } else { 1901 name = std::string(p->getName()); 1902 } 1903 LLVM_DEBUG(llvm::dbgs() 1904 << "=== start generating pattern '" << name << "' ===\n"); 1905 PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name); 1906 LLVM_DEBUG(llvm::dbgs() 1907 << "=== done generating pattern '" << name << "' ===\n"); 1908 rewriterNames.push_back(std::move(name)); 1909 } 1910 1911 // Emit function to add the generated matchers to the pattern list. 1912 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(" 1913 "::mlir::RewritePatternSet &patterns) {\n"; 1914 for (const auto &name : rewriterNames) { 1915 os << " patterns.add<" << name << ">(patterns.getContext());\n"; 1916 } 1917 os << "}\n"; 1918 } 1919 1920 static mlir::GenRegistration 1921 genRewriters("gen-rewriters", "Generate pattern rewriters", 1922 [](const RecordKeeper &records, raw_ostream &os) { 1923 emitRewriters(records, os); 1924 return false; 1925 }); 1926