1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Support/IndentedOstream.h" 14 #include "mlir/TableGen/Attribute.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/GenInfo.h" 17 #include "mlir/TableGen/Operator.h" 18 #include "mlir/TableGen/Pattern.h" 19 #include "mlir/TableGen/Predicate.h" 20 #include "mlir/TableGen/Type.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/StringSet.h" 23 #include "llvm/Support/CommandLine.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/FormatAdapters.h" 26 #include "llvm/Support/PrettyStackTrace.h" 27 #include "llvm/Support/Signals.h" 28 #include "llvm/TableGen/Error.h" 29 #include "llvm/TableGen/Main.h" 30 #include "llvm/TableGen/Record.h" 31 #include "llvm/TableGen/TableGenBackend.h" 32 33 using namespace mlir; 34 using namespace mlir::tblgen; 35 36 using llvm::formatv; 37 using llvm::Record; 38 using llvm::RecordKeeper; 39 40 #define DEBUG_TYPE "mlir-tblgen-rewritergen" 41 42 namespace llvm { 43 template <> 44 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 45 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 46 raw_ostream &os, StringRef style) { 47 os << v.first << ":" << v.second; 48 } 49 }; 50 } // end namespace llvm 51 52 //===----------------------------------------------------------------------===// 53 // PatternEmitter 54 //===----------------------------------------------------------------------===// 55 56 namespace { 57 class PatternEmitter { 58 public: 59 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 60 61 // Emits the mlir::RewritePattern struct named `rewriteName`. 62 void emit(StringRef rewriteName); 63 64 private: 65 // Emits the code for matching ops. 66 void emitMatchLogic(DagNode tree, StringRef opName); 67 68 // Emits the code for rewriting ops. 69 void emitRewriteLogic(); 70 71 //===--------------------------------------------------------------------===// 72 // Match utilities 73 //===--------------------------------------------------------------------===// 74 75 // Emits C++ statements for matching the DAG structure. 76 void emitMatch(DagNode tree, StringRef name, int depth); 77 78 // Emits C++ statements for matching using a native code call. 79 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); 80 81 // Emits C++ statements for matching the op constrained by the given DAG 82 // `tree` returning the op's variable name. 83 void emitOpMatch(DagNode tree, StringRef opName, int depth); 84 85 // Emits C++ statements for matching the `argIndex`-th argument of the given 86 // DAG `tree` as an operand. operandIndex is the index in the DAG excluding 87 // the preceding attributes. 88 void emitOperandMatch(DagNode tree, StringRef opName, int argIndex, 89 int operandIndex, int depth); 90 91 // Emits C++ statements for matching the `argIndex`-th argument of the given 92 // DAG `tree` as an attribute. 93 void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, 94 int depth); 95 96 // Emits C++ for checking a match with a corresponding match failure 97 // diagnostic. 98 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, 99 const llvm::formatv_object_base &failureFmt); 100 101 // Emits C++ for checking a match with a corresponding match failure 102 // diagnostics. 103 void emitMatchCheck(StringRef opName, const std::string &matchStr, 104 const std::string &failureStr); 105 106 //===--------------------------------------------------------------------===// 107 // Rewrite utilities 108 //===--------------------------------------------------------------------===// 109 110 // The entry point for handling a result pattern rooted at `resultTree`. This 111 // method dispatches to concrete handlers according to `resultTree`'s kind and 112 // returns a symbol representing the whole value pack. Callers are expected to 113 // further resolve the symbol according to the specific use case. 114 // 115 // `depth` is the nesting level of `resultTree`; 0 means top-level result 116 // pattern. For top-level result pattern, `resultIndex` indicates which result 117 // of the matched root op this pattern is intended to replace, which can be 118 // used to deduce the result type of the op generated from this result 119 // pattern. 120 std::string handleResultPattern(DagNode resultTree, int resultIndex, 121 int depth); 122 123 // Emits the C++ statement to replace the matched DAG with a value built via 124 // calling native C++ code. 125 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); 126 127 // Returns the symbol of the old value serving as the replacement. 128 StringRef handleReplaceWithValue(DagNode tree); 129 130 // Returns the location value to use. 131 std::pair<bool, std::string> getLocation(DagNode tree); 132 133 // Returns the location value to use. 134 std::string handleLocationDirective(DagNode tree); 135 136 // Emits the C++ statement to build a new op out of the given DAG `tree` and 137 // returns the variable name that this op is assigned to. If the root op in 138 // DAG `tree` has a specified name, the created op will be assigned to a 139 // variable of the given name. Otherwise, a unique name will be used as the 140 // result value name. 141 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 142 143 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; 144 145 // Emits a local variable for each value and attribute to be used for creating 146 // an op. 147 void createSeparateLocalVarsForOpArgs(DagNode node, 148 ChildNodeIndexNameMap &childNodeNames); 149 150 // Emits the concrete arguments used to call an op's builder. 151 void supplyValuesForOpArgs(DagNode node, 152 const ChildNodeIndexNameMap &childNodeNames, 153 int depth); 154 155 // Emits the local variables for holding all values as a whole and all named 156 // attributes as a whole to be used for creating an op. 157 void createAggregateLocalVarsForOpArgs( 158 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); 159 160 // Returns the C++ expression to construct a constant attribute of the given 161 // `value` for the given attribute kind `attr`. 162 std::string handleConstantAttr(Attribute attr, StringRef value); 163 164 // Returns the C++ expression to build an argument from the given DAG `leaf`. 165 // `patArgName` is used to bound the argument to the source pattern. 166 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 167 168 //===--------------------------------------------------------------------===// 169 // General utilities 170 //===--------------------------------------------------------------------===// 171 172 // Collects all of the operations within the given dag tree. 173 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 174 175 // Returns a unique symbol for a local variable of the given `op`. 176 std::string getUniqueSymbol(const Operator *op); 177 178 //===--------------------------------------------------------------------===// 179 // Symbol utilities 180 //===--------------------------------------------------------------------===// 181 182 // Returns how many static values the given DAG `node` correspond to. 183 int getNodeValueCount(DagNode node); 184 185 private: 186 // Pattern instantiation location followed by the location of multiclass 187 // prototypes used. This is intended to be used as a whole to 188 // PrintFatalError() on errors. 189 ArrayRef<llvm::SMLoc> loc; 190 191 // Op's TableGen Record to wrapper object. 192 RecordOperatorMap *opMap; 193 194 // Handy wrapper for pattern being emitted. 195 Pattern pattern; 196 197 // Map for all bound symbols' info. 198 SymbolInfoMap symbolInfoMap; 199 200 // The next unused ID for newly created values. 201 unsigned nextValueId; 202 203 raw_indented_ostream os; 204 205 // Format contexts containing placeholder substitutions. 206 FmtContext fmtCtx; 207 208 // Number of op processed. 209 int opCounter = 0; 210 }; 211 } // end anonymous namespace 212 213 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 214 raw_ostream &os) 215 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 216 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 217 fmtCtx.withBuilder("rewriter"); 218 } 219 220 std::string PatternEmitter::handleConstantAttr(Attribute attr, 221 StringRef value) { 222 if (!attr.isConstBuildable()) 223 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 224 " does not have the 'constBuilderCall' field"); 225 226 // TODO: Verify the constants here 227 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value)); 228 } 229 230 // Helper function to match patterns. 231 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { 232 if (tree.isNativeCodeCall()) { 233 emitNativeCodeMatch(tree, name, depth); 234 return; 235 } 236 237 if (tree.isOperation()) { 238 emitOpMatch(tree, name, depth); 239 return; 240 } 241 242 PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match."); 243 } 244 245 // Helper function to match patterns. 246 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, 247 int depth) { 248 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: "); 249 LLVM_DEBUG(tree.print(llvm::dbgs())); 250 LLVM_DEBUG(llvm::dbgs() << '\n'); 251 252 // TODO(suderman): iterate through arguments, determine their types, output 253 // names. 254 SmallVector<std::string, 8> capture; 255 256 raw_indented_ostream::DelimitedScope scope(os); 257 258 os << "if(!" << opName << ") return ::mlir::failure();\n"; 259 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 260 std::string argName = formatv("arg{0}_{1}", depth, i); 261 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 262 os << "Value " << argName << ";\n"; 263 } else { 264 auto leaf = tree.getArgAsLeaf(i); 265 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { 266 os << "Attribute " << argName << ";\n"; 267 } else { 268 os << "Value " << argName << ";\n"; 269 } 270 } 271 272 capture.push_back(std::move(argName)); 273 } 274 275 bool hasLocationDirective; 276 std::string locToUse; 277 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 278 279 auto fmt = tree.getNativeCodeTemplate(); 280 if (fmt.count("$_self") != 1) { 281 PrintFatalError(loc, "NativeCodeCall must have $_self as argument for " 282 "passing the defining Operation"); 283 } 284 285 auto nativeCodeCall = std::string(tgfmt( 286 fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), capture)); 287 288 os << "if (failed(" << nativeCodeCall << ")) return ::mlir::failure();\n"; 289 290 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 291 auto name = tree.getArgName(i); 292 if (!name.empty() && name != "_") { 293 os << formatv("{0} = {1};\n", name, capture[i]); 294 } 295 } 296 297 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 298 std::string argName = capture[i]; 299 300 // Handle nested DAG construct first 301 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 302 PrintFatalError( 303 loc, formatv("Matching nested tree in NativeCodecall not support for " 304 "{0} as arg {1}", 305 argName, i)); 306 } 307 308 DagLeaf leaf = tree.getArgAsLeaf(i); 309 310 // The parameter for native function doesn't bind any constraints. 311 if (leaf.isUnspecified()) 312 continue; 313 314 auto constraint = leaf.getAsConstraint(); 315 316 std::string self; 317 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) 318 self = argName; 319 else 320 self = formatv("{0}.getType()", argName); 321 emitMatchCheck( 322 opName, 323 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), 324 formatv("\"operand {0} of native code call '{1}' failed to satisfy " 325 "constraint: " 326 "'{2}'\"", 327 i, tree.getNativeCodeTemplate(), constraint.getSummary())); 328 } 329 330 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n"); 331 } 332 333 // Helper function to match patterns. 334 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { 335 Operator &op = tree.getDialectOp(opMap); 336 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 337 << op.getOperationName() << "' at depth " << depth 338 << '\n'); 339 340 std::string castedName = formatv("castedOp{0}", depth); 341 os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); " 342 "(void){0};\n", 343 castedName, opName, op.getQualCppClassName()); 344 // Skip the operand matching at depth 0 as the pattern rewriter already does. 345 if (depth != 0) { 346 // Skip if there is no defining operation (e.g., arguments to function). 347 os << formatv("if (!{0}) return ::mlir::failure();\n", castedName); 348 } 349 if (tree.getNumArgs() != op.getNumArgs()) { 350 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 351 "pattern vs. {2} in definition", 352 op.getOperationName(), tree.getNumArgs(), 353 op.getNumArgs())); 354 } 355 356 // If the operand's name is set, set to that variable. 357 auto name = tree.getSymbol(); 358 if (!name.empty()) 359 os << formatv("{0} = {1};\n", name, castedName); 360 361 for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) { 362 auto opArg = op.getArg(i); 363 std::string argName = formatv("op{0}", depth + 1); 364 365 // Handle nested DAG construct first 366 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 367 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 368 if (operand->isVariableLength()) { 369 auto error = formatv("use nested DAG construct to match op {0}'s " 370 "variadic operand #{1} unsupported now", 371 op.getOperationName(), i); 372 PrintFatalError(loc, error); 373 } 374 } 375 os << "{\n"; 376 377 // Attributes don't count for getODSOperands. 378 // TODO: Operand is a Value, check if we should remove `getDefiningOp()`. 379 os.indent() << formatv( 380 "auto *{0} = " 381 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", 382 argName, castedName, nextOperand++); 383 emitMatch(argTree, argName, depth + 1); 384 os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName); 385 os.unindent() << "}\n"; 386 continue; 387 } 388 389 // Next handle DAG leaf: operand or attribute 390 if (opArg.is<NamedTypeConstraint *>()) { 391 // emitOperandMatch's argument indexing counts attributes. 392 emitOperandMatch(tree, castedName, i, nextOperand, depth); 393 ++nextOperand; 394 } else if (opArg.is<NamedAttribute *>()) { 395 emitAttributeMatch(tree, opName, i, depth); 396 } else { 397 PrintFatalError(loc, "unhandled case when matching op"); 398 } 399 } 400 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 401 << op.getOperationName() << "' at depth " << depth 402 << '\n'); 403 } 404 405 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, 406 int argIndex, int operandIndex, 407 int depth) { 408 Operator &op = tree.getDialectOp(opMap); 409 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>(); 410 auto matcher = tree.getArgAsLeaf(argIndex); 411 412 // If a constraint is specified, we need to generate C++ statements to 413 // check the constraint. 414 if (!matcher.isUnspecified()) { 415 if (!matcher.isOperandMatcher()) { 416 PrintFatalError( 417 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 418 op.getOperationName(), argIndex + 1)); 419 } 420 421 // Only need to verify if the matcher's type is different from the one 422 // of op definition. 423 Constraint constraint = matcher.getAsConstraint(); 424 if (operand->constraint != constraint) { 425 if (operand->isVariableLength()) { 426 auto error = formatv( 427 "further constrain op {0}'s variadic operand #{1} unsupported now", 428 op.getOperationName(), argIndex); 429 PrintFatalError(loc, error); 430 } 431 auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()", 432 opName, operandIndex); 433 emitMatchCheck( 434 opName, 435 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), 436 formatv("\"operand {0} of op '{1}' failed to satisfy constraint: " 437 "'{2}'\"", 438 operand - op.operand_begin(), op.getOperationName(), 439 constraint.getSummary())); 440 } 441 } 442 443 // Capture the value 444 auto name = tree.getArgName(argIndex); 445 // `$_` is a special symbol to ignore op argument matching. 446 if (!name.empty() && name != "_") { 447 // We need to subtract the number of attributes before this operand to get 448 // the index in the operand list. 449 auto numPrevAttrs = std::count_if( 450 op.arg_begin(), op.arg_begin() + argIndex, 451 [](const Argument &arg) { return arg.is<NamedAttribute *>(); }); 452 453 auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex); 454 os << formatv("{0} = {1}.getODSOperands({2});\n", 455 res->second.getVarName(name), opName, 456 argIndex - numPrevAttrs); 457 } 458 } 459 460 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, 461 int argIndex, int depth) { 462 Operator &op = tree.getDialectOp(opMap); 463 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>(); 464 const auto &attr = namedAttr->attr; 465 466 os << "{\n"; 467 os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" 468 "(void)tblgen_attr;\n", 469 opName, attr.getStorageType(), namedAttr->name); 470 471 // TODO: This should use getter method to avoid duplication. 472 if (attr.hasDefaultValue()) { 473 os << "if (!tblgen_attr) tblgen_attr = " 474 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 475 attr.getDefaultValue())) 476 << ";\n"; 477 } else if (attr.isOptional()) { 478 // For a missing attribute that is optional according to definition, we 479 // should just capture a mlir::Attribute() to signal the missing state. 480 // That is precisely what getAttr() returns on missing attributes. 481 } else { 482 emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx), 483 formatv("\"expected op '{0}' to have attribute '{1}' " 484 "of type '{2}'\"", 485 op.getOperationName(), namedAttr->name, 486 attr.getStorageType())); 487 } 488 489 auto matcher = tree.getArgAsLeaf(argIndex); 490 if (!matcher.isUnspecified()) { 491 if (!matcher.isAttrMatcher()) { 492 PrintFatalError( 493 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 494 op.getOperationName(), argIndex + 1)); 495 } 496 497 // If a constraint is specified, we need to generate C++ statements to 498 // check the constraint. 499 emitMatchCheck( 500 opName, 501 tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")), 502 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " 503 "{2}\"", 504 op.getOperationName(), namedAttr->name, 505 matcher.getAsConstraint().getSummary())); 506 } 507 508 // Capture the value 509 auto name = tree.getArgName(argIndex); 510 // `$_` is a special symbol to ignore op argument matching. 511 if (!name.empty() && name != "_") { 512 os << formatv("{0} = tblgen_attr;\n", name); 513 } 514 515 os.unindent() << "}\n"; 516 } 517 518 void PatternEmitter::emitMatchCheck( 519 StringRef opName, const FmtObjectBase &matchFmt, 520 const llvm::formatv_object_base &failureFmt) { 521 emitMatchCheck(opName, matchFmt.str(), failureFmt.str()); 522 } 523 524 void PatternEmitter::emitMatchCheck(StringRef opName, 525 const std::string &matchStr, 526 const std::string &failureStr) { 527 528 os << "if (!(" << matchStr << "))"; 529 os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName 530 << ", [&](::mlir::Diagnostic &diag) {\n diag << " 531 << failureStr << ";\n});"; 532 } 533 534 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { 535 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 536 int depth = 0; 537 emitMatch(tree, opName, depth); 538 539 for (auto &appliedConstraint : pattern.getConstraints()) { 540 auto &constraint = appliedConstraint.constraint; 541 auto &entities = appliedConstraint.entities; 542 543 auto condition = constraint.getConditionTemplate(); 544 if (isa<TypeConstraint>(constraint)) { 545 auto self = formatv("({0}.getType())", 546 symbolInfoMap.getValueAndRangeUse(entities.front())); 547 emitMatchCheck( 548 opName, tgfmt(condition, &fmtCtx.withSelf(self.str())), 549 formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"", 550 entities.front(), constraint.getSummary())); 551 552 } else if (isa<AttrConstraint>(constraint)) { 553 PrintFatalError( 554 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 555 } else { 556 // TODO: replace formatv arguments with the exact specified 557 // args. 558 if (entities.size() > 4) { 559 PrintFatalError(loc, "only support up to 4-entity constraints now"); 560 } 561 SmallVector<std::string, 4> names; 562 int i = 0; 563 for (int e = entities.size(); i < e; ++i) 564 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 565 std::string self = appliedConstraint.self; 566 if (!self.empty()) 567 self = symbolInfoMap.getValueAndRangeUse(self); 568 for (; i < 4; ++i) 569 names.push_back("<unused>"); 570 emitMatchCheck(opName, 571 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 572 names[1], names[2], names[3]), 573 formatv("\"entities '{0}' failed to satisfy constraint: " 574 "{1}\"", 575 llvm::join(entities, ", "), 576 constraint.getSummary())); 577 } 578 } 579 580 // Some of the operands could be bound to the same symbol name, we need 581 // to enforce equality constraint on those. 582 // TODO: we should be able to emit equality checks early 583 // and short circuit unnecessary work if vars are not equal. 584 for (auto symbolInfoIt = symbolInfoMap.begin(); 585 symbolInfoIt != symbolInfoMap.end();) { 586 auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); 587 auto startRange = range.first; 588 auto endRange = range.second; 589 590 auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); 591 for (++startRange; startRange != endRange; ++startRange) { 592 auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); 593 emitMatchCheck( 594 opName, 595 formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), 596 formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, 597 secondOperand)); 598 } 599 600 symbolInfoIt = endRange; 601 } 602 603 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 604 } 605 606 void PatternEmitter::collectOps(DagNode tree, 607 llvm::SmallPtrSetImpl<const Operator *> &ops) { 608 // Check if this tree is an operation. 609 if (tree.isOperation()) { 610 const Operator &op = tree.getDialectOp(opMap); 611 LLVM_DEBUG(llvm::dbgs() 612 << "found operation " << op.getOperationName() << '\n'); 613 ops.insert(&op); 614 } 615 616 // Recurse the arguments of the tree. 617 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 618 if (auto child = tree.getArgAsNestedDag(i)) 619 collectOps(child, ops); 620 } 621 622 void PatternEmitter::emit(StringRef rewriteName) { 623 // Get the DAG tree for the source pattern. 624 DagNode sourceTree = pattern.getSourcePattern(); 625 626 const Operator &rootOp = pattern.getSourceRootOp(); 627 auto rootName = rootOp.getOperationName(); 628 629 // Collect the set of result operations. 630 llvm::SmallPtrSet<const Operator *, 4> resultOps; 631 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 632 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 633 collectOps(pattern.getResultPattern(i), resultOps); 634 } 635 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 636 637 // Emit RewritePattern for Pattern. 638 auto locs = pattern.getLocation(); 639 os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n", 640 make_range(locs.rbegin(), locs.rend())); 641 os << formatv(R"(struct {0} : public ::mlir::RewritePattern { 642 {0}(::mlir::MLIRContext *context) 643 : ::mlir::RewritePattern("{1}", {2}, context, {{)", 644 rewriteName, rootName, pattern.getBenefit()); 645 // Sort result operators by name. 646 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 647 resultOps.end()); 648 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 649 return lhs->getOperationName() < rhs->getOperationName(); 650 }); 651 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) { 652 os << '"' << op->getOperationName() << '"'; 653 }); 654 os << "}) {}\n"; 655 656 // Emit matchAndRewrite() function. 657 { 658 auto classScope = os.scope(); 659 os.reindent(R"( 660 ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, 661 ::mlir::PatternRewriter &rewriter) const override {)") 662 << '\n'; 663 { 664 auto functionScope = os.scope(); 665 666 // Register all symbols bound in the source pattern. 667 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 668 669 LLVM_DEBUG(llvm::dbgs() 670 << "start creating local variables for capturing matches\n"); 671 os << "// Variables for capturing values and attributes used while " 672 "creating ops\n"; 673 // Create local variables for storing the arguments and results bound 674 // to symbols. 675 for (const auto &symbolInfoPair : symbolInfoMap) { 676 const auto &symbol = symbolInfoPair.first; 677 const auto &info = symbolInfoPair.second; 678 679 os << info.getVarDecl(symbol); 680 } 681 // TODO: capture ops with consistent numbering so that it can be 682 // reused for fused loc. 683 os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n", 684 pattern.getSourcePattern().getNumOps()); 685 LLVM_DEBUG(llvm::dbgs() 686 << "done creating local variables for capturing matches\n"); 687 688 os << "// Match\n"; 689 os << "tblgen_ops[0] = op0;\n"; 690 emitMatchLogic(sourceTree, "op0"); 691 692 os << "\n// Rewrite\n"; 693 emitRewriteLogic(); 694 695 os << "return ::mlir::success();\n"; 696 } 697 os << "};\n"; 698 } 699 os << "};\n\n"; 700 } 701 702 void PatternEmitter::emitRewriteLogic() { 703 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 704 const Operator &rootOp = pattern.getSourceRootOp(); 705 int numExpectedResults = rootOp.getNumResults(); 706 int numResultPatterns = pattern.getNumResultPatterns(); 707 708 // First register all symbols bound to ops generated in result patterns. 709 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 710 711 // Only the last N static values generated are used to replace the matched 712 // root N-result op. We need to calculate the starting index (of the results 713 // of the matched op) each result pattern is to replace. 714 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 715 // If we don't need to replace any value at all, set the replacement starting 716 // index as the number of result patterns so we skip all of them when trying 717 // to replace the matched op's results. 718 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 719 for (int i = numResultPatterns - 1; i >= 0; --i) { 720 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 721 offsets[i] = offsets[i + 1] - numValues; 722 if (offsets[i] == 0) { 723 if (replStartIndex == -1) 724 replStartIndex = i; 725 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 726 auto error = formatv( 727 "cannot use the same multi-result op '{0}' to generate both " 728 "auxiliary values and values to be used for replacing the matched op", 729 pattern.getResultPattern(i).getSymbol()); 730 PrintFatalError(loc, error); 731 } 732 } 733 734 if (offsets.front() > 0) { 735 const char error[] = "no enough values generated to replace the matched op"; 736 PrintFatalError(loc, error); 737 } 738 739 os << "auto odsLoc = rewriter.getFusedLoc({"; 740 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 741 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 742 } 743 os << "}); (void)odsLoc;\n"; 744 745 // Process auxiliary result patterns. 746 for (int i = 0; i < replStartIndex; ++i) { 747 DagNode resultTree = pattern.getResultPattern(i); 748 auto val = handleResultPattern(resultTree, offsets[i], 0); 749 // Normal op creation will be streamed to `os` by the above call; but 750 // NativeCodeCall will only be materialized to `os` if it is used. Here 751 // we are handling auxiliary patterns so we want the side effect even if 752 // NativeCodeCall is not replacing matched root op's results. 753 if (resultTree.isNativeCodeCall()) 754 os << val << ";\n"; 755 } 756 757 if (numExpectedResults == 0) { 758 assert(replStartIndex >= numResultPatterns && 759 "invalid auxiliary vs. replacement pattern division!"); 760 // No result to replace. Just erase the op. 761 os << "rewriter.eraseOp(op0);\n"; 762 } else { 763 // Process replacement result patterns. 764 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n"; 765 for (int i = replStartIndex; i < numResultPatterns; ++i) { 766 DagNode resultTree = pattern.getResultPattern(i); 767 auto val = handleResultPattern(resultTree, offsets[i], 0); 768 os << "\n"; 769 // Resolve each symbol for all range use so that we can loop over them. 770 // We need an explicit cast to `SmallVector` to capture the cases where 771 // `{0}` resolves to an `Operation::result_range` as well as cases that 772 // are not iterable (e.g. vector that gets wrapped in additional braces by 773 // RewriterGen). 774 // TODO: Revisit the need for materializing a vector. 775 os << symbolInfoMap.getAllRangeUse( 776 val, 777 "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" 778 " tblgen_repl_values.push_back(v);\n}\n", 779 "\n"); 780 } 781 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n"; 782 } 783 784 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 785 } 786 787 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 788 return std::string( 789 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++)); 790 } 791 792 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 793 int resultIndex, int depth) { 794 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 795 LLVM_DEBUG(resultTree.print(llvm::dbgs())); 796 LLVM_DEBUG(llvm::dbgs() << '\n'); 797 798 if (resultTree.isLocationDirective()) { 799 PrintFatalError(loc, 800 "location directive can only be used with op creation"); 801 } 802 803 if (resultTree.isNativeCodeCall()) { 804 auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth); 805 symbolInfoMap.bindValue(symbol); 806 return symbol; 807 } 808 809 if (resultTree.isReplaceWithValue()) 810 return handleReplaceWithValue(resultTree).str(); 811 812 // Normal op creation. 813 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 814 if (resultTree.getSymbol().empty()) { 815 // This is an op not explicitly bound to a symbol in the rewrite rule. 816 // Register the auto-generated symbol for it. 817 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 818 } 819 return symbol; 820 } 821 822 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { 823 assert(tree.isReplaceWithValue()); 824 825 if (tree.getNumArgs() != 1) { 826 PrintFatalError( 827 loc, "replaceWithValue directive must take exactly one argument"); 828 } 829 830 if (!tree.getSymbol().empty()) { 831 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 832 } 833 834 return tree.getArgName(0); 835 } 836 837 std::string PatternEmitter::handleLocationDirective(DagNode tree) { 838 assert(tree.isLocationDirective()); 839 auto lookUpArgLoc = [this, &tree](int idx) { 840 const auto *const lookupFmt = "(*{0}.begin()).getLoc()"; 841 return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt); 842 }; 843 844 if (tree.getNumArgs() == 0) 845 llvm::PrintFatalError( 846 "At least one argument to location directive required"); 847 848 if (!tree.getSymbol().empty()) 849 PrintFatalError(loc, "cannot bind symbol to location"); 850 851 if (tree.getNumArgs() == 1) { 852 DagLeaf leaf = tree.getArgAsLeaf(0); 853 if (leaf.isStringAttr()) 854 return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"))", 855 leaf.getStringAttr()) 856 .str(); 857 return lookUpArgLoc(0); 858 } 859 860 std::string ret; 861 llvm::raw_string_ostream os(ret); 862 std::string strAttr; 863 os << "rewriter.getFusedLoc({"; 864 bool first = true; 865 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 866 DagLeaf leaf = tree.getArgAsLeaf(i); 867 // Handle the optional string value. 868 if (leaf.isStringAttr()) { 869 if (!strAttr.empty()) 870 llvm::PrintFatalError("Only one string attribute may be specified"); 871 strAttr = leaf.getStringAttr(); 872 continue; 873 } 874 os << (first ? "" : ", ") << lookUpArgLoc(i); 875 first = false; 876 } 877 os << "}"; 878 if (!strAttr.empty()) { 879 os << ", rewriter.getStringAttr(\"" << strAttr << "\")"; 880 } 881 os << ")"; 882 return os.str(); 883 } 884 885 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 886 StringRef patArgName) { 887 if (leaf.isStringAttr()) 888 PrintFatalError(loc, "raw string not supported as argument"); 889 if (leaf.isConstantAttr()) { 890 auto constAttr = leaf.getAsConstantAttr(); 891 return handleConstantAttr(constAttr.getAttribute(), 892 constAttr.getConstantValue()); 893 } 894 if (leaf.isEnumAttrCase()) { 895 auto enumCase = leaf.getAsEnumAttrCase(); 896 if (enumCase.isStrCase()) 897 return handleConstantAttr(enumCase, enumCase.getSymbol()); 898 // This is an enum case backed by an IntegerAttr. We need to get its value 899 // to build the constant. 900 std::string val = std::to_string(enumCase.getValue()); 901 return handleConstantAttr(enumCase, val); 902 } 903 904 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 905 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 906 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 907 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 908 << "' (via symbol ref)\n"); 909 return argName; 910 } 911 if (leaf.isNativeCodeCall()) { 912 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 913 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 914 << "' (via NativeCodeCall)\n"); 915 return std::string(repl); 916 } 917 PrintFatalError(loc, "unhandled case when rewriting op"); 918 } 919 920 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, 921 int depth) { 922 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 923 LLVM_DEBUG(tree.print(llvm::dbgs())); 924 LLVM_DEBUG(llvm::dbgs() << '\n'); 925 926 auto fmt = tree.getNativeCodeTemplate(); 927 928 SmallVector<std::string, 16> attrs; 929 930 bool hasLocationDirective; 931 std::string locToUse; 932 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 933 934 for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { 935 if (tree.isNestedDagArg(i)) { 936 attrs.push_back( 937 handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1)); 938 } else { 939 attrs.push_back( 940 handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i))); 941 } 942 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i 943 << " replacement: " << attrs[i] << "\n"); 944 } 945 946 std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs); 947 if (!tree.getSymbol().empty()) { 948 os << formatv("auto {0} = {1};\n", tree.getSymbol(), symbol); 949 symbol = tree.getSymbol().str(); 950 } 951 952 return symbol; 953 } 954 955 int PatternEmitter::getNodeValueCount(DagNode node) { 956 if (node.isOperation()) { 957 // If the op is bound to a symbol in the rewrite rule, query its result 958 // count from the symbol info map. 959 auto symbol = node.getSymbol(); 960 if (!symbol.empty()) { 961 return symbolInfoMap.getStaticValueCount(symbol); 962 } 963 // Otherwise this is an unbound op; we will use all its results. 964 return pattern.getDialectOp(node).getNumResults(); 965 } 966 // TODO: This considers all NativeCodeCall as returning one 967 // value. Enhance if multi-value ones are needed. 968 return 1; 969 } 970 971 std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) { 972 auto numPatArgs = tree.getNumArgs(); 973 974 if (numPatArgs != 0) { 975 if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1)) 976 if (lastArg.isLocationDirective()) { 977 return std::make_pair(true, handleLocationDirective(lastArg)); 978 } 979 } 980 981 // If no explicit location is given, use the default, all fused, location. 982 return std::make_pair(false, "odsLoc"); 983 } 984 985 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 986 int depth) { 987 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); 988 LLVM_DEBUG(tree.print(llvm::dbgs())); 989 LLVM_DEBUG(llvm::dbgs() << '\n'); 990 991 Operator &resultOp = tree.getDialectOp(opMap); 992 auto numOpArgs = resultOp.getNumArgs(); 993 auto numPatArgs = tree.getNumArgs(); 994 995 bool hasLocationDirective; 996 std::string locToUse; 997 std::tie(hasLocationDirective, locToUse) = getLocation(tree); 998 999 auto inPattern = numPatArgs - hasLocationDirective; 1000 if (numOpArgs != inPattern) { 1001 PrintFatalError(loc, 1002 formatv("resultant op '{0}' argument number mismatch: " 1003 "{1} in pattern vs. {2} in definition", 1004 resultOp.getOperationName(), inPattern, numOpArgs)); 1005 } 1006 1007 // A map to collect all nested DAG child nodes' names, with operand index as 1008 // the key. This includes both bound and unbound child nodes. 1009 ChildNodeIndexNameMap childNodeNames; 1010 1011 // First go through all the child nodes who are nested DAG constructs to 1012 // create ops for them and remember the symbol names for them, so that we can 1013 // use the results in the current node. This happens in a recursive manner. 1014 for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { 1015 if (auto child = tree.getArgAsNestedDag(i)) 1016 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 1017 } 1018 1019 // The name of the local variable holding this op. 1020 std::string valuePackName; 1021 // The symbol for holding the result of this pattern. Note that the result of 1022 // this pattern is not necessarily the same as the variable created by this 1023 // pattern because we can use `__N` suffix to refer only a specific result if 1024 // the generated op is a multi-result op. 1025 std::string resultValue; 1026 if (tree.getSymbol().empty()) { 1027 // No symbol is explicitly bound to this op in the pattern. Generate a 1028 // unique name. 1029 valuePackName = resultValue = getUniqueSymbol(&resultOp); 1030 } else { 1031 resultValue = std::string(tree.getSymbol()); 1032 // Strip the index to get the name for the value pack and use it to name the 1033 // local variable for the op. 1034 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue)); 1035 } 1036 1037 // Create the local variable for this op. 1038 os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(), 1039 valuePackName); 1040 1041 // Right now ODS don't have general type inference support. Except a few 1042 // special cases listed below, DRR needs to supply types for all results 1043 // when building an op. 1044 bool isSameOperandsAndResultType = 1045 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType"); 1046 bool useFirstAttr = 1047 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"); 1048 1049 if (isSameOperandsAndResultType || useFirstAttr) { 1050 // We know how to deduce the result type for ops with these traits and we've 1051 // generated builders taking aggregate parameters. Use those builders to 1052 // create the ops. 1053 1054 // First prepare local variables for op arguments used in builder call. 1055 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 1056 1057 // Then create the op. 1058 os.scope("", "\n}\n").os << formatv( 1059 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);", 1060 valuePackName, resultOp.getQualCppClassName(), locToUse); 1061 return resultValue; 1062 } 1063 1064 bool usePartialResults = valuePackName != resultValue; 1065 1066 if (usePartialResults || depth > 0 || resultIndex < 0) { 1067 // For these cases (broadcastable ops, op results used both as auxiliary 1068 // values and replacement values, ops in nested patterns, auxiliary ops), we 1069 // still need to supply the result types when building the op. But because 1070 // we don't generate a builder automatically with ODS for them, it's the 1071 // developer's responsibility to make sure such a builder (with result type 1072 // deduction ability) exists. We go through the separate-parameter builder 1073 // here given that it's easier for developers to write compared to 1074 // aggregate-parameter builders. 1075 createSeparateLocalVarsForOpArgs(tree, childNodeNames); 1076 1077 os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, 1078 resultOp.getQualCppClassName(), locToUse); 1079 supplyValuesForOpArgs(tree, childNodeNames, depth); 1080 os << "\n );\n}\n"; 1081 return resultValue; 1082 } 1083 1084 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 1085 // generated from the source pattern root op. Then we can use the source 1086 // pattern's value types to determine the value type of the generated op 1087 // here. 1088 1089 // First prepare local variables for op arguments used in builder call. 1090 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 1091 1092 // Then prepare the result types. We need to specify the types for all 1093 // results. 1094 os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; " 1095 "(void)tblgen_types;\n"); 1096 int numResults = resultOp.getNumResults(); 1097 if (numResults != 0) { 1098 for (int i = 0; i < numResults; ++i) 1099 os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" 1100 " tblgen_types.push_back(v.getType());\n}\n", 1101 resultIndex + i); 1102 } 1103 os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " 1104 "tblgen_values, tblgen_attrs);\n", 1105 valuePackName, resultOp.getQualCppClassName(), locToUse); 1106 os.unindent() << "}\n"; 1107 return resultValue; 1108 } 1109 1110 void PatternEmitter::createSeparateLocalVarsForOpArgs( 1111 DagNode node, ChildNodeIndexNameMap &childNodeNames) { 1112 Operator &resultOp = node.getDialectOp(opMap); 1113 1114 // Now prepare operands used for building this op: 1115 // * If the operand is non-variadic, we create a `Value` local variable. 1116 // * If the operand is variadic, we create a `SmallVector<Value>` local 1117 // variable. 1118 1119 int valueIndex = 0; // An index for uniquing local variable names. 1120 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1121 const auto *operand = 1122 resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>(); 1123 // We do not need special handling for attributes. 1124 if (!operand) 1125 continue; 1126 1127 raw_indented_ostream::DelimitedScope scope(os); 1128 std::string varName; 1129 if (operand->isVariadic()) { 1130 varName = std::string(formatv("tblgen_values_{0}", valueIndex++)); 1131 os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName); 1132 std::string range; 1133 if (node.isNestedDagArg(argIndex)) { 1134 range = childNodeNames[argIndex]; 1135 } else { 1136 range = std::string(node.getArgName(argIndex)); 1137 } 1138 // Resolve the symbol for all range use so that we have a uniform way of 1139 // capturing the values. 1140 range = symbolInfoMap.getValueAndRangeUse(range); 1141 os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range, 1142 varName); 1143 } else { 1144 varName = std::string(formatv("tblgen_value_{0}", valueIndex++)); 1145 os << formatv("::mlir::Value {0} = ", varName); 1146 if (node.isNestedDagArg(argIndex)) { 1147 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 1148 } else { 1149 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1150 auto symbol = 1151 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1152 if (leaf.isNativeCodeCall()) { 1153 os << std::string( 1154 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1155 } else { 1156 os << symbol; 1157 } 1158 } 1159 os << ";\n"; 1160 } 1161 1162 // Update to use the newly created local variable for building the op later. 1163 childNodeNames[argIndex] = varName; 1164 } 1165 } 1166 1167 void PatternEmitter::supplyValuesForOpArgs( 1168 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 1169 Operator &resultOp = node.getDialectOp(opMap); 1170 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); 1171 argIndex != numOpArgs; ++argIndex) { 1172 // Start each argument on its own line. 1173 os << ",\n "; 1174 1175 Argument opArg = resultOp.getArg(argIndex); 1176 // Handle the case of operand first. 1177 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 1178 if (!operand->name.empty()) 1179 os << "/*" << operand->name << "=*/"; 1180 os << childNodeNames.lookup(argIndex); 1181 continue; 1182 } 1183 1184 // The argument in the op definition. 1185 auto opArgName = resultOp.getArgName(argIndex); 1186 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1187 if (!subTree.isNativeCodeCall()) 1188 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1189 "for creating attribute"); 1190 os << formatv("/*{0}=*/{1}", opArgName, 1191 handleReplaceWithNativeCodeCall(subTree, depth)); 1192 } else { 1193 auto leaf = node.getArgAsLeaf(argIndex); 1194 // The argument in the result DAG pattern. 1195 auto patArgName = node.getArgName(argIndex); 1196 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 1197 // TODO: Refactor out into map to avoid recomputing these. 1198 if (!opArg.is<NamedAttribute *>()) 1199 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 1200 if (!patArgName.empty()) 1201 os << "/*" << patArgName << "=*/"; 1202 } else { 1203 os << "/*" << opArgName << "=*/"; 1204 } 1205 os << handleOpArgument(leaf, patArgName); 1206 } 1207 } 1208 } 1209 1210 void PatternEmitter::createAggregateLocalVarsForOpArgs( 1211 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 1212 Operator &resultOp = node.getDialectOp(opMap); 1213 1214 auto scope = os.scope(); 1215 os << formatv("::mlir::SmallVector<::mlir::Value, 4> " 1216 "tblgen_values; (void)tblgen_values;\n"); 1217 os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> " 1218 "tblgen_attrs; (void)tblgen_attrs;\n"); 1219 1220 const char *addAttrCmd = 1221 "if (auto tmpAttr = {1}) {\n" 1222 " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), " 1223 "tmpAttr);\n}\n"; 1224 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1225 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) { 1226 // The argument in the op definition. 1227 auto opArgName = resultOp.getArgName(argIndex); 1228 if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1229 if (!subTree.isNativeCodeCall()) 1230 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1231 "for creating attribute"); 1232 os << formatv(addAttrCmd, opArgName, 1233 handleReplaceWithNativeCodeCall(subTree, depth + 1)); 1234 } else { 1235 auto leaf = node.getArgAsLeaf(argIndex); 1236 // The argument in the result DAG pattern. 1237 auto patArgName = node.getArgName(argIndex); 1238 os << formatv(addAttrCmd, opArgName, 1239 handleOpArgument(leaf, patArgName)); 1240 } 1241 continue; 1242 } 1243 1244 const auto *operand = 1245 resultOp.getArg(argIndex).get<NamedTypeConstraint *>(); 1246 std::string varName; 1247 if (operand->isVariadic()) { 1248 std::string range; 1249 if (node.isNestedDagArg(argIndex)) { 1250 range = childNodeNames.lookup(argIndex); 1251 } else { 1252 range = std::string(node.getArgName(argIndex)); 1253 } 1254 // Resolve the symbol for all range use so that we have a uniform way of 1255 // capturing the values. 1256 range = symbolInfoMap.getValueAndRangeUse(range); 1257 os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n", 1258 range); 1259 } else { 1260 os << formatv("tblgen_values.push_back("); 1261 if (node.isNestedDagArg(argIndex)) { 1262 os << symbolInfoMap.getValueAndRangeUse( 1263 childNodeNames.lookup(argIndex)); 1264 } else { 1265 DagLeaf leaf = node.getArgAsLeaf(argIndex); 1266 auto symbol = 1267 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 1268 if (leaf.isNativeCodeCall()) { 1269 os << std::string( 1270 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 1271 } else { 1272 os << symbol; 1273 } 1274 } 1275 os << ");\n"; 1276 } 1277 } 1278 } 1279 1280 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 1281 emitSourceFileHeader("Rewriters", os); 1282 1283 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 1284 auto numPatterns = patterns.size(); 1285 1286 // We put the map here because it can be shared among multiple patterns. 1287 RecordOperatorMap recordOpMap; 1288 1289 std::vector<std::string> rewriterNames; 1290 rewriterNames.reserve(numPatterns); 1291 1292 std::string baseRewriterName = "GeneratedConvert"; 1293 int rewriterIndex = 0; 1294 1295 for (Record *p : patterns) { 1296 std::string name; 1297 if (p->isAnonymous()) { 1298 // If no name is provided, ensure unique rewriter names simply by 1299 // appending unique suffix. 1300 name = baseRewriterName + llvm::utostr(rewriterIndex++); 1301 } else { 1302 name = std::string(p->getName()); 1303 } 1304 LLVM_DEBUG(llvm::dbgs() 1305 << "=== start generating pattern '" << name << "' ===\n"); 1306 PatternEmitter(p, &recordOpMap, os).emit(name); 1307 LLVM_DEBUG(llvm::dbgs() 1308 << "=== done generating pattern '" << name << "' ===\n"); 1309 rewriterNames.push_back(std::move(name)); 1310 } 1311 1312 // Emit function to add the generated matchers to the pattern list. 1313 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(" 1314 "::mlir::RewritePatternSet &patterns) {\n"; 1315 for (const auto &name : rewriterNames) { 1316 os << " patterns.add<" << name << ">(patterns.getContext());\n"; 1317 } 1318 os << "}\n"; 1319 } 1320 1321 static mlir::GenRegistration 1322 genRewriters("gen-rewriters", "Generate pattern rewriters", 1323 [](const RecordKeeper &records, raw_ostream &os) { 1324 emitRewriters(records, os); 1325 return false; 1326 }); 1327