1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Support/IndentedOstream.h" 14 #include "mlir/TableGen/Attribute.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/GenInfo.h" 17 #include "mlir/TableGen/Operator.h" 18 #include "mlir/TableGen/Pattern.h" 19 #include "mlir/TableGen/Predicate.h" 20 #include "mlir/TableGen/Type.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/StringSet.h" 23 #include "llvm/Support/CommandLine.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/FormatAdapters.h" 26 #include "llvm/Support/PrettyStackTrace.h" 27 #include "llvm/Support/Signals.h" 28 #include "llvm/TableGen/Error.h" 29 #include "llvm/TableGen/Main.h" 30 #include "llvm/TableGen/Record.h" 31 #include "llvm/TableGen/TableGenBackend.h" 32 33 using namespace mlir; 34 using namespace mlir::tblgen; 35 36 using llvm::formatv; 37 using llvm::Record; 38 using llvm::RecordKeeper; 39 40 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 41 42 namespace llvm { 43 template <> 44 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 45 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 46 raw_ostream &os, StringRef style) { 47 os << v.first << ":" << v.second; 48 } 49 }; 50 } // end namespace llvm 51 52 //===----------------------------------------------------------------------===// 53 // PatternEmitter 54 //===----------------------------------------------------------------------===// 55 56 namespace { 57 class PatternEmitter { 58 public: 59 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 60 61 // Emits the mlir::RewritePattern struct named `rewriteName`. 62 void emit(StringRef rewriteName); 63 64 private: 65 // Emits the code for matching ops. 66 void emitMatchLogic(DagNode tree, StringRef opName); 67 68 // Emits the code for rewriting ops. 69 void emitRewriteLogic(); 70 71 //===--------------------------------------------------------------------===// 72 // Match utilities 73 //===--------------------------------------------------------------------===// 74 75 // Emits C++ statements for matching the DAG structure. 76 void emitMatch(DagNode tree, StringRef name, int depth); 77 78 // Emits C++ statements for matching using a native code call. 79 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); 80 81 // Emits C++ statements for matching the op constrained by the given DAG 82 // `tree` returning the op's variable name. 83 void emitOpMatch(DagNode tree, StringRef opName, int depth); 84 85 // Emits C++ statements for matching the `argIndex`-th argument of the given 86 // DAG `tree` as an operand. operandIndex is the index in the DAG excluding 87 // the preceding attributes. 88 void emitOperandMatch(DagNode tree, StringRef opName, int argIndex, 89 int operandIndex, int depth); 90 91 // Emits C++ statements for matching the `argIndex`-th argument of the given 92 // DAG `tree` as an attribute. 93 void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, 94 int depth); 95 96 // Emits C++ for checking a match with a corresponding match failure 97 // diagnostic. 98 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, 99 const llvm::formatv_object_base &failureFmt); 100 101 // Emits C++ for checking a match with a corresponding match failure 102 // diagnostics. 103 void emitMatchCheck(StringRef opName, const std::string &matchStr, 104 const std::string &failureStr); 105 106 //===--------------------------------------------------------------------===// 107 // Rewrite utilities 108 //===--------------------------------------------------------------------===// 109 110 // The entry point for handling a result pattern rooted at `resultTree`. This 111 // method dispatches to concrete handlers according to `resultTree`'s kind and 112 // returns a symbol representing the whole value pack. Callers are expected to 113 // further resolve the symbol according to the specific use case. 114 // 115 // `depth` is the nesting level of `resultTree`; 0 means top-level result 116 // pattern. For top-level result pattern, `resultIndex` indicates which result 117 // of the matched root op this pattern is intended to replace, which can be 118 // used to deduce the result type of the op generated from this result 119 // pattern. 120 std::string handleResultPattern(DagNode resultTree, int resultIndex, 121 int depth); 122 123 // Emits the C++ statement to replace the matched DAG with a value built via 124 // calling native C++ code. 125 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); 126 127 // Returns the symbol of the old value serving as the replacement. 128 StringRef handleReplaceWithValue(DagNode tree); 129 130 // Returns the location value to use. 131 std::pair<bool, std::string> getLocation(DagNode tree); 132 133 // Returns the location value to use. 134 std::string handleLocationDirective(DagNode tree); 135 136 // Emits the C++ statement to build a new op out of the given DAG `tree` and 137 // returns the variable name that this op is assigned to. If the root op in 138 // DAG `tree` has a specified name, the created op will be assigned to a 139 // variable of the given name. Otherwise, a unique name will be used as the 140 // result value name. 141 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 142 143 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; 144 145 // Emits a local variable for each value and attribute to be used for creating 146 // an op. 147 void createSeparateLocalVarsForOpArgs(DagNode node, 148 ChildNodeIndexNameMap &childNodeNames); 149 150 // Emits the concrete arguments used to call an op's builder. 151 void supplyValuesForOpArgs(DagNode node, 152 const ChildNodeIndexNameMap &childNodeNames, 153 int depth); 154 155 // Emits the local variables for holding all values as a whole and all named 156 // attributes as a whole to be used for creating an op. 157 void createAggregateLocalVarsForOpArgs( 158 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); 159 160 // Returns the C++ expression to construct a constant attribute of the given 161 // `value` for the given attribute kind `attr`. 162 std::string handleConstantAttr(Attribute attr, StringRef value); 163 164 // Returns the C++ expression to build an argument from the given DAG `leaf`. 165 // `patArgName` is used to bound the argument to the source pattern. 166 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 167 168 //===--------------------------------------------------------------------===// 169 // General utilities 170 //===--------------------------------------------------------------------===// 171 172 // Collects all of the operations within the given dag tree. 173 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 174 175 // Returns a unique symbol for a local variable of the given `op`. 176 std::string getUniqueSymbol(const Operator *op); 177 178 //===--------------------------------------------------------------------===// 179 // Symbol utilities 180 //===--------------------------------------------------------------------===// 181 182 // Returns how many static values the given DAG `node` correspond to. 183 int getNodeValueCount(DagNode node); 184 185 private: 186 // Pattern instantiation location followed by the location of multiclass 187 // prototypes used. This is intended to be used as a whole to 188 // PrintFatalError() on errors. 189 ArrayRef<llvm::SMLoc> loc; 190 191 // Op's TableGen Record to wrapper object. 192 RecordOperatorMap *opMap; 193 194 // Handy wrapper for pattern being emitted. 195 Pattern pattern; 196 197 // Map for all bound symbols' info. 198 SymbolInfoMap symbolInfoMap; 199 200 // The next unused ID for newly created values. 201 unsigned nextValueId; 202 203 raw_indented_ostream os; 204 205 // Format contexts containing placeholder substitutions. 206 FmtContext fmtCtx; 207 208 // Number of op processed. 209 int opCounter = 0; 210 }; 211 } // end anonymous namespace 212 213 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 214 raw_ostream &os) 215 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 216 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 217 fmtCtx.withBuilder("rewriter"); 218 } 219 220 std::string PatternEmitter::handleConstantAttr(Attribute attr, 221 StringRef value) { 222 if (!attr.isConstBuildable()) 223 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 224 " does not have the 'constBuilderCall' field"); 225 226 // TODO: Verify the constants here 227 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value)); 228 } 229 230 // Helper function to match patterns. 231 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { 232 if (tree.isNativeCodeCall()) { 233 emitNativeCodeMatch(tree, name, depth); 234 return; 235 } 236 237 if (tree.isOperation()) { 238 emitOpMatch(tree, name, depth); 239 return; 240 } 241 242 PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match."); 243 } 244 245 // Helper function to match patterns. 246 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, 247 int depth) { 248 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: "); 249 LLVM_DEBUG(tree.print(llvm::dbgs())); 250 LLVM_DEBUG(llvm::dbgs() << '\n'); 251 252 // TODO(suderman): iterate through arguments, determine their types, output 253 // names. 254 SmallVector<std::string, 8> capture; 255 256 raw_indented_ostream::DelimitedScope scope(os); 257 258 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 259 std::string argName = formatv("arg{0}_{1}", depth, i); 260 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 261 os << "Value " << argName << ";\n"; 262 } else { 263 auto leaf = tree.getArgAsLeaf(i); 264 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { 265 os << "Attribute " << argName << ";\n"; 266 } else { 267 os << "Value " << argName << ";\n"; 268 } 269 } 270 271 capture.push_back(std::move(argName)); 272 } 273 274 bool hasLocationDirective; 275 std::string locToUse; 276 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 277 278 auto fmt = tree.getNativeCodeTemplate(); 279 if (fmt.count("$_self") != 1) 280 PrintFatalError(loc, "NativeCodeCall must have $_self as argument for " 281 "passing the defining Operation"); 282 283 auto nativeCodeCall = std::string(tgfmt( 284 fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), capture)); 285 286 emitMatchCheck(opName, formatv("!failed({0})", nativeCodeCall), 287 formatv("\"{0} return failure\"", nativeCodeCall)); 288 289 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 290 auto name = tree.getArgName(i); 291 if (!name.empty() && name != "_") { 292 os << formatv("{0} = {1};\n", name, capture[i]); 293 } 294 } 295 296 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 297 std::string argName = capture[i]; 298 299 // Handle nested DAG construct first 300 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 301 PrintFatalError( 302 loc, formatv("Matching nested tree in NativeCodecall not support for " 303 "{0} as arg {1}", 304 argName, i)); 305 } 306 307 DagLeaf leaf = tree.getArgAsLeaf(i); 308 309 // The parameter for native function doesn't bind any constraints. 310 if (leaf.isUnspecified()) 311 continue; 312 313 auto constraint = leaf.getAsConstraint(); 314 315 std::string self; 316 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) 317 self = argName; 318 else 319 self = formatv("{0}.getType()", argName); 320 emitMatchCheck( 321 opName, 322 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), 323 formatv("\"operand {0} of native code call '{1}' failed to satisfy " 324 "constraint: " 325 "'{2}'\"", 326 i, tree.getNativeCodeTemplate(), constraint.getSummary())); 327 } 328 329 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n"); 330 } 331 332 // Helper function to match patterns. 333 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { 334 Operator &op = tree.getDialectOp(opMap); 335 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 336 << op.getOperationName() << "' at depth " << depth 337 << '\n'); 338 339 std::string castedName = formatv("castedOp{0}", depth); 340 os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); " 341 "(void){0};\n", 342 castedName, opName, op.getQualCppClassName()); 343 344 // Skip the operand matching at depth 0 as the pattern rewriter already does. 345 if (depth != 0) 346 emitMatchCheck(opName, /*matchStr=*/castedName, 347 formatv("\"{0} is not {1} type\"", castedName, 348 op.getQualCppClassName())); 349 350 if (tree.getNumArgs() != op.getNumArgs()) 351 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 352 "pattern vs. {2} in definition", 353 op.getOperationName(), tree.getNumArgs(), 354 op.getNumArgs())); 355 356 // If the operand's name is set, set to that variable. 357 auto name = tree.getSymbol(); 358 if (!name.empty()) 359 os << formatv("{0} = {1};\n", name, castedName); 360 361 for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) { 362 auto opArg = op.getArg(i); 363 std::string argName = formatv("op{0}", depth + 1); 364 365 // Handle nested DAG construct first 366 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 367 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 368 if (operand->isVariableLength()) { 369 auto error = formatv("use nested DAG construct to match op {0}'s " 370 "variadic operand #{1} unsupported now", 371 op.getOperationName(), i); 372 PrintFatalError(loc, error); 373 } 374 } 375 os << "{\n"; 376 377 // Attributes don't count for getODSOperands. 378 // TODO: Operand is a Value, check if we should remove `getDefiningOp()`. 379 os.indent() << formatv( 380 "auto *{0} = " 381 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", 382 argName, castedName, nextOperand); 383 // Null check of operand's definingOp 384 emitMatchCheck(castedName, /*matchStr=*/argName, 385 formatv("\"Operand {0} of {1} has null definingOp\"", 386 nextOperand++, castedName)); 387 emitMatch(argTree, argName, depth + 1); 388 os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName); 389 os.unindent() << "}\n"; 390 continue; 391 } 392 393 // Next handle DAG leaf: operand or attribute 394 if (opArg.is<NamedTypeConstraint *>()) { 395 // emitOperandMatch's argument indexing counts attributes. 396 emitOperandMatch(tree, castedName, i, nextOperand, depth); 397 ++nextOperand; 398 } else if (opArg.is<NamedAttribute *>()) { 399 emitAttributeMatch(tree, opName, i, depth); 400 } else { 401 PrintFatalError(loc, "unhandled case when matching op"); 402 } 403 } 404 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 405 << op.getOperationName() << "' at depth " << depth 406 << '\n'); 407 } 408 409 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, 410 int argIndex, int operandIndex, 411 int depth) { 412 Operator &op = tree.getDialectOp(opMap); 413 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>(); 414 auto matcher = tree.getArgAsLeaf(argIndex); 415 416 // If a constraint is specified, we need to generate C++ statements to 417 // check the constraint. 418 if (!matcher.isUnspecified()) { 419 if (!matcher.isOperandMatcher()) { 420 PrintFatalError( 421 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 422 op.getOperationName(), argIndex + 1)); 423 } 424 425 // Only need to verify if the matcher's type is different from the one 426 // of op definition. 427 Constraint constraint = matcher.getAsConstraint(); 428 if (operand->constraint != constraint) { 429 if (operand->isVariableLength()) { 430 auto error = formatv( 431 "further constrain op {0}'s variadic operand #{1} unsupported now", 432 op.getOperationName(), argIndex); 433 PrintFatalError(loc, error); 434 } 435 auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()", 436 opName, operandIndex); 437 emitMatchCheck( 438 opName, 439 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), 440 formatv("\"operand {0} of op '{1}' failed to satisfy constraint: " 441 "'{2}'\"", 442 operand - op.operand_begin(), op.getOperationName(), 443 constraint.getSummary())); 444 } 445 } 446 447 // Capture the value 448 auto name = tree.getArgName(argIndex); 449 // `$_` is a special symbol to ignore op argument matching. 450 if (!name.empty() && name != "_") { 451 // We need to subtract the number of attributes before this operand to get 452 // the index in the operand list. 453 auto numPrevAttrs = std::count_if( 454 op.arg_begin(), op.arg_begin() + argIndex, 455 [](const Argument &arg) { return arg.is<NamedAttribute *>(); }); 456 457 auto res = symbolInfoMap.findBoundSymbol(name, tree, op, argIndex); 458 os << formatv("{0} = {1}.getODSOperands({2});\n", 459 res->second.getVarName(name), opName, 460 argIndex - numPrevAttrs); 461 } 462 } 463 464 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, 465 int argIndex, int depth) { 466 Operator &op = tree.getDialectOp(opMap); 467 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>(); 468 const auto &attr = namedAttr->attr; 469 470 os << "{\n"; 471 os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" 472 "(void)tblgen_attr;\n", 473 opName, attr.getStorageType(), namedAttr->name); 474 475 // TODO: This should use getter method to avoid duplication. 476 if (attr.hasDefaultValue()) { 477 os << "if (!tblgen_attr) tblgen_attr = " 478 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 479 attr.getDefaultValue())) 480 << ";\n"; 481 } else if (attr.isOptional()) { 482 // For a missing attribute that is optional according to definition, we 483 // should just capture a mlir::Attribute() to signal the missing state. 484 // That is precisely what getAttr() returns on missing attributes. 485 } else { 486 emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx), 487 formatv("\"expected op '{0}' to have attribute '{1}' " 488 "of type '{2}'\"", 489 op.getOperationName(), namedAttr->name, 490 attr.getStorageType())); 491 } 492 493 auto matcher = tree.getArgAsLeaf(argIndex); 494 if (!matcher.isUnspecified()) { 495 if (!matcher.isAttrMatcher()) { 496 PrintFatalError( 497 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 498 op.getOperationName(), argIndex + 1)); 499 } 500 501 // If a constraint is specified, we need to generate C++ statements to 502 // check the constraint. 503 emitMatchCheck( 504 opName, 505 tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")), 506 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " 507 "{2}\"", 508 op.getOperationName(), namedAttr->name, 509 matcher.getAsConstraint().getSummary())); 510 } 511 512 // Capture the value 513 auto name = tree.getArgName(argIndex); 514 // `$_` is a special symbol to ignore op argument matching. 515 if (!name.empty() && name != "_") { 516 os << formatv("{0} = tblgen_attr;\n", name); 517 } 518 519 os.unindent() << "}\n"; 520 } 521 522 void PatternEmitter::emitMatchCheck( 523 StringRef opName, const FmtObjectBase &matchFmt, 524 const llvm::formatv_object_base &failureFmt) { 525 emitMatchCheck(opName, matchFmt.str(), failureFmt.str()); 526 } 527 528 void PatternEmitter::emitMatchCheck(StringRef opName, 529 const std::string &matchStr, 530 const std::string &failureStr) { 531 532 os << "if (!(" << matchStr << "))"; 533 os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName 534 << ", [&](::mlir::Diagnostic &diag) {\n diag << " 535 << failureStr << ";\n});"; 536 } 537 538 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { 539 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 540 int depth = 0; 541 emitMatch(tree, opName, depth); 542 543 for (auto &appliedConstraint : pattern.getConstraints()) { 544 auto &constraint = appliedConstraint.constraint; 545 auto &entities = appliedConstraint.entities; 546 547 auto condition = constraint.getConditionTemplate(); 548 if (isa<TypeConstraint>(constraint)) { 549 auto self = formatv("({0}.getType())", 550 symbolInfoMap.getValueAndRangeUse(entities.front())); 551 emitMatchCheck( 552 opName, tgfmt(condition, &fmtCtx.withSelf(self.str())), 553 formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"", 554 entities.front(), constraint.getSummary())); 555 556 } else if (isa<AttrConstraint>(constraint)) { 557 PrintFatalError( 558 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 559 } else { 560 // TODO: replace formatv arguments with the exact specified 561 // args. 562 if (entities.size() > 4) { 563 PrintFatalError(loc, "only support up to 4-entity constraints now"); 564 } 565 SmallVector<std::string, 4> names; 566 int i = 0; 567 for (int e = entities.size(); i < e; ++i) 568 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 569 std::string self = appliedConstraint.self; 570 if (!self.empty()) 571 self = symbolInfoMap.getValueAndRangeUse(self); 572 for (; i < 4; ++i) 573 names.push_back("<unused>"); 574 emitMatchCheck(opName, 575 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 576 names[1], names[2], names[3]), 577 formatv("\"entities '{0}' failed to satisfy constraint: " 578 "{1}\"", 579 llvm::join(entities, ", "), 580 constraint.getSummary())); 581 } 582 } 583 584 // Some of the operands could be bound to the same symbol name, we need 585 // to enforce equality constraint on those. 586 // TODO: we should be able to emit equality checks early 587 // and short circuit unnecessary work if vars are not equal. 588 for (auto symbolInfoIt = symbolInfoMap.begin(); 589 symbolInfoIt != symbolInfoMap.end();) { 590 auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); 591 auto startRange = range.first; 592 auto endRange = range.second; 593 594 auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); 595 for (++startRange; startRange != endRange; ++startRange) { 596 auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); 597 emitMatchCheck( 598 opName, 599 formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), 600 formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, 601 secondOperand)); 602 } 603 604 symbolInfoIt = endRange; 605 } 606 607 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 608 } 609 610 void PatternEmitter::collectOps(DagNode tree, 611 llvm::SmallPtrSetImpl<const Operator *> &ops) { 612 // Check if this tree is an operation. 613 if (tree.isOperation()) { 614 const Operator &op = tree.getDialectOp(opMap); 615 LLVM_DEBUG(llvm::dbgs() 616 << "found operation " << op.getOperationName() << '\n'); 617 ops.insert(&op); 618 } 619 620 // Recurse the arguments of the tree. 621 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 622 if (auto child = tree.getArgAsNestedDag(i)) 623 collectOps(child, ops); 624 } 625 626 void PatternEmitter::emit(StringRef rewriteName) { 627 // Get the DAG tree for the source pattern. 628 DagNode sourceTree = pattern.getSourcePattern(); 629 630 const Operator &rootOp = pattern.getSourceRootOp(); 631 auto rootName = rootOp.getOperationName(); 632 633 // Collect the set of result operations. 634 llvm::SmallPtrSet<const Operator *, 4> resultOps; 635 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 636 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 637 collectOps(pattern.getResultPattern(i), resultOps); 638 } 639 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 640 641 // Emit RewritePattern for Pattern. 642 auto locs = pattern.getLocation(); 643 os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n", 644 make_range(locs.rbegin(), locs.rend())); 645 os << formatv(R"(struct {0} : public ::mlir::RewritePattern { 646 {0}(::mlir::MLIRContext *context) 647 : ::mlir::RewritePattern("{1}", {2}, context, {{)", 648 rewriteName, rootName, pattern.getBenefit()); 649 // Sort result operators by name. 650 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 651 resultOps.end()); 652 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 653 return lhs->getOperationName() < rhs->getOperationName(); 654 }); 655 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) { 656 os << '"' << op->getOperationName() << '"'; 657 }); 658 os << "}) {}\n"; 659 660 // Emit matchAndRewrite() function. 661 { 662 auto classScope = os.scope(); 663 os.reindent(R"( 664 ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, 665 ::mlir::PatternRewriter &rewriter) const override {)") 666 << '\n'; 667 { 668 auto functionScope = os.scope(); 669 670 // Register all symbols bound in the source pattern. 671 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 672 673 LLVM_DEBUG(llvm::dbgs() 674 << "start creating local variables for capturing matches\n"); 675 os << "// Variables for capturing values and attributes used while " 676 "creating ops\n"; 677 // Create local variables for storing the arguments and results bound 678 // to symbols. 679 for (const auto &symbolInfoPair : symbolInfoMap) { 680 const auto &symbol = symbolInfoPair.first; 681 const auto &info = symbolInfoPair.second; 682 683 os << info.getVarDecl(symbol); 684 } 685 // TODO: capture ops with consistent numbering so that it can be 686 // reused for fused loc. 687 os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n", 688 pattern.getSourcePattern().getNumOps()); 689 LLVM_DEBUG(llvm::dbgs() 690 << "done creating local variables for capturing matches\n"); 691 692 os << "// Match\n"; 693 os << "tblgen_ops[0] = op0;\n"; 694 emitMatchLogic(sourceTree, "op0"); 695 696 os << "\n// Rewrite\n"; 697 emitRewriteLogic(); 698 699 os << "return ::mlir::success();\n"; 700 } 701 os << "};\n"; 702 } 703 os << "};\n\n"; 704 } 705 706 void PatternEmitter::emitRewriteLogic() { 707 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 708 const Operator &rootOp = pattern.getSourceRootOp(); 709 int numExpectedResults = rootOp.getNumResults(); 710 int numResultPatterns = pattern.getNumResultPatterns(); 711 712 // First register all symbols bound to ops generated in result patterns. 713 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 714 715 // Only the last N static values generated are used to replace the matched 716 // root N-result op. We need to calculate the starting index (of the results 717 // of the matched op) each result pattern is to replace. 718 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 719 // If we don't need to replace any value at all, set the replacement starting 720 // index as the number of result patterns so we skip all of them when trying 721 // to replace the matched op's results. 722 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 723 for (int i = numResultPatterns - 1; i >= 0; --i) { 724 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 725 offsets[i] = offsets[i + 1] - numValues; 726 if (offsets[i] == 0) { 727 if (replStartIndex == -1) 728 replStartIndex = i; 729 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 730 auto error = formatv( 731 "cannot use the same multi-result op '{0}' to generate both " 732 "auxiliary values and values to be used for replacing the matched op", 733 pattern.getResultPattern(i).getSymbol()); 734 PrintFatalError(loc, error); 735 } 736 } 737 738 if (offsets.front() > 0) { 739 const char error[] = "no enough values generated to replace the matched op"; 740 PrintFatalError(loc, error); 741 } 742 743 os << "auto odsLoc = rewriter.getFusedLoc({"; 744 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 745 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 746 } 747 os << "}); (void)odsLoc;\n"; 748 749 // Process auxiliary result patterns. 750 for (int i = 0; i < replStartIndex; ++i) { 751 DagNode resultTree = pattern.getResultPattern(i); 752 auto val = handleResultPattern(resultTree, offsets[i], 0); 753 // Normal op creation will be streamed to `os` by the above call; but 754 // NativeCodeCall will only be materialized to `os` if it is used. Here 755 // we are handling auxiliary patterns so we want the side effect even if 756 // NativeCodeCall is not replacing matched root op's results. 757 if (resultTree.isNativeCodeCall()) 758 os << val << ";\n"; 759 } 760 761 if (numExpectedResults == 0) { 762 assert(replStartIndex >= numResultPatterns && 763 "invalid auxiliary vs. replacement pattern division!"); 764 // No result to replace. Just erase the op. 765 os << "rewriter.eraseOp(op0);\n"; 766 } else { 767 // Process replacement result patterns. 768 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n"; 769 for (int i = replStartIndex; i < numResultPatterns; ++i) { 770 DagNode resultTree = pattern.getResultPattern(i); 771 auto val = handleResultPattern(resultTree, offsets[i], 0); 772 os << "\n"; 773 // Resolve each symbol for all range use so that we can loop over them. 774 // We need an explicit cast to `SmallVector` to capture the cases where 775 // `{0}` resolves to an `Operation::result_range` as well as cases that 776 // are not iterable (e.g. vector that gets wrapped in additional braces by 777 // RewriterGen). 778 // TODO: Revisit the need for materializing a vector. 779 os << symbolInfoMap.getAllRangeUse( 780 val, 781 "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" 782 " tblgen_repl_values.push_back(v);\n}\n", 783 "\n"); 784 } 785 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n"; 786 } 787 788 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 789 } 790 791 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 792 return std::string( 793 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++)); 794 } 795 796 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 797 int resultIndex, int depth) { 798 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 799 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 800 LLVM_DEBUG(llvm::dbgs() << '\n'); 801 802 if (resultTree.isLocationDirective()) { 803 PrintFatalError(loc, 804 "location directive can only be used with op creation"); 805 } 806 807 if (resultTree.isNativeCodeCall()) { 808 auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth); 809 symbolInfoMap.bindValue(symbol); 810 return symbol; 811 } 812 813 if (resultTree.isReplaceWithValue()) 814 return handleReplaceWithValue(resultTree).str(); 815 816 // Normal op creation. 817 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 818 if (resultTree.getSymbol().empty()) { 819 // This is an op not explicitly bound to a symbol in the rewrite rule. 820 // Register the auto-generated symbol for it. 821 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 822 } 823 return symbol; 824 } 825 826 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { 827 assert(tree.isReplaceWithValue()); 828 829 if (tree.getNumArgs() != 1) { 830 PrintFatalError( 831 loc, "replaceWithValue directive must take exactly one argument"); 832 } 833 834 if (!tree.getSymbol().empty()) { 835 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 836 } 837 838 return tree.getArgName(0); 839 } 840 841 std::string PatternEmitter::handleLocationDirective(DagNode tree) { 842 assert(tree.isLocationDirective()); 843 auto lookUpArgLoc = [this, &tree](int idx) { 844 const auto *const lookupFmt = "(*{0}.begin()).getLoc()"; 845 return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt); 846 }; 847 848 if (tree.getNumArgs() == 0) 849 llvm::PrintFatalError( 850 "At least one argument to location directive required"); 851 852 if (!tree.getSymbol().empty()) 853 PrintFatalError(loc, "cannot bind symbol to location"); 854 855 if (tree.getNumArgs() == 1) { 856 DagLeaf leaf = tree.getArgAsLeaf(0); 857 if (leaf.isStringAttr()) 858 return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"))", 859 leaf.getStringAttr()) 860 .str(); 861 return lookUpArgLoc(0); 862 } 863 864 std::string ret; 865 llvm::raw_string_ostream os(ret); 866 std::string strAttr; 867 os << "rewriter.getFusedLoc({"; 868 bool first = true; 869 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 870 DagLeaf leaf = tree.getArgAsLeaf(i); 871 // Handle the optional string value. 872 if (leaf.isStringAttr()) { 873 if (!strAttr.empty()) 874 llvm::PrintFatalError("Only one string attribute may be specified"); 875 strAttr = leaf.getStringAttr(); 876 continue; 877 } 878 os << (first ? "" : ", ") << lookUpArgLoc(i); 879 first = false; 880 } 881 os << "}"; 882 if (!strAttr.empty()) { 883 os << ", rewriter.getStringAttr(\"" << strAttr << "\")"; 884 } 885 os << ")"; 886 return os.str(); 887 } 888 889 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 890 StringRef patArgName) { 891 if (leaf.isStringAttr()) 892 PrintFatalError(loc, "raw string not supported as argument"); 893 if (leaf.isConstantAttr()) { 894 auto constAttr = leaf.getAsConstantAttr(); 895 return handleConstantAttr(constAttr.getAttribute(), 896 constAttr.getConstantValue()); 897 } 898 if (leaf.isEnumAttrCase()) { 899 auto enumCase = leaf.getAsEnumAttrCase(); 900 if (enumCase.isStrCase()) 901 return handleConstantAttr(enumCase, enumCase.getSymbol()); 902 // This is an enum case backed by an IntegerAttr. We need to get its value 903 // to build the constant. 904 std::string val = std::to_string(enumCase.getValue()); 905 return handleConstantAttr(enumCase, val); 906 } 907 908 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 909 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 910 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 911 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 912 << "' (via symbol ref)\n"); 913 return argName; 914 } 915 if (leaf.isNativeCodeCall()) { 916 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 917 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 918 << "' (via NativeCodeCall)\n"); 919 return std::string(repl); 920 } 921 PrintFatalError(loc, "unhandled case when rewriting op"); 922 } 923 924 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, 925 int depth) { 926 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 927 LLVM_DEBUG(tree.print(llvm::dbgs())); 928 LLVM_DEBUG(llvm::dbgs() << '\n'); 929 930 auto fmt = tree.getNativeCodeTemplate(); 931 932 SmallVector<std::string, 16> attrs; 933 934 bool hasLocationDirective; 935 std::string locToUse; 936 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 937 938 for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { 939 if (tree.isNestedDagArg(i)) { 940 attrs.push_back( 941 handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1)); 942 } else { 943 attrs.push_back( 944 handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i))); 945 } 946 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i 947 << " replacement: " << attrs[i] << "\n"); 948 } 949 950 std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs); 951 if (!tree.getSymbol().empty()) { 952 os << formatv("auto {0} = {1};\n", tree.getSymbol(), symbol); 953 symbol = tree.getSymbol().str(); 954 } 955 956 return symbol; 957 } 958 959 int PatternEmitter::getNodeValueCount(DagNode node) { 960 if (node.isOperation()) { 961 // If the op is bound to a symbol in the rewrite rule, query its result 962 // count from the symbol info map. 963 auto symbol = node.getSymbol(); 964 if (!symbol.empty()) { 965 return symbolInfoMap.getStaticValueCount(symbol); 966 } 967 // Otherwise this is an unbound op; we will use all its results. 968 return pattern.getDialectOp(node).getNumResults(); 969 } 970 // TODO: This considers all NativeCodeCall as returning one 971 // value. Enhance if multi-value ones are needed. 972 return 1; 973 } 974 975 std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) { 976 auto numPatArgs = tree.getNumArgs(); 977 978 if (numPatArgs != 0) { 979 if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1)) 980 if (lastArg.isLocationDirective()) { 981 return std::make_pair(true, handleLocationDirective(lastArg)); 982 } 983 } 984 985 // If no explicit location is given, use the default, all fused, location. 986 return std::make_pair(false, "odsLoc"); 987 } 988 989 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 990 int depth) { 991 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); 992 LLVM_DEBUG(tree.print(llvm::dbgs())); 993 LLVM_DEBUG(llvm::dbgs() << '\n'); 994 995 Operator &resultOp = tree.getDialectOp(opMap); 996 auto numOpArgs = resultOp.getNumArgs(); 997 auto numPatArgs = tree.getNumArgs(); 998 999 bool hasLocationDirective; 1000 std::string locToUse; 1001 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 1002 1003 auto inPattern = numPatArgs - hasLocationDirective; 1004 if (numOpArgs != inPattern) { 1005 PrintFatalError(loc, 1006 formatv("resultant op '{0}' argument number mismatch: " 1007 "{1} in pattern vs. {2} in definition", 1008 resultOp.getOperationName(), inPattern, numOpArgs)); 1009 } 1010 1011 // A map to collect all nested DAG child nodes' names, with operand index as 1012 // the key. This includes both bound and unbound child nodes. 1013 ChildNodeIndexNameMap childNodeNames; 1014 1015 // First go through all the child nodes who are nested DAG constructs to 1016 // create ops for them and remember the symbol names for them, so that we can 1017 // use the results in the current node. This happens in a recursive manner. 1018 for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { 1019 if (auto child = tree.getArgAsNestedDag(i)) 1020 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 1021 } 1022 1023 // The name of the local variable holding this op. 1024 std::string valuePackName; 1025 // The symbol for holding the result of this pattern. Note that the result of 1026 // this pattern is not necessarily the same as the variable created by this 1027 // pattern because we can use `__N` suffix to refer only a specific result if 1028 // the generated op is a multi-result op. 1029 std::string resultValue; 1030 if (tree.getSymbol().empty()) { 1031 // No symbol is explicitly bound to this op in the pattern. Generate a 1032 // unique name. 1033 valuePackName = resultValue = getUniqueSymbol(&resultOp); 1034 } else { 1035 resultValue = std::string(tree.getSymbol()); 1036 // Strip the index to get the name for the value pack and use it to name the 1037 // local variable for the op. 1038 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue)); 1039 } 1040 1041 // Create the local variable for this op. 1042 os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(), 1043 valuePackName); 1044 1045 // Right now ODS don't have general type inference support. Except a few 1046 // special cases listed below, DRR needs to supply types for all results 1047 // when building an op. 1048 bool isSameOperandsAndResultType = 1049 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType"); 1050 bool useFirstAttr = 1051 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"); 1052 1053 if (isSameOperandsAndResultType || useFirstAttr) { 1054 // We know how to deduce the result type for ops with these traits and we've 1055 // generated builders taking aggregate parameters. Use those builders to 1056 // create the ops. 1057 1058 // First prepare local variables for op arguments used in builder call. 1059 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 1060 1061 // Then create the op. 1062 os.scope("", "\n}\n").os << formatv( 1063 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);", 1064 valuePackName, resultOp.getQualCppClassName(), locToUse); 1065 return resultValue; 1066 } 1067 1068 bool usePartialResults = valuePackName != resultValue; 1069 1070 if (usePartialResults || depth > 0 || resultIndex < 0) { 1071 // For these cases (broadcastable ops, op results used both as auxiliary 1072 // values and replacement values, ops in nested patterns, auxiliary ops), we 1073 // still need to supply the result types when building the op. But because 1074 // we don't generate a builder automatically with ODS for them, it's the 1075 // developer's responsibility to make sure such a builder (with result type 1076 // deduction ability) exists. We go through the separate-parameter builder 1077 // here given that it's easier for developers to write compared to 1078 // aggregate-parameter builders. 1079 createSeparateLocalVarsForOpArgs(tree, childNodeNames); 1080 1081 os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, 1082 resultOp.getQualCppClassName(), locToUse); 1083 supplyValuesForOpArgs(tree, childNodeNames, depth); 1084 os << "\n );\n}\n"; 1085 return resultValue; 1086 } 1087 1088 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 1089 // generated from the source pattern root op. Then we can use the source 1090 // pattern's value types to determine the value type of the generated op 1091 // here. 1092 1093 // First prepare local variables for op arguments used in builder call. 1094 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 1095 1096 // Then prepare the result types. We need to specify the types for all 1097 // results. 1098 os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; " 1099 "(void)tblgen_types;\n"); 1100 int numResults = resultOp.getNumResults(); 1101 if (numResults != 0) { 1102 for (int i = 0; i < numResults; ++i) 1103 os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" 1104 " tblgen_types.push_back(v.getType());\n}\n", 1105 resultIndex + i); 1106 } 1107 os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " 1108 "tblgen_values, tblgen_attrs);\n", 1109 valuePackName, resultOp.getQualCppClassName(), locToUse); 1110 os.unindent() << "}\n"; 1111 return resultValue; 1112 } 1113 1114 void PatternEmitter::createSeparateLocalVarsForOpArgs( 1115 DagNode node, ChildNodeIndexNameMap &childNodeNames) { 1116 Operator &resultOp = node.getDialectOp(opMap); 1117 1118 // Now prepare operands used for building this op: 1119 // * If the operand is non-variadic, we create a `Value` local variable. 1120 // * If the operand is variadic, we create a `SmallVector<Value>` local 1121 // variable. 1122 1123 int valueIndex = 0; // An index for uniquing local variable names. 1124 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1125 const auto *operand = 1126 resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>(); 1127 // We do not need special handling for attributes. 1128 if (!operand) 1129 continue; 1130 1131 raw_indented_ostream::DelimitedScope scope(os); 1132 std::string varName; 1133 if (operand->isVariadic()) { 1134 varName = std::string(formatv("tblgen_values_{0}", valueIndex++)); 1135 os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName); 1136 std::string range; 1137 if (node.isNestedDagArg(argIndex)) { 1138 range = childNodeNames[argIndex]; 1139 } else { 1140 range = std::string(node.getArgName(argIndex)); 1141 } 1142 // Resolve the symbol for all range use so that we have a uniform way of 1143 // capturing the values. 1144 range = symbolInfoMap.getValueAndRangeUse(range); 1145 os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range, 1146 varName); 1147 } else { 1148 varName = std::string(formatv("tblgen_value_{0}", valueIndex++)); 1149 os << formatv("::mlir::Value {0} = ", varName); 1150 if (node.isNestedDagArg(argIndex)) { 1151 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 1152 } else { 1153 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1154 auto symbol = 1155 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1156 if (leaf.isNativeCodeCall()) { 1157 os << std::string( 1158 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1159 } else { 1160 os << symbol; 1161 } 1162 } 1163 os << ";\n"; 1164 } 1165 1166 // Update to use the newly created local variable for building the op later. 1167 childNodeNames[argIndex] = varName; 1168 } 1169 } 1170 1171 void PatternEmitter::supplyValuesForOpArgs( 1172 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 1173 Operator &resultOp = node.getDialectOp(opMap); 1174 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); 1175 argIndex != numOpArgs; ++argIndex) { 1176 // Start each argument on its own line. 1177 os << ",\n "; 1178 1179 Argument opArg = resultOp.getArg(argIndex); 1180 // Handle the case of operand first. 1181 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 1182 if (!operand->name.empty()) 1183 os << "/*" << operand->name << "=*/"; 1184 os << childNodeNames.lookup(argIndex); 1185 continue; 1186 } 1187 1188 // The argument in the op definition. 1189 auto opArgName = resultOp.getArgName(argIndex); 1190 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1191 if (!subTree.isNativeCodeCall()) 1192 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1193 "for creating attribute"); 1194 os << formatv("/*{0}=*/{1}", opArgName, 1195 handleReplaceWithNativeCodeCall(subTree, depth)); 1196 } else { 1197 auto leaf = node.getArgAsLeaf(argIndex); 1198 // The argument in the result DAG pattern. 1199 auto patArgName = node.getArgName(argIndex); 1200 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 1201 // TODO: Refactor out into map to avoid recomputing these. 1202 if (!opArg.is<NamedAttribute *>()) 1203 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 1204 if (!patArgName.empty()) 1205 os << "/*" << patArgName << "=*/"; 1206 } else { 1207 os << "/*" << opArgName << "=*/"; 1208 } 1209 os << handleOpArgument(leaf, patArgName); 1210 } 1211 } 1212 } 1213 1214 void PatternEmitter::createAggregateLocalVarsForOpArgs( 1215 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 1216 Operator &resultOp = node.getDialectOp(opMap); 1217 1218 auto scope = os.scope(); 1219 os << formatv("::mlir::SmallVector<::mlir::Value, 4> " 1220 "tblgen_values; (void)tblgen_values;\n"); 1221 os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> " 1222 "tblgen_attrs; (void)tblgen_attrs;\n"); 1223 1224 const char *addAttrCmd = 1225 "if (auto tmpAttr = {1}) {\n" 1226 " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), " 1227 "tmpAttr);\n}\n"; 1228 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1229 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) { 1230 // The argument in the op definition. 1231 auto opArgName = resultOp.getArgName(argIndex); 1232 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1233 if (!subTree.isNativeCodeCall()) 1234 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1235 "for creating attribute"); 1236 os << formatv(addAttrCmd, opArgName, 1237 handleReplaceWithNativeCodeCall(subTree, depth + 1)); 1238 } else { 1239 auto leaf = node.getArgAsLeaf(argIndex); 1240 // The argument in the result DAG pattern. 1241 auto patArgName = node.getArgName(argIndex); 1242 os << formatv(addAttrCmd, opArgName, 1243 handleOpArgument(leaf, patArgName)); 1244 } 1245 continue; 1246 } 1247 1248 const auto *operand = 1249 resultOp.getArg(argIndex).get<NamedTypeConstraint *>(); 1250 std::string varName; 1251 if (operand->isVariadic()) { 1252 std::string range; 1253 if (node.isNestedDagArg(argIndex)) { 1254 range = childNodeNames.lookup(argIndex); 1255 } else { 1256 range = std::string(node.getArgName(argIndex)); 1257 } 1258 // Resolve the symbol for all range use so that we have a uniform way of 1259 // capturing the values. 1260 range = symbolInfoMap.getValueAndRangeUse(range); 1261 os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n", 1262 range); 1263 } else { 1264 os << formatv("tblgen_values.push_back("); 1265 if (node.isNestedDagArg(argIndex)) { 1266 os << symbolInfoMap.getValueAndRangeUse( 1267 childNodeNames.lookup(argIndex)); 1268 } else { 1269 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1270 auto symbol = 1271 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1272 if (leaf.isNativeCodeCall()) { 1273 os << std::string( 1274 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1275 } else { 1276 os << symbol; 1277 } 1278 } 1279 os << ");\n"; 1280 } 1281 } 1282 } 1283 1284 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 1285 emitSourceFileHeader("Rewriters", os); 1286 1287 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 1288 auto numPatterns = patterns.size(); 1289 1290 // We put the map here because it can be shared among multiple patterns. 1291 RecordOperatorMap recordOpMap; 1292 1293 std::vector<std::string> rewriterNames; 1294 rewriterNames.reserve(numPatterns); 1295 1296 std::string baseRewriterName = "GeneratedConvert"; 1297 int rewriterIndex = 0; 1298 1299 for (Record *p : patterns) { 1300 std::string name; 1301 if (p->isAnonymous()) { 1302 // If no name is provided, ensure unique rewriter names simply by 1303 // appending unique suffix. 1304 name = baseRewriterName + llvm::utostr(rewriterIndex++); 1305 } else { 1306 name = std::string(p->getName()); 1307 } 1308 LLVM_DEBUG(llvm::dbgs() 1309 << "=== start generating pattern '" << name << "' ===\n"); 1310 PatternEmitter(p, &recordOpMap, os).emit(name); 1311 LLVM_DEBUG(llvm::dbgs() 1312 << "=== done generating pattern '" << name << "' ===\n"); 1313 rewriterNames.push_back(std::move(name)); 1314 } 1315 1316 // Emit function to add the generated matchers to the pattern list. 1317 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(" 1318 "::mlir::RewritePatternSet &patterns) {\n"; 1319 for (const auto &name : rewriterNames) { 1320 os << " patterns.add<" << name << ">(patterns.getContext());\n"; 1321 } 1322 os << "}\n"; 1323 } 1324 1325 static mlir::GenRegistration 1326 genRewriters("gen-rewriters", "Generate pattern rewriters", 1327 [](const RecordKeeper &records, raw_ostream &os) { 1328 emitRewriters(records, os); 1329 return false; 1330 }); 1331