1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Support/IndentedOstream.h" 14 #include "mlir/TableGen/Argument.h" 15 #include "mlir/TableGen/Attribute.h" 16 #include "mlir/TableGen/CodeGenHelpers.h" 17 #include "mlir/TableGen/Format.h" 18 #include "mlir/TableGen/GenInfo.h" 19 #include "mlir/TableGen/Operator.h" 20 #include "mlir/TableGen/Pattern.h" 21 #include "mlir/TableGen/Predicate.h" 22 #include "mlir/TableGen/Property.h" 23 #include "mlir/TableGen/Type.h" 24 #include "llvm/ADT/FunctionExtras.h" 25 #include "llvm/ADT/SetVector.h" 26 #include "llvm/ADT/StringExtras.h" 27 #include "llvm/ADT/StringSet.h" 28 #include "llvm/Support/CommandLine.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/FormatAdapters.h" 31 #include "llvm/Support/PrettyStackTrace.h" 32 #include "llvm/Support/Signals.h" 33 #include "llvm/TableGen/Error.h" 34 #include "llvm/TableGen/Main.h" 35 #include "llvm/TableGen/Record.h" 36 #include "llvm/TableGen/TableGenBackend.h" 37 38 using namespace mlir; 39 using namespace mlir::tblgen; 40 41 using llvm::formatv; 42 using llvm::Record; 43 using llvm::RecordKeeper; 44 45 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 46 47 namespace llvm { 48 template <> 49 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 50 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 51 raw_ostream &os, StringRef style) { 52 os << v.first << ":" << v.second; 53 } 54 }; 55 } // namespace llvm 56 57 //===----------------------------------------------------------------------===// 58 // PatternEmitter 59 //===----------------------------------------------------------------------===// 60 61 namespace { 62 63 class StaticMatcherHelper; 64 65 class PatternEmitter { 66 public: 67 PatternEmitter(const Record *pat, RecordOperatorMap *mapper, raw_ostream &os, 68 StaticMatcherHelper &helper); 69 70 // Emits the mlir::RewritePattern struct named `rewriteName`. 71 void emit(StringRef rewriteName); 72 73 // Emits the static function of DAG matcher. 74 void emitStaticMatcher(DagNode tree, std::string funcName); 75 76 private: 77 // Emits the code for matching ops. 78 void emitMatchLogic(DagNode tree, StringRef opName); 79 80 // Emits the code for rewriting ops. 81 void emitRewriteLogic(); 82 83 //===--------------------------------------------------------------------===// 84 // Match utilities 85 //===--------------------------------------------------------------------===// 86 87 // Emits C++ statements for matching the DAG structure. 88 void emitMatch(DagNode tree, StringRef name, int depth); 89 90 // Emit C++ function call to static DAG matcher. 91 void emitStaticMatchCall(DagNode tree, StringRef name); 92 93 // Emit C++ function call to static type/attribute constraint function. 94 void emitStaticVerifierCall(StringRef funcName, StringRef opName, 95 StringRef arg, StringRef failureStr); 96 97 // Emits C++ statements for matching using a native code call. 98 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); 99 100 // Emits C++ statements for matching the op constrained by the given DAG 101 // `tree` returning the op's variable name. 102 void emitOpMatch(DagNode tree, StringRef opName, int depth); 103 104 // Emits C++ statements for matching the `argIndex`-th argument of the given 105 // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the 106 // bound name and the constraint of the operand respectively. 107 void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName, 108 int operandIndex, DagLeaf operandMatcher, 109 StringRef argName, int argIndex, 110 std::optional<int> variadicSubIndex); 111 112 // Emits C++ statements for matching the operands which can be matched in 113 // either order. 114 void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, 115 StringRef opName, int argIndex, int &operandIndex, 116 int depth); 117 118 // Emits C++ statements for matching a variadic operand. 119 void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree, 120 StringRef opName, int argIndex, 121 int &operandIndex, int depth); 122 123 // Emits C++ statements for matching the `argIndex`-th argument of the given 124 // DAG `tree` as an attribute. 125 void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, 126 int depth); 127 128 // Emits C++ for checking a match with a corresponding match failure 129 // diagnostic. 130 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, 131 const llvm::formatv_object_base &failureFmt); 132 133 // Emits C++ for checking a match with a corresponding match failure 134 // diagnostics. 135 void emitMatchCheck(StringRef opName, const std::string &matchStr, 136 const std::string &failureStr); 137 138 //===--------------------------------------------------------------------===// 139 // Rewrite utilities 140 //===--------------------------------------------------------------------===// 141 142 // The entry point for handling a result pattern rooted at `resultTree`. This 143 // method dispatches to concrete handlers according to `resultTree`'s kind and 144 // returns a symbol representing the whole value pack. Callers are expected to 145 // further resolve the symbol according to the specific use case. 146 // 147 // `depth` is the nesting level of `resultTree`; 0 means top-level result 148 // pattern. For top-level result pattern, `resultIndex` indicates which result 149 // of the matched root op this pattern is intended to replace, which can be 150 // used to deduce the result type of the op generated from this result 151 // pattern. 152 std::string handleResultPattern(DagNode resultTree, int resultIndex, 153 int depth); 154 155 // Emits the C++ statement to replace the matched DAG with a value built via 156 // calling native C++ code. 157 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); 158 159 // Returns the symbol of the old value serving as the replacement. 160 StringRef handleReplaceWithValue(DagNode tree); 161 162 // Emits the C++ statement to replace the matched DAG with an array of 163 // matched values. 164 std::string handleVariadic(DagNode tree, int depth); 165 166 // Trailing directives are used at the end of DAG node argument lists to 167 // specify additional behaviour for op matchers and creators, etc. 168 struct TrailingDirectives { 169 // DAG node containing the `location` directive. Null if there is none. 170 DagNode location; 171 172 // DAG node containing the `returnType` directive. Null if there is none. 173 DagNode returnType; 174 175 // Number of found trailing directives. 176 int numDirectives; 177 }; 178 179 // Collect any trailing directives. 180 TrailingDirectives getTrailingDirectives(DagNode tree); 181 182 // Returns the location value to use. 183 std::string getLocation(TrailingDirectives &tail); 184 185 // Returns the location value to use. 186 std::string handleLocationDirective(DagNode tree); 187 188 // Emit return type argument. 189 std::string handleReturnTypeArg(DagNode returnType, int i, int depth); 190 191 // Emits the C++ statement to build a new op out of the given DAG `tree` and 192 // returns the variable name that this op is assigned to. If the root op in 193 // DAG `tree` has a specified name, the created op will be assigned to a 194 // variable of the given name. Otherwise, a unique name will be used as the 195 // result value name. 196 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 197 198 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; 199 200 // Emits a local variable for each value and attribute to be used for creating 201 // an op. 202 void createSeparateLocalVarsForOpArgs(DagNode node, 203 ChildNodeIndexNameMap &childNodeNames); 204 205 // Emits the concrete arguments used to call an op's builder. 206 void supplyValuesForOpArgs(DagNode node, 207 const ChildNodeIndexNameMap &childNodeNames, 208 int depth); 209 210 // Emits the local variables for holding all values as a whole and all named 211 // attributes as a whole to be used for creating an op. 212 void createAggregateLocalVarsForOpArgs( 213 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); 214 215 // Returns the C++ expression to construct a constant attribute of the given 216 // `value` for the given attribute kind `attr`. 217 std::string handleConstantAttr(Attribute attr, const Twine &value); 218 219 // Returns the C++ expression to build an argument from the given DAG `leaf`. 220 // `patArgName` is used to bound the argument to the source pattern. 221 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 222 223 //===--------------------------------------------------------------------===// 224 // General utilities 225 //===--------------------------------------------------------------------===// 226 227 // Collects all of the operations within the given dag tree. 228 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 229 230 // Returns a unique symbol for a local variable of the given `op`. 231 std::string getUniqueSymbol(const Operator *op); 232 233 //===--------------------------------------------------------------------===// 234 // Symbol utilities 235 //===--------------------------------------------------------------------===// 236 237 // Returns how many static values the given DAG `node` correspond to. 238 int getNodeValueCount(DagNode node); 239 240 private: 241 // Pattern instantiation location followed by the location of multiclass 242 // prototypes used. This is intended to be used as a whole to 243 // PrintFatalError() on errors. 244 ArrayRef<SMLoc> loc; 245 246 // Op's TableGen Record to wrapper object. 247 RecordOperatorMap *opMap; 248 249 // Handy wrapper for pattern being emitted. 250 Pattern pattern; 251 252 // Map for all bound symbols' info. 253 SymbolInfoMap symbolInfoMap; 254 255 StaticMatcherHelper &staticMatcherHelper; 256 257 // The next unused ID for newly created values. 258 unsigned nextValueId = 0; 259 260 raw_indented_ostream os; 261 262 // Format contexts containing placeholder substitutions. 263 FmtContext fmtCtx; 264 }; 265 266 // Tracks DagNode's reference multiple times across patterns. Enables generating 267 // static matcher functions for DagNode's referenced multiple times rather than 268 // inlining them. 269 class StaticMatcherHelper { 270 public: 271 StaticMatcherHelper(raw_ostream &os, const RecordKeeper &records, 272 RecordOperatorMap &mapper); 273 274 // Determine if we should inline the match logic or delegate to a static 275 // function. 276 bool useStaticMatcher(DagNode node) { 277 // either/variadic node must be associated to the parentOp, thus we can't 278 // emit a static matcher rooted at them. 279 if (node.isEither() || node.isVariadic()) 280 return false; 281 282 return refStats[node] > kStaticMatcherThreshold; 283 } 284 285 // Get the name of the static DAG matcher function corresponding to the node. 286 std::string getMatcherName(DagNode node) { 287 assert(useStaticMatcher(node)); 288 return matcherNames[node]; 289 } 290 291 // Get the name of static type/attribute verification function. 292 StringRef getVerifierName(DagLeaf leaf); 293 294 // Collect the `Record`s, i.e., the DRR, so that we can get the information of 295 // the duplicated DAGs. 296 void addPattern(const Record *record); 297 298 // Emit all static functions of DAG Matcher. 299 void populateStaticMatchers(raw_ostream &os); 300 301 // Emit all static functions for Constraints. 302 void populateStaticConstraintFunctions(raw_ostream &os); 303 304 private: 305 static constexpr unsigned kStaticMatcherThreshold = 1; 306 307 // Consider two patterns as down below, 308 // DagNode_Root_A DagNode_Root_B 309 // \ \ 310 // DagNode_C DagNode_C 311 // \ \ 312 // DagNode_D DagNode_D 313 // 314 // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of 315 // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced 316 // multiple times so we'll have static matchers for both of them. When we're 317 // emitting the match logic for DagNode_C, we will check if DagNode_D has the 318 // static matcher generated. If so, then we'll generate a call to the 319 // function, inline otherwise. In this case, inlining is not what we want. As 320 // a result, generate the static matcher in topological order to ensure all 321 // the dependent static matchers are generated and we can avoid accidentally 322 // inlining. 323 // 324 // The topological order of all the DagNodes among all patterns. 325 SmallVector<std::pair<DagNode, const Record *>> topologicalOrder; 326 327 RecordOperatorMap &opMap; 328 329 // Records of the static function name of each DagNode 330 DenseMap<DagNode, std::string> matcherNames; 331 332 // After collecting all the DagNode in each pattern, `refStats` records the 333 // number of users for each DagNode. We will generate the static matcher for a 334 // DagNode while the number of users exceeds a certain threshold. 335 DenseMap<DagNode, unsigned> refStats; 336 337 // Number of static matcher generated. This is used to generate a unique name 338 // for each DagNode. 339 int staticMatcherCounter = 0; 340 341 // The DagLeaf which contains type or attr constraint. 342 SetVector<DagLeaf> constraints; 343 344 // Static type/attribute verification function emitter. 345 StaticVerifierFunctionEmitter staticVerifierEmitter; 346 }; 347 348 } // namespace 349 350 PatternEmitter::PatternEmitter(const Record *pat, RecordOperatorMap *mapper, 351 raw_ostream &os, StaticMatcherHelper &helper) 352 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 353 symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) { 354 fmtCtx.withBuilder("rewriter"); 355 } 356 357 std::string PatternEmitter::handleConstantAttr(Attribute attr, 358 const Twine &value) { 359 if (!attr.isConstBuildable()) 360 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 361 " does not have the 'constBuilderCall' field"); 362 363 // TODO: Verify the constants here 364 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value)); 365 } 366 367 void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) { 368 os << formatv( 369 "static ::llvm::LogicalResult {0}(::mlir::PatternRewriter &rewriter, " 370 "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation " 371 "*, 4> &tblgen_ops", 372 funcName); 373 374 // We pass the reference of the variables that need to be captured. Hence we 375 // need to collect all the symbols in the tree first. 376 pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true); 377 symbolInfoMap.assignUniqueAlternativeNames(); 378 for (const auto &info : symbolInfoMap) 379 os << formatv(", {0}", info.second.getArgDecl(info.first)); 380 381 os << ") {\n"; 382 os.indent(); 383 os << "(void)tblgen_ops;\n"; 384 385 // Note that a static matcher is considered at least one step from the match 386 // entry. 387 emitMatch(tree, "op0", /*depth=*/1); 388 389 os << "return ::mlir::success();\n"; 390 os.unindent(); 391 os << "}\n\n"; 392 } 393 394 // Helper function to match patterns. 395 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { 396 if (tree.isNativeCodeCall()) { 397 emitNativeCodeMatch(tree, name, depth); 398 return; 399 } 400 401 if (tree.isOperation()) { 402 emitOpMatch(tree, name, depth); 403 return; 404 } 405 406 PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match."); 407 } 408 409 void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) { 410 std::string funcName = staticMatcherHelper.getMatcherName(tree); 411 os << formatv("if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", funcName, 412 opName); 413 414 // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in 415 // one pass. 416 417 // In general, bound symbol should have the unique name in the pattern but 418 // for the operand, binding same symbol to multiple operands imply a 419 // constraint at the same time. In this case, we will rename those operands 420 // with different names. As a result, we need to collect all the symbolInfos 421 // from the DagNode then get the updated name of the local variables from the 422 // global symbolInfoMap. 423 424 // Collect all the bound symbols in the Dag 425 SymbolInfoMap localSymbolMap(loc); 426 pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true); 427 428 for (const auto &info : localSymbolMap) { 429 auto name = info.first; 430 auto symboInfo = info.second; 431 auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo); 432 os << formatv(", {0}", ret->second.getVarName(name)); 433 } 434 435 os << "))) {\n"; 436 os.scope().os << "return ::mlir::failure();\n"; 437 os << "}\n"; 438 } 439 440 void PatternEmitter::emitStaticVerifierCall(StringRef funcName, 441 StringRef opName, StringRef arg, 442 StringRef failureStr) { 443 os << formatv("if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n", 444 funcName, opName, arg, failureStr); 445 os.scope().os << "return ::mlir::failure();\n"; 446 os << "}\n"; 447 } 448 449 // Helper function to match patterns. 450 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, 451 int depth) { 452 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: "); 453 LLVM_DEBUG(tree.print(llvm::dbgs())); 454 LLVM_DEBUG(llvm::dbgs() << '\n'); 455 456 // The order of generating static matcher follows the topological order so 457 // that for every dependent DagNode already have their static matcher 458 // generated if needed. The reason we check if `getMatcherName(tree).empty()` 459 // is when we are generating the static matcher for a DagNode itself. In this 460 // case, we need to emit the function body rather than a function call. 461 if (staticMatcherHelper.useStaticMatcher(tree) && 462 !staticMatcherHelper.getMatcherName(tree).empty()) { 463 emitStaticMatchCall(tree, opName); 464 465 // NativeCodeCall will never be at depth 0 so that we don't need to catch 466 // the root operation as emitOpMatch(); 467 468 return; 469 } 470 471 // TODO(suderman): iterate through arguments, determine their types, output 472 // names. 473 SmallVector<std::string, 8> capture; 474 475 raw_indented_ostream::DelimitedScope scope(os); 476 477 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 478 std::string argName = formatv("arg{0}_{1}", depth, i); 479 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 480 if (argTree.isEither()) 481 PrintFatalError(loc, "NativeCodeCall cannot have `either` operands"); 482 if (argTree.isVariadic()) 483 PrintFatalError(loc, "NativeCodeCall cannot have `variadic` operands"); 484 485 os << "::mlir::Value " << argName << ";\n"; 486 } else { 487 auto leaf = tree.getArgAsLeaf(i); 488 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { 489 os << "::mlir::Attribute " << argName << ";\n"; 490 } else { 491 os << "::mlir::Value " << argName << ";\n"; 492 } 493 } 494 495 capture.push_back(std::move(argName)); 496 } 497 498 auto tail = getTrailingDirectives(tree); 499 if (tail.returnType) 500 PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier"); 501 auto locToUse = getLocation(tail); 502 503 auto fmt = tree.getNativeCodeTemplate(); 504 if (fmt.count("$_self") != 1) 505 PrintFatalError(loc, "NativeCodeCall must have $_self as argument for " 506 "passing the defining Operation"); 507 508 auto nativeCodeCall = std::string( 509 tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), 510 static_cast<ArrayRef<std::string>>(capture))); 511 512 emitMatchCheck(opName, formatv("!::mlir::failed({0})", nativeCodeCall), 513 formatv("\"{0} return ::mlir::failure\"", nativeCodeCall)); 514 515 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 516 auto name = tree.getArgName(i); 517 if (!name.empty() && name != "_") { 518 os << formatv("{0} = {1};\n", name, capture[i]); 519 } 520 } 521 522 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 523 std::string argName = capture[i]; 524 525 // Handle nested DAG construct first 526 if (tree.getArgAsNestedDag(i)) { 527 PrintFatalError( 528 loc, formatv("Matching nested tree in NativeCodecall not support for " 529 "{0} as arg {1}", 530 argName, i)); 531 } 532 533 DagLeaf leaf = tree.getArgAsLeaf(i); 534 535 // The parameter for native function doesn't bind any constraints. 536 if (leaf.isUnspecified()) 537 continue; 538 539 auto constraint = leaf.getAsConstraint(); 540 541 std::string self; 542 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) 543 self = argName; 544 else 545 self = formatv("{0}.getType()", argName); 546 StringRef verifier = staticMatcherHelper.getVerifierName(leaf); 547 emitStaticVerifierCall( 548 verifier, opName, self, 549 formatv("\"operand {0} of native code call '{1}' failed to satisfy " 550 "constraint: " 551 "'{2}'\"", 552 i, tree.getNativeCodeTemplate(), 553 escapeString(constraint.getSummary())) 554 .str()); 555 } 556 557 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n"); 558 } 559 560 // Helper function to match patterns. 561 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { 562 Operator &op = tree.getDialectOp(opMap); 563 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 564 << op.getOperationName() << "' at depth " << depth 565 << '\n'); 566 567 auto getCastedName = [depth]() -> std::string { 568 return formatv("castedOp{0}", depth); 569 }; 570 571 // The order of generating static matcher follows the topological order so 572 // that for every dependent DagNode already have their static matcher 573 // generated if needed. The reason we check if `getMatcherName(tree).empty()` 574 // is when we are generating the static matcher for a DagNode itself. In this 575 // case, we need to emit the function body rather than a function call. 576 if (staticMatcherHelper.useStaticMatcher(tree) && 577 !staticMatcherHelper.getMatcherName(tree).empty()) { 578 emitStaticMatchCall(tree, opName); 579 // In the codegen of rewriter, we suppose that castedOp0 will capture the 580 // root operation. Manually add it if the root DagNode is a static matcher. 581 if (depth == 0) 582 os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); " 583 "(void){2};\n", 584 opName, op.getQualCppClassName(), getCastedName()); 585 return; 586 } 587 588 std::string castedName = getCastedName(); 589 os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); " 590 "(void){0};\n", 591 castedName, opName, op.getQualCppClassName()); 592 593 // Skip the operand matching at depth 0 as the pattern rewriter already does. 594 if (depth != 0) 595 emitMatchCheck(opName, /*matchStr=*/castedName, 596 formatv("\"{0} is not {1} type\"", castedName, 597 op.getQualCppClassName())); 598 599 // If the operand's name is set, set to that variable. 600 auto name = tree.getSymbol(); 601 if (!name.empty()) 602 os << formatv("{0} = {1};\n", name, castedName); 603 604 for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; 605 ++i, ++opArgIdx) { 606 auto opArg = op.getArg(opArgIdx); 607 std::string argName = formatv("op{0}", depth + 1); 608 609 // Handle nested DAG construct first 610 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 611 if (argTree.isEither()) { 612 emitEitherOperandMatch(tree, argTree, castedName, opArgIdx, nextOperand, 613 depth); 614 ++opArgIdx; 615 continue; 616 } 617 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) { 618 if (argTree.isVariadic()) { 619 if (!operand->isVariadic()) { 620 auto error = formatv("variadic DAG construct can't match op {0}'s " 621 "non-variadic operand #{1}", 622 op.getOperationName(), opArgIdx); 623 PrintFatalError(loc, error); 624 } 625 emitVariadicOperandMatch(tree, argTree, castedName, opArgIdx, 626 nextOperand, depth); 627 ++nextOperand; 628 continue; 629 } 630 if (operand->isVariableLength()) { 631 auto error = formatv("use nested DAG construct to match op {0}'s " 632 "variadic operand #{1} unsupported now", 633 op.getOperationName(), opArgIdx); 634 PrintFatalError(loc, error); 635 } 636 } 637 638 os << "{\n"; 639 640 // Attributes don't count for getODSOperands. 641 // TODO: Operand is a Value, check if we should remove `getDefiningOp()`. 642 os.indent() << formatv( 643 "auto *{0} = " 644 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", 645 argName, castedName, nextOperand); 646 // Null check of operand's definingOp 647 emitMatchCheck( 648 castedName, /*matchStr=*/argName, 649 formatv("\"There's no operation that defines operand {0} of {1}\"", 650 nextOperand++, castedName)); 651 emitMatch(argTree, argName, depth + 1); 652 os << formatv("tblgen_ops.push_back({0});\n", argName); 653 os.unindent() << "}\n"; 654 continue; 655 } 656 657 // Next handle DAG leaf: operand or attribute 658 if (isa<NamedTypeConstraint *>(opArg)) { 659 auto operandName = 660 formatv("{0}.getODSOperands({1})", castedName, nextOperand); 661 emitOperandMatch(tree, castedName, operandName.str(), opArgIdx, 662 /*operandMatcher=*/tree.getArgAsLeaf(i), 663 /*argName=*/tree.getArgName(i), opArgIdx, 664 /*variadicSubIndex=*/std::nullopt); 665 ++nextOperand; 666 } else if (isa<NamedAttribute *>(opArg)) { 667 emitAttributeMatch(tree, opName, opArgIdx, depth); 668 } else { 669 PrintFatalError(loc, "unhandled case when matching op"); 670 } 671 } 672 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 673 << op.getOperationName() << "' at depth " << depth 674 << '\n'); 675 } 676 677 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, 678 StringRef operandName, int operandIndex, 679 DagLeaf operandMatcher, StringRef argName, 680 int argIndex, 681 std::optional<int> variadicSubIndex) { 682 Operator &op = tree.getDialectOp(opMap); 683 auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex)); 684 685 // If a constraint is specified, we need to generate C++ statements to 686 // check the constraint. 687 if (!operandMatcher.isUnspecified()) { 688 if (!operandMatcher.isOperandMatcher()) 689 PrintFatalError( 690 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 691 op.getOperationName(), argIndex + 1)); 692 693 // Only need to verify if the matcher's type is different from the one 694 // of op definition. 695 Constraint constraint = operandMatcher.getAsConstraint(); 696 if (operand->constraint != constraint) { 697 if (operand->isVariableLength()) { 698 auto error = formatv( 699 "further constrain op {0}'s variadic operand #{1} unsupported now", 700 op.getOperationName(), argIndex); 701 PrintFatalError(loc, error); 702 } 703 auto self = formatv("(*{0}.begin()).getType()", operandName); 704 StringRef verifier = staticMatcherHelper.getVerifierName(operandMatcher); 705 emitStaticVerifierCall( 706 verifier, opName, self.str(), 707 formatv( 708 "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"", 709 operand - op.operand_begin(), op.getOperationName(), 710 escapeString(constraint.getSummary())) 711 .str()); 712 } 713 } 714 715 // Capture the value 716 // `$_` is a special symbol to ignore op argument matching. 717 if (!argName.empty() && argName != "_") { 718 auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex, 719 variadicSubIndex); 720 if (res == symbolInfoMap.end()) 721 PrintFatalError(loc, formatv("symbol not found: {0}", argName)); 722 723 os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName); 724 } 725 } 726 727 void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, 728 StringRef opName, int argIndex, 729 int &operandIndex, int depth) { 730 constexpr int numEitherArgs = 2; 731 if (eitherArgTree.getNumArgs() != numEitherArgs) 732 PrintFatalError(loc, "`either` only supports grouping two operands"); 733 734 Operator &op = tree.getDialectOp(opMap); 735 736 std::string codeBuffer; 737 llvm::raw_string_ostream tblgenOps(codeBuffer); 738 739 std::string lambda = formatv("eitherLambda{0}", depth); 740 os << formatv( 741 "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n", 742 lambda); 743 744 os.indent(); 745 746 for (int i = 0; i < numEitherArgs; ++i, ++argIndex) { 747 if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) { 748 if (argTree.isEither()) 749 PrintFatalError(loc, "either cannot be nested"); 750 751 std::string argName = formatv("local_op_{0}", i).str(); 752 753 os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName, 754 i); 755 756 // Indent emitMatchCheck and emitMatch because they declare local 757 // variables. 758 os << "{\n"; 759 os.indent(); 760 761 emitMatchCheck( 762 opName, /*matchStr=*/argName, 763 formatv("\"There's no operation that defines operand {0} of {1}\"", 764 operandIndex++, opName)); 765 emitMatch(argTree, argName, depth + 1); 766 767 os.unindent() << "}\n"; 768 769 // `tblgen_ops` is used to collect the matched operations. In either, we 770 // need to queue the operation only if the matching success. Thus we emit 771 // the code at the end. 772 tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName); 773 } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) { 774 emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(), 775 operandIndex, 776 /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i), 777 /*argName=*/eitherArgTree.getArgName(i), argIndex, 778 /*variadicSubIndex=*/std::nullopt); 779 ++operandIndex; 780 } else { 781 PrintFatalError(loc, "either can only be applied on operand"); 782 } 783 } 784 785 os << tblgenOps.str(); 786 os << "return ::mlir::success();\n"; 787 os.unindent() << "};\n"; 788 789 os << "{\n"; 790 os.indent(); 791 792 os << formatv("auto eitherOperand0 = {0}.getODSOperands({1});\n", opName, 793 operandIndex - 2); 794 os << formatv("auto eitherOperand1 = {0}.getODSOperands({1});\n", opName, 795 operandIndex - 1); 796 797 os << formatv("if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && " 798 "::mlir::failed({0}(eitherOperand1, " 799 "eitherOperand0)))\n", 800 lambda); 801 os.indent() << "return ::mlir::failure();\n"; 802 803 os.unindent().unindent() << "}\n"; 804 } 805 806 void PatternEmitter::emitVariadicOperandMatch(DagNode tree, 807 DagNode variadicArgTree, 808 StringRef opName, int argIndex, 809 int &operandIndex, int depth) { 810 Operator &op = tree.getDialectOp(opMap); 811 812 os << "{\n"; 813 os.indent(); 814 815 os << formatv("auto variadic_operand_range = {0}.getODSOperands({1});\n", 816 opName, operandIndex); 817 os << formatv("if (variadic_operand_range.size() != {0}) " 818 "return ::mlir::failure();\n", 819 variadicArgTree.getNumArgs()); 820 821 StringRef variadicTreeName = variadicArgTree.getSymbol(); 822 if (!variadicTreeName.empty()) { 823 auto res = 824 symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex, 825 /*variadicSubIndex=*/std::nullopt); 826 if (res == symbolInfoMap.end()) 827 PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName)); 828 829 os << formatv("{0} = variadic_operand_range;\n", 830 res->second.getVarName(variadicTreeName)); 831 } 832 833 for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) { 834 if (DagNode argTree = variadicArgTree.getArgAsNestedDag(i)) { 835 if (!argTree.isOperation()) 836 PrintFatalError(loc, "variadic only accepts operation sub-dags"); 837 838 os << "{\n"; 839 os.indent(); 840 841 std::string argName = formatv("local_op_{0}", i).str(); 842 os << formatv("auto *{0} = " 843 "variadic_operand_range[{1}].getDefiningOp();\n", 844 argName, i); 845 emitMatchCheck( 846 opName, /*matchStr=*/argName, 847 formatv("\"There's no operation that defines variadic operand " 848 "{0} (variadic sub-opearnd #{1}) of {2}\"", 849 operandIndex, i, opName)); 850 emitMatch(argTree, argName, depth + 1); 851 os << formatv("tblgen_ops.push_back({0});\n", argName); 852 853 os.unindent() << "}\n"; 854 } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) { 855 auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i); 856 emitOperandMatch(tree, opName, operandName.str(), operandIndex, 857 /*operandMatcher=*/variadicArgTree.getArgAsLeaf(i), 858 /*argName=*/variadicArgTree.getArgName(i), argIndex, i); 859 } else { 860 PrintFatalError(loc, "variadic can only be applied on operand"); 861 } 862 } 863 864 os.unindent() << "}\n"; 865 } 866 867 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, 868 int argIndex, int depth) { 869 Operator &op = tree.getDialectOp(opMap); 870 auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex)); 871 const auto &attr = namedAttr->attr; 872 873 os << "{\n"; 874 os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" 875 "(void)tblgen_attr;\n", 876 opName, attr.getStorageType(), namedAttr->name); 877 878 // TODO: This should use getter method to avoid duplication. 879 if (attr.hasDefaultValue()) { 880 os << "if (!tblgen_attr) tblgen_attr = " 881 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 882 attr.getDefaultValue())) 883 << ";\n"; 884 } else if (attr.isOptional()) { 885 // For a missing attribute that is optional according to definition, we 886 // should just capture a mlir::Attribute() to signal the missing state. 887 // That is precisely what getDiscardableAttr() returns on missing 888 // attributes. 889 } else { 890 emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx), 891 formatv("\"expected op '{0}' to have attribute '{1}' " 892 "of type '{2}'\"", 893 op.getOperationName(), namedAttr->name, 894 attr.getStorageType())); 895 } 896 897 auto matcher = tree.getArgAsLeaf(argIndex); 898 if (!matcher.isUnspecified()) { 899 if (!matcher.isAttrMatcher()) { 900 PrintFatalError( 901 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 902 op.getOperationName(), argIndex + 1)); 903 } 904 905 // If a constraint is specified, we need to generate function call to its 906 // static verifier. 907 StringRef verifier = staticMatcherHelper.getVerifierName(matcher); 908 if (attr.isOptional()) { 909 // Avoid dereferencing null attribute. This is using a simple heuristic to 910 // avoid common cases of attempting to dereference null attribute. This 911 // will return where there is no check if attribute is null unless the 912 // attribute's value is not used. 913 // FIXME: This could be improved as some null dereferences could slip 914 // through. 915 if (!StringRef(matcher.getConditionTemplate()).contains("!$_self") && 916 StringRef(matcher.getConditionTemplate()).contains("$_self")) { 917 os << "if (!tblgen_attr) return ::mlir::failure();\n"; 918 } 919 } 920 emitStaticVerifierCall( 921 verifier, opName, "tblgen_attr", 922 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " 923 "'{2}'\"", 924 op.getOperationName(), namedAttr->name, 925 escapeString(matcher.getAsConstraint().getSummary())) 926 .str()); 927 } 928 929 // Capture the value 930 auto name = tree.getArgName(argIndex); 931 // `$_` is a special symbol to ignore op argument matching. 932 if (!name.empty() && name != "_") { 933 os << formatv("{0} = tblgen_attr;\n", name); 934 } 935 936 os.unindent() << "}\n"; 937 } 938 939 void PatternEmitter::emitMatchCheck( 940 StringRef opName, const FmtObjectBase &matchFmt, 941 const llvm::formatv_object_base &failureFmt) { 942 emitMatchCheck(opName, matchFmt.str(), failureFmt.str()); 943 } 944 945 void PatternEmitter::emitMatchCheck(StringRef opName, 946 const std::string &matchStr, 947 const std::string &failureStr) { 948 949 os << "if (!(" << matchStr << "))"; 950 os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName 951 << ", [&](::mlir::Diagnostic &diag) {\n diag << " 952 << failureStr << ";\n});"; 953 } 954 955 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { 956 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 957 int depth = 0; 958 emitMatch(tree, opName, depth); 959 960 for (auto &appliedConstraint : pattern.getConstraints()) { 961 auto &constraint = appliedConstraint.constraint; 962 auto &entities = appliedConstraint.entities; 963 964 auto condition = constraint.getConditionTemplate(); 965 if (isa<TypeConstraint>(constraint)) { 966 if (entities.size() != 1) 967 PrintFatalError(loc, "type constraint requires exactly one argument"); 968 969 auto self = formatv("({0}.getType())", 970 symbolInfoMap.getValueAndRangeUse(entities.front())); 971 emitMatchCheck( 972 opName, tgfmt(condition, &fmtCtx.withSelf(self.str())), 973 formatv("\"value entity '{0}' failed to satisfy constraint: '{1}'\"", 974 entities.front(), escapeString(constraint.getSummary()))); 975 976 } else if (isa<AttrConstraint>(constraint)) { 977 PrintFatalError( 978 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 979 } else { 980 // TODO: replace formatv arguments with the exact specified 981 // args. 982 if (entities.size() > 4) { 983 PrintFatalError(loc, "only support up to 4-entity constraints now"); 984 } 985 SmallVector<std::string, 4> names; 986 int i = 0; 987 for (int e = entities.size(); i < e; ++i) 988 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 989 std::string self = appliedConstraint.self; 990 if (!self.empty()) 991 self = symbolInfoMap.getValueAndRangeUse(self); 992 for (; i < 4; ++i) 993 names.push_back("<unused>"); 994 emitMatchCheck(opName, 995 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 996 names[1], names[2], names[3]), 997 formatv("\"entities '{0}' failed to satisfy constraint: " 998 "'{1}'\"", 999 llvm::join(entities, ", "), 1000 escapeString(constraint.getSummary()))); 1001 } 1002 } 1003 1004 // Some of the operands could be bound to the same symbol name, we need 1005 // to enforce equality constraint on those. 1006 // TODO: we should be able to emit equality checks early 1007 // and short circuit unnecessary work if vars are not equal. 1008 for (auto symbolInfoIt = symbolInfoMap.begin(); 1009 symbolInfoIt != symbolInfoMap.end();) { 1010 auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); 1011 auto startRange = range.first; 1012 auto endRange = range.second; 1013 1014 auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); 1015 for (++startRange; startRange != endRange; ++startRange) { 1016 auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); 1017 emitMatchCheck( 1018 opName, 1019 formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), 1020 formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, 1021 secondOperand)); 1022 } 1023 1024 symbolInfoIt = endRange; 1025 } 1026 1027 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 1028 } 1029 1030 void PatternEmitter::collectOps(DagNode tree, 1031 llvm::SmallPtrSetImpl<const Operator *> &ops) { 1032 // Check if this tree is an operation. 1033 if (tree.isOperation()) { 1034 const Operator &op = tree.getDialectOp(opMap); 1035 LLVM_DEBUG(llvm::dbgs() 1036 << "found operation " << op.getOperationName() << '\n'); 1037 ops.insert(&op); 1038 } 1039 1040 // Recurse the arguments of the tree. 1041 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 1042 if (auto child = tree.getArgAsNestedDag(i)) 1043 collectOps(child, ops); 1044 } 1045 1046 void PatternEmitter::emit(StringRef rewriteName) { 1047 // Get the DAG tree for the source pattern. 1048 DagNode sourceTree = pattern.getSourcePattern(); 1049 1050 const Operator &rootOp = pattern.getSourceRootOp(); 1051 auto rootName = rootOp.getOperationName(); 1052 1053 // Collect the set of result operations. 1054 llvm::SmallPtrSet<const Operator *, 4> resultOps; 1055 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 1056 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 1057 collectOps(pattern.getResultPattern(i), resultOps); 1058 } 1059 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 1060 1061 // Emit RewritePattern for Pattern. 1062 auto locs = pattern.getLocation(); 1063 os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n", 1064 make_range(locs.rbegin(), locs.rend())); 1065 os << formatv(R"(struct {0} : public ::mlir::RewritePattern { 1066 {0}(::mlir::MLIRContext *context) 1067 : ::mlir::RewritePattern("{1}", {2}, context, {{)", 1068 rewriteName, rootName, pattern.getBenefit()); 1069 // Sort result operators by name. 1070 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 1071 resultOps.end()); 1072 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 1073 return lhs->getOperationName() < rhs->getOperationName(); 1074 }); 1075 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) { 1076 os << '"' << op->getOperationName() << '"'; 1077 }); 1078 os << "}) {}\n"; 1079 1080 // Emit matchAndRewrite() function. 1081 { 1082 auto classScope = os.scope(); 1083 os.printReindented(R"( 1084 ::llvm::LogicalResult matchAndRewrite(::mlir::Operation *op0, 1085 ::mlir::PatternRewriter &rewriter) const override {)") 1086 << '\n'; 1087 { 1088 auto functionScope = os.scope(); 1089 1090 // Register all symbols bound in the source pattern. 1091 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 1092 1093 LLVM_DEBUG(llvm::dbgs() 1094 << "start creating local variables for capturing matches\n"); 1095 os << "// Variables for capturing values and attributes used while " 1096 "creating ops\n"; 1097 // Create local variables for storing the arguments and results bound 1098 // to symbols. 1099 for (const auto &symbolInfoPair : symbolInfoMap) { 1100 const auto &symbol = symbolInfoPair.first; 1101 const auto &info = symbolInfoPair.second; 1102 1103 os << info.getVarDecl(symbol); 1104 } 1105 // TODO: capture ops with consistent numbering so that it can be 1106 // reused for fused loc. 1107 os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n"; 1108 LLVM_DEBUG(llvm::dbgs() 1109 << "done creating local variables for capturing matches\n"); 1110 1111 os << "// Match\n"; 1112 os << "tblgen_ops.push_back(op0);\n"; 1113 emitMatchLogic(sourceTree, "op0"); 1114 1115 os << "\n// Rewrite\n"; 1116 emitRewriteLogic(); 1117 1118 os << "return ::mlir::success();\n"; 1119 } 1120 os << "}\n"; 1121 } 1122 os << "};\n\n"; 1123 } 1124 1125 void PatternEmitter::emitRewriteLogic() { 1126 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 1127 const Operator &rootOp = pattern.getSourceRootOp(); 1128 int numExpectedResults = rootOp.getNumResults(); 1129 int numResultPatterns = pattern.getNumResultPatterns(); 1130 1131 // First register all symbols bound to ops generated in result patterns. 1132 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 1133 1134 // Only the last N static values generated are used to replace the matched 1135 // root N-result op. We need to calculate the starting index (of the results 1136 // of the matched op) each result pattern is to replace. 1137 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 1138 // If we don't need to replace any value at all, set the replacement starting 1139 // index as the number of result patterns so we skip all of them when trying 1140 // to replace the matched op's results. 1141 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 1142 for (int i = numResultPatterns - 1; i >= 0; --i) { 1143 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 1144 offsets[i] = offsets[i + 1] - numValues; 1145 if (offsets[i] == 0) { 1146 if (replStartIndex == -1) 1147 replStartIndex = i; 1148 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 1149 auto error = formatv( 1150 "cannot use the same multi-result op '{0}' to generate both " 1151 "auxiliary values and values to be used for replacing the matched op", 1152 pattern.getResultPattern(i).getSymbol()); 1153 PrintFatalError(loc, error); 1154 } 1155 } 1156 1157 if (offsets.front() > 0) { 1158 const char error[] = 1159 "not enough values generated to replace the matched op"; 1160 PrintFatalError(loc, error); 1161 } 1162 1163 os << "auto odsLoc = rewriter.getFusedLoc({"; 1164 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 1165 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 1166 } 1167 os << "}); (void)odsLoc;\n"; 1168 1169 // Process auxiliary result patterns. 1170 for (int i = 0; i < replStartIndex; ++i) { 1171 DagNode resultTree = pattern.getResultPattern(i); 1172 auto val = handleResultPattern(resultTree, offsets[i], 0); 1173 // Normal op creation will be streamed to `os` by the above call; but 1174 // NativeCodeCall will only be materialized to `os` if it is used. Here 1175 // we are handling auxiliary patterns so we want the side effect even if 1176 // NativeCodeCall is not replacing matched root op's results. 1177 if (resultTree.isNativeCodeCall() && 1178 resultTree.getNumReturnsOfNativeCode() == 0) 1179 os << val << ";\n"; 1180 } 1181 1182 auto processSupplementalPatterns = [&]() { 1183 int numSupplementalPatterns = pattern.getNumSupplementalPatterns(); 1184 for (int i = 0, offset = -numSupplementalPatterns; 1185 i < numSupplementalPatterns; ++i) { 1186 DagNode resultTree = pattern.getSupplementalPattern(i); 1187 auto val = handleResultPattern(resultTree, offset++, 0); 1188 if (resultTree.isNativeCodeCall() && 1189 resultTree.getNumReturnsOfNativeCode() == 0) 1190 os << val << ";\n"; 1191 } 1192 }; 1193 1194 if (numExpectedResults == 0) { 1195 assert(replStartIndex >= numResultPatterns && 1196 "invalid auxiliary vs. replacement pattern division!"); 1197 processSupplementalPatterns(); 1198 // No result to replace. Just erase the op. 1199 os << "rewriter.eraseOp(op0);\n"; 1200 } else { 1201 // Process replacement result patterns. 1202 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n"; 1203 for (int i = replStartIndex; i < numResultPatterns; ++i) { 1204 DagNode resultTree = pattern.getResultPattern(i); 1205 auto val = handleResultPattern(resultTree, offsets[i], 0); 1206 os << "\n"; 1207 // Resolve each symbol for all range use so that we can loop over them. 1208 // We need an explicit cast to `SmallVector` to capture the cases where 1209 // `{0}` resolves to an `Operation::result_range` as well as cases that 1210 // are not iterable (e.g. vector that gets wrapped in additional braces by 1211 // RewriterGen). 1212 // TODO: Revisit the need for materializing a vector. 1213 os << symbolInfoMap.getAllRangeUse( 1214 val, 1215 "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" 1216 " tblgen_repl_values.push_back(v);\n}\n", 1217 "\n"); 1218 } 1219 processSupplementalPatterns(); 1220 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n"; 1221 } 1222 1223 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 1224 } 1225 1226 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 1227 return std::string( 1228 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++)); 1229 } 1230 1231 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 1232 int resultIndex, int depth) { 1233 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 1234 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 1235 LLVM_DEBUG(llvm::dbgs() << '\n'); 1236 1237 if (resultTree.isLocationDirective()) { 1238 PrintFatalError(loc, 1239 "location directive can only be used with op creation"); 1240 } 1241 1242 if (resultTree.isNativeCodeCall()) 1243 return handleReplaceWithNativeCodeCall(resultTree, depth); 1244 1245 if (resultTree.isReplaceWithValue()) 1246 return handleReplaceWithValue(resultTree).str(); 1247 1248 if (resultTree.isVariadic()) 1249 return handleVariadic(resultTree, depth); 1250 1251 // Normal op creation. 1252 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 1253 if (resultTree.getSymbol().empty()) { 1254 // This is an op not explicitly bound to a symbol in the rewrite rule. 1255 // Register the auto-generated symbol for it. 1256 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 1257 } 1258 return symbol; 1259 } 1260 1261 std::string PatternEmitter::handleVariadic(DagNode tree, int depth) { 1262 assert(tree.isVariadic()); 1263 1264 std::string output; 1265 llvm::raw_string_ostream oss(output); 1266 auto name = std::string(formatv("tblgen_variadic_values_{0}", nextValueId++)); 1267 symbolInfoMap.bindValue(name); 1268 oss << "::llvm::SmallVector<::mlir::Value, 4> " << name << ";\n"; 1269 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 1270 if (auto child = tree.getArgAsNestedDag(i)) { 1271 oss << name << ".push_back(" << handleResultPattern(child, i, depth + 1) 1272 << ");\n"; 1273 } else { 1274 oss << name << ".push_back(" 1275 << handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)) 1276 << ");\n"; 1277 } 1278 } 1279 1280 os << oss.str(); 1281 return name; 1282 } 1283 1284 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { 1285 assert(tree.isReplaceWithValue()); 1286 1287 if (tree.getNumArgs() != 1) { 1288 PrintFatalError( 1289 loc, "replaceWithValue directive must take exactly one argument"); 1290 } 1291 1292 if (!tree.getSymbol().empty()) { 1293 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 1294 } 1295 1296 return tree.getArgName(0); 1297 } 1298 1299 std::string PatternEmitter::handleLocationDirective(DagNode tree) { 1300 assert(tree.isLocationDirective()); 1301 auto lookUpArgLoc = [this, &tree](int idx) { 1302 const auto *const lookupFmt = "{0}.getLoc()"; 1303 return symbolInfoMap.getValueAndRangeUse(tree.getArgName(idx), lookupFmt); 1304 }; 1305 1306 if (tree.getNumArgs() == 0) 1307 llvm::PrintFatalError( 1308 "At least one argument to location directive required"); 1309 1310 if (!tree.getSymbol().empty()) 1311 PrintFatalError(loc, "cannot bind symbol to location"); 1312 1313 if (tree.getNumArgs() == 1) { 1314 DagLeaf leaf = tree.getArgAsLeaf(0); 1315 if (leaf.isStringAttr()) 1316 return formatv("::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))", 1317 leaf.getStringAttr()) 1318 .str(); 1319 return lookUpArgLoc(0); 1320 } 1321 1322 std::string ret; 1323 llvm::raw_string_ostream os(ret); 1324 std::string strAttr; 1325 os << "rewriter.getFusedLoc({"; 1326 bool first = true; 1327 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 1328 DagLeaf leaf = tree.getArgAsLeaf(i); 1329 // Handle the optional string value. 1330 if (leaf.isStringAttr()) { 1331 if (!strAttr.empty()) 1332 llvm::PrintFatalError("Only one string attribute may be specified"); 1333 strAttr = leaf.getStringAttr(); 1334 continue; 1335 } 1336 os << (first ? "" : ", ") << lookUpArgLoc(i); 1337 first = false; 1338 } 1339 os << "}"; 1340 if (!strAttr.empty()) { 1341 os << ", rewriter.getStringAttr(\"" << strAttr << "\")"; 1342 } 1343 os << ")"; 1344 return os.str(); 1345 } 1346 1347 std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i, 1348 int depth) { 1349 // Nested NativeCodeCall. 1350 if (auto dagNode = returnType.getArgAsNestedDag(i)) { 1351 if (!dagNode.isNativeCodeCall()) 1352 PrintFatalError(loc, "nested DAG in `returnType` must be a native code " 1353 "call"); 1354 return handleReplaceWithNativeCodeCall(dagNode, depth); 1355 } 1356 // String literal. 1357 auto dagLeaf = returnType.getArgAsLeaf(i); 1358 if (dagLeaf.isStringAttr()) 1359 return tgfmt(dagLeaf.getStringAttr(), &fmtCtx); 1360 return tgfmt( 1361 "$0.getType()", &fmtCtx, 1362 handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i))); 1363 } 1364 1365 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 1366 StringRef patArgName) { 1367 if (leaf.isStringAttr()) 1368 PrintFatalError(loc, "raw string not supported as argument"); 1369 if (leaf.isConstantAttr()) { 1370 auto constAttr = leaf.getAsConstantAttr(); 1371 return handleConstantAttr(constAttr.getAttribute(), 1372 constAttr.getConstantValue()); 1373 } 1374 if (leaf.isEnumAttrCase()) { 1375 auto enumCase = leaf.getAsEnumAttrCase(); 1376 // This is an enum case backed by an IntegerAttr. We need to get its value 1377 // to build the constant. 1378 std::string val = std::to_string(enumCase.getValue()); 1379 return handleConstantAttr(enumCase, val); 1380 } 1381 1382 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 1383 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 1384 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 1385 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 1386 << "' (via symbol ref)\n"); 1387 return argName; 1388 } 1389 if (leaf.isNativeCodeCall()) { 1390 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 1391 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 1392 << "' (via NativeCodeCall)\n"); 1393 return std::string(repl); 1394 } 1395 PrintFatalError(loc, "unhandled case when rewriting op"); 1396 } 1397 1398 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, 1399 int depth) { 1400 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 1401 LLVM_DEBUG(tree.print(llvm::dbgs())); 1402 LLVM_DEBUG(llvm::dbgs() << '\n'); 1403 1404 auto fmt = tree.getNativeCodeTemplate(); 1405 1406 SmallVector<std::string, 16> attrs; 1407 1408 auto tail = getTrailingDirectives(tree); 1409 if (tail.returnType) 1410 PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier"); 1411 auto locToUse = getLocation(tail); 1412 1413 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 1414 if (tree.isNestedDagArg(i)) { 1415 attrs.push_back( 1416 handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1)); 1417 } else { 1418 attrs.push_back( 1419 handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i))); 1420 } 1421 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i 1422 << " replacement: " << attrs[i] << "\n"); 1423 } 1424 1425 std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), 1426 static_cast<ArrayRef<std::string>>(attrs)); 1427 1428 // In general, NativeCodeCall without naming binding don't need this. To 1429 // ensure void helper function has been correctly labeled, i.e., use 1430 // NativeCodeCallVoid, we cache the result to a local variable so that we will 1431 // get a compilation error in the auto-generated file. 1432 // Example. 1433 // // In the td file 1434 // Pat<(...), (NativeCodeCall<Foo> ...)> 1435 // 1436 // --- 1437 // 1438 // // In the auto-generated .cpp 1439 // ... 1440 // // Causes compilation error if Foo() returns void. 1441 // auto nativeVar = Foo(); 1442 // ... 1443 if (tree.getNumReturnsOfNativeCode() != 0) { 1444 // Determine the local variable name for return value. 1445 std::string varName = 1446 SymbolInfoMap::getValuePackName(tree.getSymbol()).str(); 1447 if (varName.empty()) { 1448 varName = formatv("nativeVar_{0}", nextValueId++); 1449 // Register the local variable for later uses. 1450 symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode()); 1451 } 1452 1453 // Catch the return value of helper function. 1454 os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol); 1455 1456 if (!tree.getSymbol().empty()) 1457 symbol = tree.getSymbol().str(); 1458 else 1459 symbol = varName; 1460 } 1461 1462 return symbol; 1463 } 1464 1465 int PatternEmitter::getNodeValueCount(DagNode node) { 1466 if (node.isOperation()) { 1467 // If the op is bound to a symbol in the rewrite rule, query its result 1468 // count from the symbol info map. 1469 auto symbol = node.getSymbol(); 1470 if (!symbol.empty()) { 1471 return symbolInfoMap.getStaticValueCount(symbol); 1472 } 1473 // Otherwise this is an unbound op; we will use all its results. 1474 return pattern.getDialectOp(node).getNumResults(); 1475 } 1476 1477 if (node.isNativeCodeCall()) 1478 return node.getNumReturnsOfNativeCode(); 1479 1480 return 1; 1481 } 1482 1483 PatternEmitter::TrailingDirectives 1484 PatternEmitter::getTrailingDirectives(DagNode tree) { 1485 TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0}; 1486 1487 // Look backwards through the arguments. 1488 auto numPatArgs = tree.getNumArgs(); 1489 for (int i = numPatArgs - 1; i >= 0; --i) { 1490 auto dagArg = tree.getArgAsNestedDag(i); 1491 // A leaf is not a directive. Stop looking. 1492 if (!dagArg) 1493 break; 1494 1495 auto isLocation = dagArg.isLocationDirective(); 1496 auto isReturnType = dagArg.isReturnTypeDirective(); 1497 // If encountered a DAG node that isn't a trailing directive, stop looking. 1498 if (!(isLocation || isReturnType)) 1499 break; 1500 // Save the directive, but error if one of the same type was already 1501 // found. 1502 ++tail.numDirectives; 1503 if (isLocation) { 1504 if (tail.location) 1505 PrintFatalError(loc, "`location` directive can only be specified " 1506 "once"); 1507 tail.location = dagArg; 1508 } else if (isReturnType) { 1509 if (tail.returnType) 1510 PrintFatalError(loc, "`returnType` directive can only be specified " 1511 "once"); 1512 tail.returnType = dagArg; 1513 } 1514 } 1515 1516 return tail; 1517 } 1518 1519 std::string 1520 PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) { 1521 if (tail.location) 1522 return handleLocationDirective(tail.location); 1523 1524 // If no explicit location is given, use the default, all fused, location. 1525 return "odsLoc"; 1526 } 1527 1528 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 1529 int depth) { 1530 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); 1531 LLVM_DEBUG(tree.print(llvm::dbgs())); 1532 LLVM_DEBUG(llvm::dbgs() << '\n'); 1533 1534 Operator &resultOp = tree.getDialectOp(opMap); 1535 auto numOpArgs = resultOp.getNumArgs(); 1536 auto numPatArgs = tree.getNumArgs(); 1537 1538 auto tail = getTrailingDirectives(tree); 1539 auto locToUse = getLocation(tail); 1540 1541 auto inPattern = numPatArgs - tail.numDirectives; 1542 if (numOpArgs != inPattern) { 1543 PrintFatalError(loc, 1544 formatv("resultant op '{0}' argument number mismatch: " 1545 "{1} in pattern vs. {2} in definition", 1546 resultOp.getOperationName(), inPattern, numOpArgs)); 1547 } 1548 1549 // A map to collect all nested DAG child nodes' names, with operand index as 1550 // the key. This includes both bound and unbound child nodes. 1551 ChildNodeIndexNameMap childNodeNames; 1552 1553 // If the argument is a type constraint, then its an operand. Check if the 1554 // op's argument is variadic that the argument in the pattern is too. 1555 auto checkIfMatchedVariadic = [&](int i) { 1556 // FIXME: This does not yet check for variable/leaf case. 1557 // FIXME: Change so that native code call can be handled. 1558 const auto *operand = 1559 llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(i)); 1560 if (!operand || !operand->isVariadic()) 1561 return; 1562 1563 auto child = tree.getArgAsNestedDag(i); 1564 if (!child) 1565 return; 1566 1567 // Skip over replaceWithValues. 1568 while (child.isReplaceWithValue()) { 1569 if (!(child = child.getArgAsNestedDag(0))) 1570 return; 1571 } 1572 if (!child.isNativeCodeCall() && !child.isVariadic()) 1573 PrintFatalError(loc, formatv("op expects variadic operand `{0}`, while " 1574 "provided is non-variadic", 1575 resultOp.getArgName(i))); 1576 }; 1577 1578 // First go through all the child nodes who are nested DAG constructs to 1579 // create ops for them and remember the symbol names for them, so that we can 1580 // use the results in the current node. This happens in a recursive manner. 1581 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 1582 checkIfMatchedVariadic(i); 1583 if (auto child = tree.getArgAsNestedDag(i)) 1584 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 1585 } 1586 1587 // The name of the local variable holding this op. 1588 std::string valuePackName; 1589 // The symbol for holding the result of this pattern. Note that the result of 1590 // this pattern is not necessarily the same as the variable created by this 1591 // pattern because we can use `__N` suffix to refer only a specific result if 1592 // the generated op is a multi-result op. 1593 std::string resultValue; 1594 if (tree.getSymbol().empty()) { 1595 // No symbol is explicitly bound to this op in the pattern. Generate a 1596 // unique name. 1597 valuePackName = resultValue = getUniqueSymbol(&resultOp); 1598 } else { 1599 resultValue = std::string(tree.getSymbol()); 1600 // Strip the index to get the name for the value pack and use it to name the 1601 // local variable for the op. 1602 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue)); 1603 } 1604 1605 // Create the local variable for this op. 1606 os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(), 1607 valuePackName); 1608 1609 // Right now ODS don't have general type inference support. Except a few 1610 // special cases listed below, DRR needs to supply types for all results 1611 // when building an op. 1612 bool isSameOperandsAndResultType = 1613 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType"); 1614 bool useFirstAttr = 1615 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"); 1616 1617 if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) { 1618 // We know how to deduce the result type for ops with these traits and we've 1619 // generated builders taking aggregate parameters. Use those builders to 1620 // create the ops. 1621 1622 // First prepare local variables for op arguments used in builder call. 1623 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 1624 1625 // Then create the op. 1626 os.scope("", "\n}\n").os << formatv( 1627 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);", 1628 valuePackName, resultOp.getQualCppClassName(), locToUse); 1629 return resultValue; 1630 } 1631 1632 bool usePartialResults = valuePackName != resultValue; 1633 1634 if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) { 1635 // For these cases (broadcastable ops, op results used both as auxiliary 1636 // values and replacement values, ops in nested patterns, auxiliary ops), we 1637 // still need to supply the result types when building the op. But because 1638 // we don't generate a builder automatically with ODS for them, it's the 1639 // developer's responsibility to make sure such a builder (with result type 1640 // deduction ability) exists. We go through the separate-parameter builder 1641 // here given that it's easier for developers to write compared to 1642 // aggregate-parameter builders. 1643 createSeparateLocalVarsForOpArgs(tree, childNodeNames); 1644 1645 os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, 1646 resultOp.getQualCppClassName(), locToUse); 1647 supplyValuesForOpArgs(tree, childNodeNames, depth); 1648 os << "\n );\n}\n"; 1649 return resultValue; 1650 } 1651 1652 // If we are provided explicit return types, use them to build the op. 1653 // However, if depth == 0 and resultIndex >= 0, it means we are replacing 1654 // the values generated from the source pattern root op. Then we must use the 1655 // source pattern's value types to determine the value type of the generated 1656 // op here. 1657 if (depth == 0 && resultIndex >= 0 && tail.returnType) 1658 PrintFatalError(loc, "Cannot specify explicit return types in an op whose " 1659 "return values replace the source pattern's root op"); 1660 1661 // First prepare local variables for op arguments used in builder call. 1662 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 1663 1664 // Then prepare the result types. We need to specify the types for all 1665 // results. 1666 os.indent() << formatv("::llvm::SmallVector<::mlir::Type, 4> tblgen_types; " 1667 "(void)tblgen_types;\n"); 1668 int numResults = resultOp.getNumResults(); 1669 if (tail.returnType) { 1670 auto numRetTys = tail.returnType.getNumArgs(); 1671 for (int i = 0; i < numRetTys; ++i) { 1672 auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1); 1673 os << "tblgen_types.push_back(" << varName << ");\n"; 1674 } 1675 } else { 1676 if (numResults != 0) { 1677 // Copy the result types from the source pattern. 1678 for (int i = 0; i < numResults; ++i) 1679 os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" 1680 " tblgen_types.push_back(v.getType());\n}\n", 1681 resultIndex + i); 1682 } 1683 } 1684 os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " 1685 "tblgen_values, tblgen_attrs);\n", 1686 valuePackName, resultOp.getQualCppClassName(), locToUse); 1687 os.unindent() << "}\n"; 1688 return resultValue; 1689 } 1690 1691 void PatternEmitter::createSeparateLocalVarsForOpArgs( 1692 DagNode node, ChildNodeIndexNameMap &childNodeNames) { 1693 Operator &resultOp = node.getDialectOp(opMap); 1694 1695 // Now prepare operands used for building this op: 1696 // * If the operand is non-variadic, we create a `Value` local variable. 1697 // * If the operand is variadic, we create a `SmallVector<Value>` local 1698 // variable. 1699 1700 int valueIndex = 0; // An index for uniquing local variable names. 1701 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1702 const auto *operand = 1703 llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(argIndex)); 1704 // We do not need special handling for attributes. 1705 if (!operand) 1706 continue; 1707 1708 raw_indented_ostream::DelimitedScope scope(os); 1709 std::string varName; 1710 if (operand->isVariadic()) { 1711 varName = std::string(formatv("tblgen_values_{0}", valueIndex++)); 1712 os << formatv("::llvm::SmallVector<::mlir::Value, 4> {0};\n", varName); 1713 std::string range; 1714 if (node.isNestedDagArg(argIndex)) { 1715 range = childNodeNames[argIndex]; 1716 } else { 1717 range = std::string(node.getArgName(argIndex)); 1718 } 1719 // Resolve the symbol for all range use so that we have a uniform way of 1720 // capturing the values. 1721 range = symbolInfoMap.getValueAndRangeUse(range); 1722 os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range, 1723 varName); 1724 } else { 1725 varName = std::string(formatv("tblgen_value_{0}", valueIndex++)); 1726 os << formatv("::mlir::Value {0} = ", varName); 1727 if (node.isNestedDagArg(argIndex)) { 1728 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 1729 } else { 1730 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1731 auto symbol = 1732 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1733 if (leaf.isNativeCodeCall()) { 1734 os << std::string( 1735 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1736 } else { 1737 os << symbol; 1738 } 1739 } 1740 os << ";\n"; 1741 } 1742 1743 // Update to use the newly created local variable for building the op later. 1744 childNodeNames[argIndex] = varName; 1745 } 1746 } 1747 1748 void PatternEmitter::supplyValuesForOpArgs( 1749 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 1750 Operator &resultOp = node.getDialectOp(opMap); 1751 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); 1752 argIndex != numOpArgs; ++argIndex) { 1753 // Start each argument on its own line. 1754 os << ",\n "; 1755 1756 Argument opArg = resultOp.getArg(argIndex); 1757 // Handle the case of operand first. 1758 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) { 1759 if (!operand->name.empty()) 1760 os << "/*" << operand->name << "=*/"; 1761 os << childNodeNames.lookup(argIndex); 1762 continue; 1763 } 1764 1765 // The argument in the op definition. 1766 auto opArgName = resultOp.getArgName(argIndex); 1767 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1768 if (!subTree.isNativeCodeCall()) 1769 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1770 "for creating attribute"); 1771 os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex)); 1772 } else { 1773 auto leaf = node.getArgAsLeaf(argIndex); 1774 // The argument in the result DAG pattern. 1775 auto patArgName = node.getArgName(argIndex); 1776 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 1777 // TODO: Refactor out into map to avoid recomputing these. 1778 if (!isa<NamedAttribute *>(opArg)) 1779 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 1780 if (!patArgName.empty()) 1781 os << "/*" << patArgName << "=*/"; 1782 } else { 1783 os << "/*" << opArgName << "=*/"; 1784 } 1785 os << handleOpArgument(leaf, patArgName); 1786 } 1787 } 1788 } 1789 1790 void PatternEmitter::createAggregateLocalVarsForOpArgs( 1791 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 1792 Operator &resultOp = node.getDialectOp(opMap); 1793 1794 auto scope = os.scope(); 1795 os << formatv("::llvm::SmallVector<::mlir::Value, 4> " 1796 "tblgen_values; (void)tblgen_values;\n"); 1797 os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> " 1798 "tblgen_attrs; (void)tblgen_attrs;\n"); 1799 1800 const char *addAttrCmd = 1801 "if (auto tmpAttr = {1}) {\n" 1802 " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), " 1803 "tmpAttr);\n}\n"; 1804 int numVariadic = 0; 1805 bool hasOperandSegmentSizes = false; 1806 std::vector<std::string> sizes; 1807 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1808 if (isa<NamedAttribute *>(resultOp.getArg(argIndex))) { 1809 // The argument in the op definition. 1810 auto opArgName = resultOp.getArgName(argIndex); 1811 hasOperandSegmentSizes = 1812 hasOperandSegmentSizes || opArgName == "operandSegmentSizes"; 1813 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1814 if (!subTree.isNativeCodeCall()) 1815 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1816 "for creating attribute"); 1817 os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex)); 1818 } else { 1819 auto leaf = node.getArgAsLeaf(argIndex); 1820 // The argument in the result DAG pattern. 1821 auto patArgName = node.getArgName(argIndex); 1822 os << formatv(addAttrCmd, opArgName, 1823 handleOpArgument(leaf, patArgName)); 1824 } 1825 continue; 1826 } 1827 1828 const auto *operand = 1829 cast<NamedTypeConstraint *>(resultOp.getArg(argIndex)); 1830 std::string varName; 1831 if (operand->isVariadic()) { 1832 ++numVariadic; 1833 std::string range; 1834 if (node.isNestedDagArg(argIndex)) { 1835 range = childNodeNames.lookup(argIndex); 1836 } else { 1837 range = std::string(node.getArgName(argIndex)); 1838 } 1839 // Resolve the symbol for all range use so that we have a uniform way of 1840 // capturing the values. 1841 range = symbolInfoMap.getValueAndRangeUse(range); 1842 os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n", 1843 range); 1844 sizes.push_back(formatv("static_cast<int32_t>({0}.size())", range)); 1845 } else { 1846 sizes.emplace_back("1"); 1847 os << formatv("tblgen_values.push_back("); 1848 if (node.isNestedDagArg(argIndex)) { 1849 os << symbolInfoMap.getValueAndRangeUse( 1850 childNodeNames.lookup(argIndex)); 1851 } else { 1852 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1853 if (leaf.isConstantAttr()) 1854 // TODO: Use better location 1855 PrintFatalError( 1856 loc, 1857 "attribute found where value was expected, if attempting to use " 1858 "constant value, construct a constant op with given attribute " 1859 "instead"); 1860 1861 auto symbol = 1862 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1863 if (leaf.isNativeCodeCall()) { 1864 os << std::string( 1865 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1866 } else { 1867 os << symbol; 1868 } 1869 } 1870 os << ");\n"; 1871 } 1872 } 1873 1874 if (numVariadic > 1 && !hasOperandSegmentSizes) { 1875 // Only set size if it can't be computed. 1876 const auto *sameVariadicSize = 1877 resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize"); 1878 if (!sameVariadicSize) { 1879 const char *setSizes = R"( 1880 tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"), 1881 rewriter.getDenseI32ArrayAttr({{ {0} })); 1882 )"; 1883 os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str()); 1884 } 1885 } 1886 } 1887 1888 StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os, 1889 const RecordKeeper &records, 1890 RecordOperatorMap &mapper) 1891 : opMap(mapper), staticVerifierEmitter(os, records) {} 1892 1893 void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) { 1894 // PatternEmitter will use the static matcher if there's one generated. To 1895 // ensure that all the dependent static matchers are generated before emitting 1896 // the matching logic of the DagNode, we use topological order to achieve it. 1897 for (auto &dagInfo : topologicalOrder) { 1898 DagNode node = dagInfo.first; 1899 if (!useStaticMatcher(node)) 1900 continue; 1901 1902 std::string funcName = 1903 formatv("static_dag_matcher_{0}", staticMatcherCounter++); 1904 assert(!matcherNames.contains(node)); 1905 PatternEmitter(dagInfo.second, &opMap, os, *this) 1906 .emitStaticMatcher(node, funcName); 1907 matcherNames[node] = funcName; 1908 } 1909 } 1910 1911 void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) { 1912 staticVerifierEmitter.emitPatternConstraints(constraints.getArrayRef()); 1913 } 1914 1915 void StaticMatcherHelper::addPattern(const Record *record) { 1916 Pattern pat(record, &opMap); 1917 1918 // While generating the function body of the DAG matcher, it may depends on 1919 // other DAG matchers. To ensure the dependent matchers are ready, we compute 1920 // the topological order for all the DAGs and emit the DAG matchers in this 1921 // order. 1922 llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) { 1923 ++refStats[node]; 1924 1925 if (refStats[node] != 1) 1926 return; 1927 1928 for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i) 1929 if (DagNode sibling = node.getArgAsNestedDag(i)) 1930 dfs(sibling); 1931 else { 1932 DagLeaf leaf = node.getArgAsLeaf(i); 1933 if (!leaf.isUnspecified()) 1934 constraints.insert(leaf); 1935 } 1936 1937 topologicalOrder.push_back(std::make_pair(node, record)); 1938 }; 1939 1940 dfs(pat.getSourcePattern()); 1941 } 1942 1943 StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) { 1944 if (leaf.isAttrMatcher()) { 1945 std::optional<StringRef> constraint = 1946 staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint()); 1947 assert(constraint && "attribute constraint was not uniqued"); 1948 return *constraint; 1949 } 1950 assert(leaf.isOperandMatcher()); 1951 return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint()); 1952 } 1953 1954 static void emitRewriters(const RecordKeeper &records, raw_ostream &os) { 1955 emitSourceFileHeader("Rewriters", os, records); 1956 1957 auto patterns = records.getAllDerivedDefinitions("Pattern"); 1958 1959 // We put the map here because it can be shared among multiple patterns. 1960 RecordOperatorMap recordOpMap; 1961 1962 // Exam all the patterns and generate static matcher for the duplicated 1963 // DagNode. 1964 StaticMatcherHelper staticMatcher(os, records, recordOpMap); 1965 for (const Record *p : patterns) 1966 staticMatcher.addPattern(p); 1967 staticMatcher.populateStaticConstraintFunctions(os); 1968 staticMatcher.populateStaticMatchers(os); 1969 1970 std::vector<std::string> rewriterNames; 1971 rewriterNames.reserve(patterns.size()); 1972 1973 std::string baseRewriterName = "GeneratedConvert"; 1974 int rewriterIndex = 0; 1975 1976 for (const Record *p : patterns) { 1977 std::string name; 1978 if (p->isAnonymous()) { 1979 // If no name is provided, ensure unique rewriter names simply by 1980 // appending unique suffix. 1981 name = baseRewriterName + llvm::utostr(rewriterIndex++); 1982 } else { 1983 name = std::string(p->getName()); 1984 } 1985 LLVM_DEBUG(llvm::dbgs() 1986 << "=== start generating pattern '" << name << "' ===\n"); 1987 PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name); 1988 LLVM_DEBUG(llvm::dbgs() 1989 << "=== done generating pattern '" << name << "' ===\n"); 1990 rewriterNames.push_back(std::move(name)); 1991 } 1992 1993 // Emit function to add the generated matchers to the pattern list. 1994 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(" 1995 "::mlir::RewritePatternSet &patterns) {\n"; 1996 for (const auto &name : rewriterNames) { 1997 os << " patterns.add<" << name << ">(patterns.getContext());\n"; 1998 } 1999 os << "}\n"; 2000 } 2001 2002 static mlir::GenRegistration 2003 genRewriters("gen-rewriters", "Generate pattern rewriters", 2004 [](const RecordKeeper &records, raw_ostream &os) { 2005 emitRewriters(records, os); 2006 return false; 2007 }); 2008