1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Support/IndentedOstream.h" 14 #include "mlir/TableGen/Attribute.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/GenInfo.h" 17 #include "mlir/TableGen/Operator.h" 18 #include "mlir/TableGen/Pattern.h" 19 #include "mlir/TableGen/Predicate.h" 20 #include "mlir/TableGen/Type.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/StringSet.h" 23 #include "llvm/Support/CommandLine.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/FormatAdapters.h" 26 #include "llvm/Support/PrettyStackTrace.h" 27 #include "llvm/Support/Signals.h" 28 #include "llvm/TableGen/Error.h" 29 #include "llvm/TableGen/Main.h" 30 #include "llvm/TableGen/Record.h" 31 #include "llvm/TableGen/TableGenBackend.h" 32 33 using namespace mlir; 34 using namespace mlir::tblgen; 35 36 using llvm::formatv; 37 using llvm::Record; 38 using llvm::RecordKeeper; 39 40 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 41 42 namespace llvm { 43 template <> 44 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 45 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 46 raw_ostream &os, StringRef style) { 47 os << v.first << ":" << v.second; 48 } 49 }; 50 } // end namespace llvm 51 52 //===----------------------------------------------------------------------===// 53 // PatternEmitter 54 //===----------------------------------------------------------------------===// 55 56 namespace { 57 class PatternEmitter { 58 public: 59 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 60 61 // Emits the mlir::RewritePattern struct named `rewriteName`. 62 void emit(StringRef rewriteName); 63 64 private: 65 // Emits the code for matching ops. 66 void emitMatchLogic(DagNode tree, StringRef opName); 67 68 // Emits the code for rewriting ops. 69 void emitRewriteLogic(); 70 71 //===--------------------------------------------------------------------===// 72 // Match utilities 73 //===--------------------------------------------------------------------===// 74 75 // Emits C++ statements for matching the DAG structure. 76 void emitMatch(DagNode tree, StringRef name, int depth); 77 78 // Emits C++ statements for matching using a native code call. 79 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); 80 81 // Emits C++ statements for matching the op constrained by the given DAG 82 // `tree` returning the op's variable name. 83 void emitOpMatch(DagNode tree, StringRef opName, int depth); 84 85 // Emits C++ statements for matching the `argIndex`-th argument of the given 86 // DAG `tree` as an operand. 87 void emitOperandMatch(DagNode tree, StringRef opName, int argIndex, 88 int depth); 89 90 // Emits C++ statements for matching the `argIndex`-th argument of the given 91 // DAG `tree` as an attribute. 92 void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, 93 int depth); 94 95 // Emits C++ for checking a match with a corresponding match failure 96 // diagnostic. 97 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, 98 const llvm::formatv_object_base &failureFmt); 99 100 // Emits C++ for checking a match with a corresponding match failure 101 // diagnostics. 102 void emitMatchCheck(StringRef opName, const std::string &matchStr, 103 const std::string &failureStr); 104 105 //===--------------------------------------------------------------------===// 106 // Rewrite utilities 107 //===--------------------------------------------------------------------===// 108 109 // The entry point for handling a result pattern rooted at `resultTree`. This 110 // method dispatches to concrete handlers according to `resultTree`'s kind and 111 // returns a symbol representing the whole value pack. Callers are expected to 112 // further resolve the symbol according to the specific use case. 113 // 114 // `depth` is the nesting level of `resultTree`; 0 means top-level result 115 // pattern. For top-level result pattern, `resultIndex` indicates which result 116 // of the matched root op this pattern is intended to replace, which can be 117 // used to deduce the result type of the op generated from this result 118 // pattern. 119 std::string handleResultPattern(DagNode resultTree, int resultIndex, 120 int depth); 121 122 // Emits the C++ statement to replace the matched DAG with a value built via 123 // calling native C++ code. 124 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); 125 126 // Returns the symbol of the old value serving as the replacement. 127 StringRef handleReplaceWithValue(DagNode tree); 128 129 // Returns the location value to use. 130 std::pair<bool, std::string> getLocation(DagNode tree); 131 132 // Returns the location value to use. 133 std::string handleLocationDirective(DagNode tree); 134 135 // Emits the C++ statement to build a new op out of the given DAG `tree` and 136 // returns the variable name that this op is assigned to. If the root op in 137 // DAG `tree` has a specified name, the created op will be assigned to a 138 // variable of the given name. Otherwise, a unique name will be used as the 139 // result value name. 140 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 141 142 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; 143 144 // Emits a local variable for each value and attribute to be used for creating 145 // an op. 146 void createSeparateLocalVarsForOpArgs(DagNode node, 147 ChildNodeIndexNameMap &childNodeNames); 148 149 // Emits the concrete arguments used to call an op's builder. 150 void supplyValuesForOpArgs(DagNode node, 151 const ChildNodeIndexNameMap &childNodeNames, 152 int depth); 153 154 // Emits the local variables for holding all values as a whole and all named 155 // attributes as a whole to be used for creating an op. 156 void createAggregateLocalVarsForOpArgs( 157 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); 158 159 // Returns the C++ expression to construct a constant attribute of the given 160 // `value` for the given attribute kind `attr`. 161 std::string handleConstantAttr(Attribute attr, StringRef value); 162 163 // Returns the C++ expression to build an argument from the given DAG `leaf`. 164 // `patArgName` is used to bound the argument to the source pattern. 165 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 166 167 //===--------------------------------------------------------------------===// 168 // General utilities 169 //===--------------------------------------------------------------------===// 170 171 // Collects all of the operations within the given dag tree. 172 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 173 174 // Returns a unique symbol for a local variable of the given `op`. 175 std::string getUniqueSymbol(const Operator *op); 176 177 //===--------------------------------------------------------------------===// 178 // Symbol utilities 179 //===--------------------------------------------------------------------===// 180 181 // Returns how many static values the given DAG `node` correspond to. 182 int getNodeValueCount(DagNode node); 183 184 private: 185 // Pattern instantiation location followed by the location of multiclass 186 // prototypes used. This is intended to be used as a whole to 187 // PrintFatalError() on errors. 188 ArrayRef<llvm::SMLoc> loc; 189 190 // Op's TableGen Record to wrapper object. 191 RecordOperatorMap *opMap; 192 193 // Handy wrapper for pattern being emitted. 194 Pattern pattern; 195 196 // Map for all bound symbols' info. 197 SymbolInfoMap symbolInfoMap; 198 199 // The next unused ID for newly created values. 200 unsigned nextValueId; 201 202 raw_indented_ostream os; 203 204 // Format contexts containing placeholder substitutions. 205 FmtContext fmtCtx; 206 207 // Number of op processed. 208 int opCounter = 0; 209 }; 210 } // end anonymous namespace 211 212 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 213 raw_ostream &os) 214 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 215 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 216 fmtCtx.withBuilder("rewriter"); 217 } 218 219 std::string PatternEmitter::handleConstantAttr(Attribute attr, 220 StringRef value) { 221 if (!attr.isConstBuildable()) 222 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 223 " does not have the 'constBuilderCall' field"); 224 225 // TODO: Verify the constants here 226 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value)); 227 } 228 229 // Helper function to match patterns. 230 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { 231 if (tree.isNativeCodeCall()) { 232 emitNativeCodeMatch(tree, name, depth); 233 return; 234 } 235 236 if (tree.isOperation()) { 237 emitOpMatch(tree, name, depth); 238 return; 239 } 240 241 PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match."); 242 } 243 244 // Helper function to match patterns. 245 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, 246 int depth) { 247 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: "); 248 LLVM_DEBUG(tree.print(llvm::dbgs())); 249 LLVM_DEBUG(llvm::dbgs() << '\n'); 250 251 // TODO(suderman): iterate through arguments, determine their types, output 252 // names. 253 SmallVector<std::string, 8> capture(8); 254 if (tree.getNumArgs() > 8) { 255 PrintFatalError(loc, 256 "unsupported NativeCodeCall matcher argument numbers: " + 257 Twine(tree.getNumArgs())); 258 } 259 260 raw_indented_ostream::DelimitedScope scope(os); 261 262 os << "if(!" << opName << ") return failure();\n"; 263 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 264 std::string argName = formatv("arg{0}_{1}", depth, i); 265 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 266 os << "Value " << argName << ";\n"; 267 } else { 268 auto leaf = tree.getArgAsLeaf(i); 269 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { 270 os << "Attribute " << argName << ";\n"; 271 } else if (leaf.isOperandMatcher()) { 272 os << "Operation " << argName << ";\n"; 273 } 274 } 275 276 capture[i] = std::move(argName); 277 } 278 279 bool hasLocationDirective; 280 std::string locToUse; 281 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 282 283 auto fmt = tree.getNativeCodeTemplate(); 284 auto nativeCodeCall = std::string(tgfmt( 285 fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1], 286 capture[2], capture[3], capture[4], capture[5], capture[6], capture[7])); 287 288 os << "if (failed(" << nativeCodeCall << ")) return failure();\n"; 289 290 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 291 auto name = tree.getArgName(i); 292 if (!name.empty() && name != "_") { 293 os << formatv("{0} = {1};\n", name, capture[i]); 294 } 295 } 296 297 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 298 std::string argName = capture[i]; 299 300 // Handle nested DAG construct first 301 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 302 PrintFatalError( 303 loc, formatv("Matching nested tree in NativeCodecall not support for " 304 "{0} as arg {1}", 305 argName, i)); 306 } 307 308 DagLeaf leaf = tree.getArgAsLeaf(i); 309 auto constraint = leaf.getAsConstraint(); 310 311 auto self = formatv("{0}", argName); 312 emitMatchCheck( 313 opName, 314 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), 315 formatv("\"operand {0} of native code call '{1}' failed to satisfy " 316 "constraint: " 317 "'{2}'\"", 318 i, tree.getNativeCodeTemplate(), constraint.getSummary())); 319 } 320 321 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n"); 322 } 323 324 // Helper function to match patterns. 325 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { 326 Operator &op = tree.getDialectOp(opMap); 327 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 328 << op.getOperationName() << "' at depth " << depth 329 << '\n'); 330 331 std::string castedName = formatv("castedOp{0}", depth); 332 os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); " 333 "(void){0};\n", 334 castedName, opName, op.getQualCppClassName()); 335 // Skip the operand matching at depth 0 as the pattern rewriter already does. 336 if (depth != 0) { 337 // Skip if there is no defining operation (e.g., arguments to function). 338 os << formatv("if (!{0}) return failure();\n", castedName); 339 } 340 if (tree.getNumArgs() != op.getNumArgs()) { 341 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 342 "pattern vs. {2} in definition", 343 op.getOperationName(), tree.getNumArgs(), 344 op.getNumArgs())); 345 } 346 347 // If the operand's name is set, set to that variable. 348 auto name = tree.getSymbol(); 349 if (!name.empty()) 350 os << formatv("{0} = {1};\n", name, castedName); 351 352 for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) { 353 auto opArg = op.getArg(i); 354 std::string argName = formatv("op{0}", depth + 1); 355 356 // Handle nested DAG construct first 357 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 358 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 359 if (operand->isVariableLength()) { 360 auto error = formatv("use nested DAG construct to match op {0}'s " 361 "variadic operand #{1} unsupported now", 362 op.getOperationName(), i); 363 PrintFatalError(loc, error); 364 } 365 } 366 os << "{\n"; 367 368 // Attributes don't count for getODSOperands. 369 os.indent() << formatv( 370 "auto *{0} = " 371 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", 372 argName, castedName, nextOperand++); 373 emitMatch(argTree, argName, depth + 1); 374 os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName); 375 os.unindent() << "}\n"; 376 continue; 377 } 378 379 // Next handle DAG leaf: operand or attribute 380 if (opArg.is<NamedTypeConstraint *>()) { 381 // emitOperandMatch's argument indexing counts attributes. 382 emitOperandMatch(tree, castedName, i, depth); 383 ++nextOperand; 384 } else if (opArg.is<NamedAttribute *>()) { 385 emitAttributeMatch(tree, opName, i, depth); 386 } else { 387 PrintFatalError(loc, "unhandled case when matching op"); 388 } 389 } 390 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 391 << op.getOperationName() << "' at depth " << depth 392 << '\n'); 393 } 394 395 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, 396 int argIndex, int depth) { 397 Operator &op = tree.getDialectOp(opMap); 398 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>(); 399 auto matcher = tree.getArgAsLeaf(argIndex); 400 401 // If a constraint is specified, we need to generate C++ statements to 402 // check the constraint. 403 if (!matcher.isUnspecified()) { 404 if (!matcher.isOperandMatcher()) { 405 PrintFatalError( 406 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 407 op.getOperationName(), argIndex + 1)); 408 } 409 410 // Only need to verify if the matcher's type is different from the one 411 // of op definition. 412 Constraint constraint = matcher.getAsConstraint(); 413 if (operand->constraint != constraint) { 414 if (operand->isVariableLength()) { 415 auto error = formatv( 416 "further constrain op {0}'s variadic operand #{1} unsupported now", 417 op.getOperationName(), argIndex); 418 PrintFatalError(loc, error); 419 } 420 auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()", 421 opName, argIndex); 422 emitMatchCheck( 423 opName, 424 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), 425 formatv("\"operand {0} of op '{1}' failed to satisfy constraint: " 426 "'{2}'\"", 427 operand - op.operand_begin(), op.getOperationName(), 428 constraint.getSummary())); 429 } 430 } 431 432 // Capture the value 433 auto name = tree.getArgName(argIndex); 434 // `$_` is a special symbol to ignore op argument matching. 435 if (!name.empty() && name != "_") { 436 // We need to subtract the number of attributes before this operand to get 437 // the index in the operand list. 438 auto numPrevAttrs = std::count_if( 439 op.arg_begin(), op.arg_begin() + argIndex, 440 [](const Argument &arg) { return arg.is<NamedAttribute *>(); }); 441 442 auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex); 443 os << formatv("{0} = {1}.getODSOperands({2});\n", 444 res->second.getVarName(name), opName, 445 argIndex - numPrevAttrs); 446 } 447 } 448 449 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, 450 int argIndex, int depth) { 451 Operator &op = tree.getDialectOp(opMap); 452 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>(); 453 const auto &attr = namedAttr->attr; 454 455 os << "{\n"; 456 os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" 457 "(void)tblgen_attr;\n", 458 opName, attr.getStorageType(), namedAttr->name); 459 460 // TODO: This should use getter method to avoid duplication. 461 if (attr.hasDefaultValue()) { 462 os << "if (!tblgen_attr) tblgen_attr = " 463 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 464 attr.getDefaultValue())) 465 << ";\n"; 466 } else if (attr.isOptional()) { 467 // For a missing attribute that is optional according to definition, we 468 // should just capture a mlir::Attribute() to signal the missing state. 469 // That is precisely what getAttr() returns on missing attributes. 470 } else { 471 emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx), 472 formatv("\"expected op '{0}' to have attribute '{1}' " 473 "of type '{2}'\"", 474 op.getOperationName(), namedAttr->name, 475 attr.getStorageType())); 476 } 477 478 auto matcher = tree.getArgAsLeaf(argIndex); 479 if (!matcher.isUnspecified()) { 480 if (!matcher.isAttrMatcher()) { 481 PrintFatalError( 482 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 483 op.getOperationName(), argIndex + 1)); 484 } 485 486 // If a constraint is specified, we need to generate C++ statements to 487 // check the constraint. 488 emitMatchCheck( 489 opName, 490 tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")), 491 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " 492 "{2}\"", 493 op.getOperationName(), namedAttr->name, 494 matcher.getAsConstraint().getSummary())); 495 } 496 497 // Capture the value 498 auto name = tree.getArgName(argIndex); 499 // `$_` is a special symbol to ignore op argument matching. 500 if (!name.empty() && name != "_") { 501 os << formatv("{0} = tblgen_attr;\n", name); 502 } 503 504 os.unindent() << "}\n"; 505 } 506 507 void PatternEmitter::emitMatchCheck( 508 StringRef opName, const FmtObjectBase &matchFmt, 509 const llvm::formatv_object_base &failureFmt) { 510 emitMatchCheck(opName, matchFmt.str(), failureFmt.str()); 511 } 512 513 void PatternEmitter::emitMatchCheck(StringRef opName, 514 const std::string &matchStr, 515 const std::string &failureStr) { 516 517 os << "if (!(" << matchStr << "))"; 518 os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName 519 << ", [&](::mlir::Diagnostic &diag) {\n diag << " 520 << failureStr << ";\n});"; 521 } 522 523 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { 524 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 525 int depth = 0; 526 emitMatch(tree, opName, depth); 527 528 for (auto &appliedConstraint : pattern.getConstraints()) { 529 auto &constraint = appliedConstraint.constraint; 530 auto &entities = appliedConstraint.entities; 531 532 auto condition = constraint.getConditionTemplate(); 533 if (isa<TypeConstraint>(constraint)) { 534 auto self = formatv("({0}.getType())", 535 symbolInfoMap.getValueAndRangeUse(entities.front())); 536 emitMatchCheck( 537 opName, tgfmt(condition, &fmtCtx.withSelf(self.str())), 538 formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"", 539 entities.front(), constraint.getSummary())); 540 541 } else if (isa<AttrConstraint>(constraint)) { 542 PrintFatalError( 543 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 544 } else { 545 // TODO: replace formatv arguments with the exact specified 546 // args. 547 if (entities.size() > 4) { 548 PrintFatalError(loc, "only support up to 4-entity constraints now"); 549 } 550 SmallVector<std::string, 4> names; 551 int i = 0; 552 for (int e = entities.size(); i < e; ++i) 553 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 554 std::string self = appliedConstraint.self; 555 if (!self.empty()) 556 self = symbolInfoMap.getValueAndRangeUse(self); 557 for (; i < 4; ++i) 558 names.push_back("<unused>"); 559 emitMatchCheck(opName, 560 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 561 names[1], names[2], names[3]), 562 formatv("\"entities '{0}' failed to satisfy constraint: " 563 "{1}\"", 564 llvm::join(entities, ", "), 565 constraint.getSummary())); 566 } 567 } 568 569 // Some of the operands could be bound to the same symbol name, we need 570 // to enforce equality constraint on those. 571 // TODO: we should be able to emit equality checks early 572 // and short circuit unnecessary work if vars are not equal. 573 for (auto symbolInfoIt = symbolInfoMap.begin(); 574 symbolInfoIt != symbolInfoMap.end();) { 575 auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); 576 auto startRange = range.first; 577 auto endRange = range.second; 578 579 auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); 580 for (++startRange; startRange != endRange; ++startRange) { 581 auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); 582 emitMatchCheck( 583 opName, 584 formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), 585 formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, 586 secondOperand)); 587 } 588 589 symbolInfoIt = endRange; 590 } 591 592 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 593 } 594 595 void PatternEmitter::collectOps(DagNode tree, 596 llvm::SmallPtrSetImpl<const Operator *> &ops) { 597 // Check if this tree is an operation. 598 if (tree.isOperation()) { 599 const Operator &op = tree.getDialectOp(opMap); 600 LLVM_DEBUG(llvm::dbgs() 601 << "found operation " << op.getOperationName() << '\n'); 602 ops.insert(&op); 603 } 604 605 // Recurse the arguments of the tree. 606 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 607 if (auto child = tree.getArgAsNestedDag(i)) 608 collectOps(child, ops); 609 } 610 611 void PatternEmitter::emit(StringRef rewriteName) { 612 // Get the DAG tree for the source pattern. 613 DagNode sourceTree = pattern.getSourcePattern(); 614 615 const Operator &rootOp = pattern.getSourceRootOp(); 616 auto rootName = rootOp.getOperationName(); 617 618 // Collect the set of result operations. 619 llvm::SmallPtrSet<const Operator *, 4> resultOps; 620 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 621 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 622 collectOps(pattern.getResultPattern(i), resultOps); 623 } 624 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 625 626 // Emit RewritePattern for Pattern. 627 auto locs = pattern.getLocation(); 628 os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n", 629 make_range(locs.rbegin(), locs.rend())); 630 os << formatv(R"(struct {0} : public ::mlir::RewritePattern { 631 {0}(::mlir::MLIRContext *context) 632 : ::mlir::RewritePattern("{1}", {{)", 633 rewriteName, rootName); 634 // Sort result operators by name. 635 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 636 resultOps.end()); 637 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 638 return lhs->getOperationName() < rhs->getOperationName(); 639 }); 640 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) { 641 os << '"' << op->getOperationName() << '"'; 642 }); 643 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 644 645 // Emit matchAndRewrite() function. 646 { 647 auto classScope = os.scope(); 648 os.reindent(R"( 649 ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, 650 ::mlir::PatternRewriter &rewriter) const override {)") 651 << '\n'; 652 { 653 auto functionScope = os.scope(); 654 655 // Register all symbols bound in the source pattern. 656 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 657 658 LLVM_DEBUG(llvm::dbgs() 659 << "start creating local variables for capturing matches\n"); 660 os << "// Variables for capturing values and attributes used while " 661 "creating ops\n"; 662 // Create local variables for storing the arguments and results bound 663 // to symbols. 664 for (const auto &symbolInfoPair : symbolInfoMap) { 665 const auto &symbol = symbolInfoPair.first; 666 const auto &info = symbolInfoPair.second; 667 668 os << info.getVarDecl(symbol); 669 } 670 // TODO: capture ops with consistent numbering so that it can be 671 // reused for fused loc. 672 os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n", 673 pattern.getSourcePattern().getNumOps()); 674 LLVM_DEBUG(llvm::dbgs() 675 << "done creating local variables for capturing matches\n"); 676 677 os << "// Match\n"; 678 os << "tblgen_ops[0] = op0;\n"; 679 emitMatchLogic(sourceTree, "op0"); 680 681 os << "\n// Rewrite\n"; 682 emitRewriteLogic(); 683 684 os << "return ::mlir::success();\n"; 685 } 686 os << "};\n"; 687 } 688 os << "};\n\n"; 689 } 690 691 void PatternEmitter::emitRewriteLogic() { 692 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 693 const Operator &rootOp = pattern.getSourceRootOp(); 694 int numExpectedResults = rootOp.getNumResults(); 695 int numResultPatterns = pattern.getNumResultPatterns(); 696 697 // First register all symbols bound to ops generated in result patterns. 698 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 699 700 // Only the last N static values generated are used to replace the matched 701 // root N-result op. We need to calculate the starting index (of the results 702 // of the matched op) each result pattern is to replace. 703 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 704 // If we don't need to replace any value at all, set the replacement starting 705 // index as the number of result patterns so we skip all of them when trying 706 // to replace the matched op's results. 707 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 708 for (int i = numResultPatterns - 1; i >= 0; --i) { 709 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 710 offsets[i] = offsets[i + 1] - numValues; 711 if (offsets[i] == 0) { 712 if (replStartIndex == -1) 713 replStartIndex = i; 714 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 715 auto error = formatv( 716 "cannot use the same multi-result op '{0}' to generate both " 717 "auxiliary values and values to be used for replacing the matched op", 718 pattern.getResultPattern(i).getSymbol()); 719 PrintFatalError(loc, error); 720 } 721 } 722 723 if (offsets.front() > 0) { 724 const char error[] = "no enough values generated to replace the matched op"; 725 PrintFatalError(loc, error); 726 } 727 728 os << "auto odsLoc = rewriter.getFusedLoc({"; 729 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 730 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 731 } 732 os << "}); (void)odsLoc;\n"; 733 734 // Process auxiliary result patterns. 735 for (int i = 0; i < replStartIndex; ++i) { 736 DagNode resultTree = pattern.getResultPattern(i); 737 auto val = handleResultPattern(resultTree, offsets[i], 0); 738 // Normal op creation will be streamed to `os` by the above call; but 739 // NativeCodeCall will only be materialized to `os` if it is used. Here 740 // we are handling auxiliary patterns so we want the side effect even if 741 // NativeCodeCall is not replacing matched root op's results. 742 if (resultTree.isNativeCodeCall()) 743 os << val << ";\n"; 744 } 745 746 if (numExpectedResults == 0) { 747 assert(replStartIndex >= numResultPatterns && 748 "invalid auxiliary vs. replacement pattern division!"); 749 // No result to replace. Just erase the op. 750 os << "rewriter.eraseOp(op0);\n"; 751 } else { 752 // Process replacement result patterns. 753 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n"; 754 for (int i = replStartIndex; i < numResultPatterns; ++i) { 755 DagNode resultTree = pattern.getResultPattern(i); 756 auto val = handleResultPattern(resultTree, offsets[i], 0); 757 os << "\n"; 758 // Resolve each symbol for all range use so that we can loop over them. 759 // We need an explicit cast to `SmallVector` to capture the cases where 760 // `{0}` resolves to an `Operation::result_range` as well as cases that 761 // are not iterable (e.g. vector that gets wrapped in additional braces by 762 // RewriterGen). 763 // TODO: Revisit the need for materializing a vector. 764 os << symbolInfoMap.getAllRangeUse( 765 val, 766 "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" 767 " tblgen_repl_values.push_back(v);\n}\n", 768 "\n"); 769 } 770 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n"; 771 } 772 773 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 774 } 775 776 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 777 return std::string( 778 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++)); 779 } 780 781 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 782 int resultIndex, int depth) { 783 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 784 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 785 LLVM_DEBUG(llvm::dbgs() << '\n'); 786 787 if (resultTree.isLocationDirective()) { 788 PrintFatalError(loc, 789 "location directive can only be used with op creation"); 790 } 791 792 if (resultTree.isNativeCodeCall()) { 793 auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth); 794 symbolInfoMap.bindValue(symbol); 795 return symbol; 796 } 797 798 if (resultTree.isReplaceWithValue()) 799 return handleReplaceWithValue(resultTree).str(); 800 801 // Normal op creation. 802 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 803 if (resultTree.getSymbol().empty()) { 804 // This is an op not explicitly bound to a symbol in the rewrite rule. 805 // Register the auto-generated symbol for it. 806 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 807 } 808 return symbol; 809 } 810 811 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { 812 assert(tree.isReplaceWithValue()); 813 814 if (tree.getNumArgs() != 1) { 815 PrintFatalError( 816 loc, "replaceWithValue directive must take exactly one argument"); 817 } 818 819 if (!tree.getSymbol().empty()) { 820 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 821 } 822 823 return tree.getArgName(0); 824 } 825 826 std::string PatternEmitter::handleLocationDirective(DagNode tree) { 827 assert(tree.isLocationDirective()); 828 auto lookUpArgLoc = [this, &tree](int idx) { 829 const auto *const lookupFmt = "(*{0}.begin()).getLoc()"; 830 return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt); 831 }; 832 833 if (tree.getNumArgs() == 0) 834 llvm::PrintFatalError( 835 "At least one argument to location directive required"); 836 837 if (!tree.getSymbol().empty()) 838 PrintFatalError(loc, "cannot bind symbol to location"); 839 840 if (tree.getNumArgs() == 1) { 841 DagLeaf leaf = tree.getArgAsLeaf(0); 842 if (leaf.isStringAttr()) 843 return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), " 844 "rewriter.getContext())", 845 leaf.getStringAttr()) 846 .str(); 847 return lookUpArgLoc(0); 848 } 849 850 std::string ret; 851 llvm::raw_string_ostream os(ret); 852 std::string strAttr; 853 os << "rewriter.getFusedLoc({"; 854 bool first = true; 855 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 856 DagLeaf leaf = tree.getArgAsLeaf(i); 857 // Handle the optional string value. 858 if (leaf.isStringAttr()) { 859 if (!strAttr.empty()) 860 llvm::PrintFatalError("Only one string attribute may be specified"); 861 strAttr = leaf.getStringAttr(); 862 continue; 863 } 864 os << (first ? "" : ", ") << lookUpArgLoc(i); 865 first = false; 866 } 867 os << "}"; 868 if (!strAttr.empty()) { 869 os << ", rewriter.getStringAttr(\"" << strAttr << "\")"; 870 } 871 os << ")"; 872 return os.str(); 873 } 874 875 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 876 StringRef patArgName) { 877 if (leaf.isStringAttr()) 878 PrintFatalError(loc, "raw string not supported as argument"); 879 if (leaf.isConstantAttr()) { 880 auto constAttr = leaf.getAsConstantAttr(); 881 return handleConstantAttr(constAttr.getAttribute(), 882 constAttr.getConstantValue()); 883 } 884 if (leaf.isEnumAttrCase()) { 885 auto enumCase = leaf.getAsEnumAttrCase(); 886 if (enumCase.isStrCase()) 887 return handleConstantAttr(enumCase, enumCase.getSymbol()); 888 // This is an enum case backed by an IntegerAttr. We need to get its value 889 // to build the constant. 890 std::string val = std::to_string(enumCase.getValue()); 891 return handleConstantAttr(enumCase, val); 892 } 893 894 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 895 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 896 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 897 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 898 << "' (via symbol ref)\n"); 899 return argName; 900 } 901 if (leaf.isNativeCodeCall()) { 902 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 903 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 904 << "' (via NativeCodeCall)\n"); 905 return std::string(repl); 906 } 907 PrintFatalError(loc, "unhandled case when rewriting op"); 908 } 909 910 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, 911 int depth) { 912 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 913 LLVM_DEBUG(tree.print(llvm::dbgs())); 914 LLVM_DEBUG(llvm::dbgs() << '\n'); 915 916 auto fmt = tree.getNativeCodeTemplate(); 917 // TODO: replace formatv arguments with the exact specified args. 918 SmallVector<std::string, 8> attrs(8); 919 if (tree.getNumArgs() > 8) { 920 PrintFatalError(loc, 921 "unsupported NativeCodeCall replace argument numbers: " + 922 Twine(tree.getNumArgs())); 923 } 924 bool hasLocationDirective; 925 std::string locToUse; 926 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 927 928 for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { 929 if (tree.isNestedDagArg(i)) { 930 attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1); 931 } else { 932 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 933 } 934 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i 935 << " replacement: " << attrs[i] << "\n"); 936 } 937 return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0], 938 attrs[1], attrs[2], attrs[3], attrs[4], attrs[5], 939 attrs[6], attrs[7])); 940 } 941 942 int PatternEmitter::getNodeValueCount(DagNode node) { 943 if (node.isOperation()) { 944 // If the op is bound to a symbol in the rewrite rule, query its result 945 // count from the symbol info map. 946 auto symbol = node.getSymbol(); 947 if (!symbol.empty()) { 948 return symbolInfoMap.getStaticValueCount(symbol); 949 } 950 // Otherwise this is an unbound op; we will use all its results. 951 return pattern.getDialectOp(node).getNumResults(); 952 } 953 // TODO: This considers all NativeCodeCall as returning one 954 // value. Enhance if multi-value ones are needed. 955 return 1; 956 } 957 958 std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) { 959 auto numPatArgs = tree.getNumArgs(); 960 961 if (numPatArgs != 0) { 962 if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1)) 963 if (lastArg.isLocationDirective()) { 964 return std::make_pair(true, handleLocationDirective(lastArg)); 965 } 966 } 967 968 // If no explicit location is given, use the default, all fused, location. 969 return std::make_pair(false, "odsLoc"); 970 } 971 972 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 973 int depth) { 974 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); 975 LLVM_DEBUG(tree.print(llvm::dbgs())); 976 LLVM_DEBUG(llvm::dbgs() << '\n'); 977 978 Operator &resultOp = tree.getDialectOp(opMap); 979 auto numOpArgs = resultOp.getNumArgs(); 980 auto numPatArgs = tree.getNumArgs(); 981 982 bool hasLocationDirective; 983 std::string locToUse; 984 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 985 986 auto inPattern = numPatArgs - hasLocationDirective; 987 if (numOpArgs != inPattern) { 988 PrintFatalError(loc, 989 formatv("resultant op '{0}' argument number mismatch: " 990 "{1} in pattern vs. {2} in definition", 991 resultOp.getOperationName(), inPattern, numOpArgs)); 992 } 993 994 // A map to collect all nested DAG child nodes' names, with operand index as 995 // the key. This includes both bound and unbound child nodes. 996 ChildNodeIndexNameMap childNodeNames; 997 998 // First go through all the child nodes who are nested DAG constructs to 999 // create ops for them and remember the symbol names for them, so that we can 1000 // use the results in the current node. This happens in a recursive manner. 1001 for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { 1002 if (auto child = tree.getArgAsNestedDag(i)) 1003 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 1004 } 1005 1006 // The name of the local variable holding this op. 1007 std::string valuePackName; 1008 // The symbol for holding the result of this pattern. Note that the result of 1009 // this pattern is not necessarily the same as the variable created by this 1010 // pattern because we can use `__N` suffix to refer only a specific result if 1011 // the generated op is a multi-result op. 1012 std::string resultValue; 1013 if (tree.getSymbol().empty()) { 1014 // No symbol is explicitly bound to this op in the pattern. Generate a 1015 // unique name. 1016 valuePackName = resultValue = getUniqueSymbol(&resultOp); 1017 } else { 1018 resultValue = std::string(tree.getSymbol()); 1019 // Strip the index to get the name for the value pack and use it to name the 1020 // local variable for the op. 1021 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue)); 1022 } 1023 1024 // Create the local variable for this op. 1025 os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(), 1026 valuePackName); 1027 1028 // Right now ODS don't have general type inference support. Except a few 1029 // special cases listed below, DRR needs to supply types for all results 1030 // when building an op. 1031 bool isSameOperandsAndResultType = 1032 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType"); 1033 bool useFirstAttr = 1034 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"); 1035 1036 if (isSameOperandsAndResultType || useFirstAttr) { 1037 // We know how to deduce the result type for ops with these traits and we've 1038 // generated builders taking aggregate parameters. Use those builders to 1039 // create the ops. 1040 1041 // First prepare local variables for op arguments used in builder call. 1042 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 1043 1044 // Then create the op. 1045 os.scope("", "\n}\n").os << formatv( 1046 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);", 1047 valuePackName, resultOp.getQualCppClassName(), locToUse); 1048 return resultValue; 1049 } 1050 1051 bool usePartialResults = valuePackName != resultValue; 1052 1053 if (usePartialResults || depth > 0 || resultIndex < 0) { 1054 // For these cases (broadcastable ops, op results used both as auxiliary 1055 // values and replacement values, ops in nested patterns, auxiliary ops), we 1056 // still need to supply the result types when building the op. But because 1057 // we don't generate a builder automatically with ODS for them, it's the 1058 // developer's responsibility to make sure such a builder (with result type 1059 // deduction ability) exists. We go through the separate-parameter builder 1060 // here given that it's easier for developers to write compared to 1061 // aggregate-parameter builders. 1062 createSeparateLocalVarsForOpArgs(tree, childNodeNames); 1063 1064 os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, 1065 resultOp.getQualCppClassName(), locToUse); 1066 supplyValuesForOpArgs(tree, childNodeNames, depth); 1067 os << "\n );\n}\n"; 1068 return resultValue; 1069 } 1070 1071 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 1072 // generated from the source pattern root op. Then we can use the source 1073 // pattern's value types to determine the value type of the generated op 1074 // here. 1075 1076 // First prepare local variables for op arguments used in builder call. 1077 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 1078 1079 // Then prepare the result types. We need to specify the types for all 1080 // results. 1081 os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; " 1082 "(void)tblgen_types;\n"); 1083 int numResults = resultOp.getNumResults(); 1084 if (numResults != 0) { 1085 for (int i = 0; i < numResults; ++i) 1086 os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" 1087 " tblgen_types.push_back(v.getType());\n}\n", 1088 resultIndex + i); 1089 } 1090 os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " 1091 "tblgen_values, tblgen_attrs);\n", 1092 valuePackName, resultOp.getQualCppClassName(), locToUse); 1093 os.unindent() << "}\n"; 1094 return resultValue; 1095 } 1096 1097 void PatternEmitter::createSeparateLocalVarsForOpArgs( 1098 DagNode node, ChildNodeIndexNameMap &childNodeNames) { 1099 Operator &resultOp = node.getDialectOp(opMap); 1100 1101 // Now prepare operands used for building this op: 1102 // * If the operand is non-variadic, we create a `Value` local variable. 1103 // * If the operand is variadic, we create a `SmallVector<Value>` local 1104 // variable. 1105 1106 int valueIndex = 0; // An index for uniquing local variable names. 1107 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1108 const auto *operand = 1109 resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>(); 1110 // We do not need special handling for attributes. 1111 if (!operand) 1112 continue; 1113 1114 raw_indented_ostream::DelimitedScope scope(os); 1115 std::string varName; 1116 if (operand->isVariadic()) { 1117 varName = std::string(formatv("tblgen_values_{0}", valueIndex++)); 1118 os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName); 1119 std::string range; 1120 if (node.isNestedDagArg(argIndex)) { 1121 range = childNodeNames[argIndex]; 1122 } else { 1123 range = std::string(node.getArgName(argIndex)); 1124 } 1125 // Resolve the symbol for all range use so that we have a uniform way of 1126 // capturing the values. 1127 range = symbolInfoMap.getValueAndRangeUse(range); 1128 os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range, 1129 varName); 1130 } else { 1131 varName = std::string(formatv("tblgen_value_{0}", valueIndex++)); 1132 os << formatv("::mlir::Value {0} = ", varName); 1133 if (node.isNestedDagArg(argIndex)) { 1134 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 1135 } else { 1136 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1137 auto symbol = 1138 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1139 if (leaf.isNativeCodeCall()) { 1140 os << std::string( 1141 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1142 } else { 1143 os << symbol; 1144 } 1145 } 1146 os << ";\n"; 1147 } 1148 1149 // Update to use the newly created local variable for building the op later. 1150 childNodeNames[argIndex] = varName; 1151 } 1152 } 1153 1154 void PatternEmitter::supplyValuesForOpArgs( 1155 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 1156 Operator &resultOp = node.getDialectOp(opMap); 1157 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); 1158 argIndex != numOpArgs; ++argIndex) { 1159 // Start each argument on its own line. 1160 os << ",\n "; 1161 1162 Argument opArg = resultOp.getArg(argIndex); 1163 // Handle the case of operand first. 1164 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 1165 if (!operand->name.empty()) 1166 os << "/*" << operand->name << "=*/"; 1167 os << childNodeNames.lookup(argIndex); 1168 continue; 1169 } 1170 1171 // The argument in the op definition. 1172 auto opArgName = resultOp.getArgName(argIndex); 1173 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1174 if (!subTree.isNativeCodeCall()) 1175 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1176 "for creating attribute"); 1177 os << formatv("/*{0}=*/{1}", opArgName, 1178 handleReplaceWithNativeCodeCall(subTree, depth)); 1179 } else { 1180 auto leaf = node.getArgAsLeaf(argIndex); 1181 // The argument in the result DAG pattern. 1182 auto patArgName = node.getArgName(argIndex); 1183 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 1184 // TODO: Refactor out into map to avoid recomputing these. 1185 if (!opArg.is<NamedAttribute *>()) 1186 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 1187 if (!patArgName.empty()) 1188 os << "/*" << patArgName << "=*/"; 1189 } else { 1190 os << "/*" << opArgName << "=*/"; 1191 } 1192 os << handleOpArgument(leaf, patArgName); 1193 } 1194 } 1195 } 1196 1197 void PatternEmitter::createAggregateLocalVarsForOpArgs( 1198 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 1199 Operator &resultOp = node.getDialectOp(opMap); 1200 1201 auto scope = os.scope(); 1202 os << formatv("::mlir::SmallVector<::mlir::Value, 4> " 1203 "tblgen_values; (void)tblgen_values;\n"); 1204 os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> " 1205 "tblgen_attrs; (void)tblgen_attrs;\n"); 1206 1207 const char *addAttrCmd = 1208 "if (auto tmpAttr = {1}) {\n" 1209 " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), " 1210 "tmpAttr);\n}\n"; 1211 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1212 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) { 1213 // The argument in the op definition. 1214 auto opArgName = resultOp.getArgName(argIndex); 1215 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1216 if (!subTree.isNativeCodeCall()) 1217 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1218 "for creating attribute"); 1219 os << formatv(addAttrCmd, opArgName, 1220 handleReplaceWithNativeCodeCall(subTree, depth + 1)); 1221 } else { 1222 auto leaf = node.getArgAsLeaf(argIndex); 1223 // The argument in the result DAG pattern. 1224 auto patArgName = node.getArgName(argIndex); 1225 os << formatv(addAttrCmd, opArgName, 1226 handleOpArgument(leaf, patArgName)); 1227 } 1228 continue; 1229 } 1230 1231 const auto *operand = 1232 resultOp.getArg(argIndex).get<NamedTypeConstraint *>(); 1233 std::string varName; 1234 if (operand->isVariadic()) { 1235 std::string range; 1236 if (node.isNestedDagArg(argIndex)) { 1237 range = childNodeNames.lookup(argIndex); 1238 } else { 1239 range = std::string(node.getArgName(argIndex)); 1240 } 1241 // Resolve the symbol for all range use so that we have a uniform way of 1242 // capturing the values. 1243 range = symbolInfoMap.getValueAndRangeUse(range); 1244 os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n", 1245 range); 1246 } else { 1247 os << formatv("tblgen_values.push_back("); 1248 if (node.isNestedDagArg(argIndex)) { 1249 os << symbolInfoMap.getValueAndRangeUse( 1250 childNodeNames.lookup(argIndex)); 1251 } else { 1252 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1253 auto symbol = 1254 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1255 if (leaf.isNativeCodeCall()) { 1256 os << std::string( 1257 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1258 } else { 1259 os << symbol; 1260 } 1261 } 1262 os << ");\n"; 1263 } 1264 } 1265 } 1266 1267 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 1268 emitSourceFileHeader("Rewriters", os); 1269 1270 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 1271 auto numPatterns = patterns.size(); 1272 1273 // We put the map here because it can be shared among multiple patterns. 1274 RecordOperatorMap recordOpMap; 1275 1276 std::vector<std::string> rewriterNames; 1277 rewriterNames.reserve(numPatterns); 1278 1279 std::string baseRewriterName = "GeneratedConvert"; 1280 int rewriterIndex = 0; 1281 1282 for (Record *p : patterns) { 1283 std::string name; 1284 if (p->isAnonymous()) { 1285 // If no name is provided, ensure unique rewriter names simply by 1286 // appending unique suffix. 1287 name = baseRewriterName + llvm::utostr(rewriterIndex++); 1288 } else { 1289 name = std::string(p->getName()); 1290 } 1291 LLVM_DEBUG(llvm::dbgs() 1292 << "=== start generating pattern '" << name << "' ===\n"); 1293 PatternEmitter(p, &recordOpMap, os).emit(name); 1294 LLVM_DEBUG(llvm::dbgs() 1295 << "=== done generating pattern '" << name << "' ===\n"); 1296 rewriterNames.push_back(std::move(name)); 1297 } 1298 1299 // Emit function to add the generated matchers to the pattern list. 1300 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::MLIRContext " 1301 "*context, ::mlir::OwningRewritePatternList &patterns) {\n"; 1302 for (const auto &name : rewriterNames) { 1303 os << " patterns.insert<" << name << ">(context);\n"; 1304 } 1305 os << "}\n"; 1306 } 1307 1308 static mlir::GenRegistration 1309 genRewriters("gen-rewriters", "Generate pattern rewriters", 1310 [](const RecordKeeper &records, raw_ostream &os) { 1311 emitRewriters(records, os); 1312 return false; 1313 }); 1314