1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements mlir::applyPatternsGreedily. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 14 15 #include "mlir/Config/mlir-config.h" 16 #include "mlir/IR/Action.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/Verifier.h" 19 #include "mlir/Interfaces/SideEffectInterfaces.h" 20 #include "mlir/Rewrite/PatternApplicator.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/RegionUtils.h" 23 #include "llvm/ADT/BitVector.h" 24 #include "llvm/ADT/DenseMap.h" 25 #include "llvm/ADT/ScopeExit.h" 26 #include "llvm/Support/CommandLine.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/ScopedPrinter.h" 29 #include "llvm/Support/raw_ostream.h" 30 31 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED 32 #include <random> 33 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED 34 35 using namespace mlir; 36 37 #define DEBUG_TYPE "greedy-rewriter" 38 39 namespace { 40 41 //===----------------------------------------------------------------------===// 42 // Debugging Infrastructure 43 //===----------------------------------------------------------------------===// 44 45 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 46 /// A helper struct that performs various "expensive checks" to detect broken 47 /// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is 48 /// broken if: 49 /// * IR does not verify after pattern application / folding. 50 /// * Pattern returns "failure" but the IR has changed. 51 /// * Pattern returns "success" but the IR has not changed. 52 /// 53 /// This struct stores finger prints of ops to determine whether the IR has 54 /// changed or not. 55 struct ExpensiveChecks : public RewriterBase::ForwardingListener { 56 ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel) 57 : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {} 58 59 /// Compute finger prints of the given op and its nested ops. 60 void computeFingerPrints(Operation *topLevel) { 61 this->topLevel = topLevel; 62 this->topLevelFingerPrint.emplace(topLevel); 63 topLevel->walk([&](Operation *op) { 64 fingerprints.try_emplace(op, op, /*includeNested=*/false); 65 }); 66 } 67 68 /// Clear all finger prints. 69 void clear() { 70 topLevel = nullptr; 71 topLevelFingerPrint.reset(); 72 fingerprints.clear(); 73 } 74 75 void notifyRewriteSuccess() { 76 if (!topLevel) 77 return; 78 79 // Make sure that the IR still verifies. 80 if (failed(verify(topLevel))) 81 llvm::report_fatal_error("IR failed to verify after pattern application"); 82 83 // Pattern application success => IR must have changed. 84 OperationFingerPrint afterFingerPrint(topLevel); 85 if (*topLevelFingerPrint == afterFingerPrint) { 86 // Note: Run "mlir-opt -debug" to see which pattern is broken. 87 llvm::report_fatal_error( 88 "pattern returned success but IR did not change"); 89 } 90 for (const auto &it : fingerprints) { 91 // Skip top-level op, its finger print is never invalidated. 92 if (it.first == topLevel) 93 continue; 94 // Note: Finger print computation may crash when an op was erased 95 // without notifying the rewriter. (Run with ASAN to see where the op was 96 // erased; the op was probably erased directly, bypassing the rewriter 97 // API.) Finger print computation does may not crash if a new op was 98 // created at the same memory location. (But then the finger print should 99 // have changed.) 100 if (it.second != 101 OperationFingerPrint(it.first, /*includeNested=*/false)) { 102 // Note: Run "mlir-opt -debug" to see which pattern is broken. 103 llvm::report_fatal_error("operation finger print changed"); 104 } 105 } 106 } 107 108 void notifyRewriteFailure() { 109 if (!topLevel) 110 return; 111 112 // Pattern application failure => IR must not have changed. 113 OperationFingerPrint afterFingerPrint(topLevel); 114 if (*topLevelFingerPrint != afterFingerPrint) { 115 // Note: Run "mlir-opt -debug" to see which pattern is broken. 116 llvm::report_fatal_error("pattern returned failure but IR did change"); 117 } 118 } 119 120 void notifyFoldingSuccess() { 121 if (!topLevel) 122 return; 123 124 // Make sure that the IR still verifies. 125 if (failed(verify(topLevel))) 126 llvm::report_fatal_error("IR failed to verify after folding"); 127 } 128 129 protected: 130 /// Invalidate the finger print of the given op, i.e., remove it from the map. 131 void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); } 132 133 void notifyBlockErased(Block *block) override { 134 RewriterBase::ForwardingListener::notifyBlockErased(block); 135 136 // The block structure (number of blocks, types of block arguments, etc.) 137 // is part of the fingerprint of the parent op. 138 // TODO: The parent op fingerprint should also be invalidated when modifying 139 // the block arguments of a block, but we do not have a 140 // `notifyBlockModified` callback yet. 141 invalidateFingerPrint(block->getParentOp()); 142 } 143 144 void notifyOperationInserted(Operation *op, 145 OpBuilder::InsertPoint previous) override { 146 RewriterBase::ForwardingListener::notifyOperationInserted(op, previous); 147 invalidateFingerPrint(op->getParentOp()); 148 } 149 150 void notifyOperationModified(Operation *op) override { 151 RewriterBase::ForwardingListener::notifyOperationModified(op); 152 invalidateFingerPrint(op); 153 } 154 155 void notifyOperationErased(Operation *op) override { 156 RewriterBase::ForwardingListener::notifyOperationErased(op); 157 op->walk([this](Operation *op) { invalidateFingerPrint(op); }); 158 } 159 160 /// Operation finger prints to detect invalid pattern API usage. IR is checked 161 /// against these finger prints after pattern application to detect cases 162 /// where IR was modified directly, bypassing the rewriter API. 163 DenseMap<Operation *, OperationFingerPrint> fingerprints; 164 165 /// Top-level operation of the current greedy rewrite. 166 Operation *topLevel = nullptr; 167 168 /// Finger print of the top-level operation. 169 std::optional<OperationFingerPrint> topLevelFingerPrint; 170 }; 171 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 172 173 #ifndef NDEBUG 174 static Operation *getDumpRootOp(Operation *op) { 175 // Dump the parent op so that materialized constants are visible. If the op 176 // is a top-level op, dump it directly. 177 if (Operation *parentOp = op->getParentOp()) 178 return parentOp; 179 return op; 180 } 181 static void logSuccessfulFolding(Operation *op) { 182 llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n"; 183 op->dump(); 184 llvm::dbgs() << "\n\n"; 185 } 186 #endif // NDEBUG 187 188 //===----------------------------------------------------------------------===// 189 // Worklist 190 //===----------------------------------------------------------------------===// 191 192 /// A LIFO worklist of operations with efficient removal and set semantics. 193 /// 194 /// This class maintains a vector of operations and a mapping of operations to 195 /// positions in the vector, so that operations can be removed efficiently at 196 /// random. When an operation is removed, it is replaced with nullptr. Such 197 /// nullptr are skipped when pop'ing elements. 198 class Worklist { 199 public: 200 Worklist(); 201 202 /// Clear the worklist. 203 void clear(); 204 205 /// Return whether the worklist is empty. 206 bool empty() const; 207 208 /// Push an operation to the end of the worklist, unless the operation is 209 /// already on the worklist. 210 void push(Operation *op); 211 212 /// Pop the an operation from the end of the worklist. Only allowed on 213 /// non-empty worklists. 214 Operation *pop(); 215 216 /// Remove an operation from the worklist. 217 void remove(Operation *op); 218 219 /// Reverse the worklist. 220 void reverse(); 221 222 protected: 223 /// The worklist of operations. 224 std::vector<Operation *> list; 225 226 /// A mapping of operations to positions in `list`. 227 DenseMap<Operation *, unsigned> map; 228 }; 229 230 Worklist::Worklist() { list.reserve(64); } 231 232 void Worklist::clear() { 233 list.clear(); 234 map.clear(); 235 } 236 237 bool Worklist::empty() const { 238 // Skip all nullptr. 239 return !llvm::any_of(list, 240 [](Operation *op) { return static_cast<bool>(op); }); 241 } 242 243 void Worklist::push(Operation *op) { 244 assert(op && "cannot push nullptr to worklist"); 245 // Check to see if the worklist already contains this op. 246 if (!map.insert({op, list.size()}).second) 247 return; 248 list.push_back(op); 249 } 250 251 Operation *Worklist::pop() { 252 assert(!empty() && "cannot pop from empty worklist"); 253 // Skip and remove all trailing nullptr. 254 while (!list.back()) 255 list.pop_back(); 256 Operation *op = list.back(); 257 list.pop_back(); 258 map.erase(op); 259 // Cleanup: Remove all trailing nullptr. 260 while (!list.empty() && !list.back()) 261 list.pop_back(); 262 return op; 263 } 264 265 void Worklist::remove(Operation *op) { 266 assert(op && "cannot remove nullptr from worklist"); 267 auto it = map.find(op); 268 if (it != map.end()) { 269 assert(list[it->second] == op && "malformed worklist data structure"); 270 list[it->second] = nullptr; 271 map.erase(it); 272 } 273 } 274 275 void Worklist::reverse() { 276 std::reverse(list.begin(), list.end()); 277 for (size_t i = 0, e = list.size(); i != e; ++i) 278 map[list[i]] = i; 279 } 280 281 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED 282 /// A worklist that pops elements at a random position. This worklist is for 283 /// testing/debugging purposes only. It can be used to ensure that lowering 284 /// pipelines work correctly regardless of the order in which ops are processed 285 /// by the GreedyPatternRewriteDriver. 286 class RandomizedWorklist : public Worklist { 287 public: 288 RandomizedWorklist() : Worklist() { 289 generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED); 290 } 291 292 /// Pop a random non-empty op from the worklist. 293 Operation *pop() { 294 Operation *op = nullptr; 295 do { 296 assert(!list.empty() && "cannot pop from empty worklist"); 297 int64_t pos = generator() % list.size(); 298 op = list[pos]; 299 list.erase(list.begin() + pos); 300 for (int64_t i = pos, e = list.size(); i < e; ++i) 301 map[list[i]] = i; 302 map.erase(op); 303 } while (!op); 304 return op; 305 } 306 307 private: 308 std::minstd_rand0 generator; 309 }; 310 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED 311 312 //===----------------------------------------------------------------------===// 313 // GreedyPatternRewriteDriver 314 //===----------------------------------------------------------------------===// 315 316 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly 317 /// applies the locally optimal patterns. 318 /// 319 /// This abstract class manages the worklist and contains helper methods for 320 /// rewriting ops on the worklist. Derived classes specify how ops are added 321 /// to the worklist in the beginning. 322 class GreedyPatternRewriteDriver : public RewriterBase::Listener { 323 protected: 324 explicit GreedyPatternRewriteDriver(MLIRContext *ctx, 325 const FrozenRewritePatternSet &patterns, 326 const GreedyRewriteConfig &config); 327 328 /// Add the given operation to the worklist. 329 void addSingleOpToWorklist(Operation *op); 330 331 /// Add the given operation and its ancestors to the worklist. 332 void addToWorklist(Operation *op); 333 334 /// Notify the driver that the specified operation may have been modified 335 /// in-place. The operation is added to the worklist. 336 void notifyOperationModified(Operation *op) override; 337 338 /// Notify the driver that the specified operation was inserted. Update the 339 /// worklist as needed: The operation is enqueued depending on scope and 340 /// strict mode. 341 void notifyOperationInserted(Operation *op, 342 OpBuilder::InsertPoint previous) override; 343 344 /// Notify the driver that the specified operation was removed. Update the 345 /// worklist as needed: The operation and its children are removed from the 346 /// worklist. 347 void notifyOperationErased(Operation *op) override; 348 349 /// Notify the driver that the specified operation was replaced. Update the 350 /// worklist as needed: New users are added enqueued. 351 void notifyOperationReplaced(Operation *op, ValueRange replacement) override; 352 353 /// Process ops until the worklist is empty or `config.maxNumRewrites` is 354 /// reached. Return `true` if any IR was changed. 355 bool processWorklist(); 356 357 /// The pattern rewriter that is used for making IR modifications and is 358 /// passed to rewrite patterns. 359 PatternRewriter rewriter; 360 361 /// The worklist for this transformation keeps track of the operations that 362 /// need to be (re)visited. 363 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED 364 RandomizedWorklist worklist; 365 #else 366 Worklist worklist; 367 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED 368 369 /// Configuration information for how to simplify. 370 const GreedyRewriteConfig config; 371 372 /// The list of ops we are restricting our rewrites to. These include the 373 /// supplied set of ops as well as new ops created while rewriting those ops 374 /// depending on `strictMode`. This set is not maintained when 375 /// `config.strictMode` is GreedyRewriteStrictness::AnyOp. 376 llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps; 377 378 private: 379 /// Look over the provided operands for any defining operations that should 380 /// be re-added to the worklist. This function should be called when an 381 /// operation is modified or removed, as it may trigger further 382 /// simplifications. 383 void addOperandsToWorklist(Operation *op); 384 385 /// Notify the driver that the given block was inserted. 386 void notifyBlockInserted(Block *block, Region *previous, 387 Region::iterator previousIt) override; 388 389 /// Notify the driver that the given block is about to be removed. 390 void notifyBlockErased(Block *block) override; 391 392 /// For debugging only: Notify the driver of a pattern match failure. 393 void 394 notifyMatchFailure(Location loc, 395 function_ref<void(Diagnostic &)> reasonCallback) override; 396 397 #ifndef NDEBUG 398 /// A logger used to emit information during the application process. 399 llvm::ScopedPrinter logger{llvm::dbgs()}; 400 #endif 401 402 /// The low-level pattern applicator. 403 PatternApplicator matcher; 404 405 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 406 ExpensiveChecks expensiveChecks; 407 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 408 }; 409 } // namespace 410 411 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( 412 MLIRContext *ctx, const FrozenRewritePatternSet &patterns, 413 const GreedyRewriteConfig &config) 414 : rewriter(ctx), config(config), matcher(patterns) 415 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 416 // clang-format off 417 , expensiveChecks( 418 /*driver=*/this, 419 /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr) 420 // clang-format on 421 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 422 { 423 // Apply a simple cost model based solely on pattern benefit. 424 matcher.applyDefaultCostModel(); 425 426 // Set up listener. 427 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 428 // Send IR notifications to the debug handler. This handler will then forward 429 // all notifications to this GreedyPatternRewriteDriver. 430 rewriter.setListener(&expensiveChecks); 431 #else 432 rewriter.setListener(this); 433 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 434 } 435 436 bool GreedyPatternRewriteDriver::processWorklist() { 437 #ifndef NDEBUG 438 const char *logLineComment = 439 "//===-------------------------------------------===//\n"; 440 441 /// A utility function to log a process result for the given reason. 442 auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) { 443 logger.unindent(); 444 logger.startLine() << "} -> " << result; 445 if (!msg.isTriviallyEmpty()) 446 logger.getOStream() << " : " << msg; 447 logger.getOStream() << "\n"; 448 }; 449 auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) { 450 logResult(result, msg); 451 logger.startLine() << logLineComment; 452 }; 453 #endif 454 455 bool changed = false; 456 int64_t numRewrites = 0; 457 while (!worklist.empty() && 458 (numRewrites < config.maxNumRewrites || 459 config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { 460 auto *op = worklist.pop(); 461 462 LLVM_DEBUG({ 463 logger.getOStream() << "\n"; 464 logger.startLine() << logLineComment; 465 logger.startLine() << "Processing operation : '" << op->getName() << "'(" 466 << op << ") {\n"; 467 logger.indent(); 468 469 // If the operation has no regions, just print it here. 470 if (op->getNumRegions() == 0) { 471 op->print( 472 logger.startLine(), 473 OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); 474 logger.getOStream() << "\n\n"; 475 } 476 }); 477 478 // If the operation is trivially dead - remove it. 479 if (isOpTriviallyDead(op)) { 480 rewriter.eraseOp(op); 481 changed = true; 482 483 LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); 484 continue; 485 } 486 487 // Try to fold this op. Do not fold constant ops. That would lead to an 488 // infinite folding loop, as every constant op would be folded to an 489 // Attribute and then immediately be rematerialized as a constant op, which 490 // is then put on the worklist. 491 if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) { 492 SmallVector<OpFoldResult> foldResults; 493 if (succeeded(op->fold(foldResults))) { 494 LLVM_DEBUG(logResultWithLine("success", "operation was folded")); 495 #ifndef NDEBUG 496 Operation *dumpRootOp = getDumpRootOp(op); 497 #endif // NDEBUG 498 if (foldResults.empty()) { 499 // Op was modified in-place. 500 notifyOperationModified(op); 501 changed = true; 502 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp)); 503 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 504 expensiveChecks.notifyFoldingSuccess(); 505 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 506 continue; 507 } 508 509 // Op results can be replaced with `foldResults`. 510 assert(foldResults.size() == op->getNumResults() && 511 "folder produced incorrect number of results"); 512 OpBuilder::InsertionGuard g(rewriter); 513 rewriter.setInsertionPoint(op); 514 SmallVector<Value> replacements; 515 bool materializationSucceeded = true; 516 for (auto [ofr, resultType] : 517 llvm::zip_equal(foldResults, op->getResultTypes())) { 518 if (auto value = ofr.dyn_cast<Value>()) { 519 assert(value.getType() == resultType && 520 "folder produced value of incorrect type"); 521 replacements.push_back(value); 522 continue; 523 } 524 // Materialize Attributes as SSA values. 525 Operation *constOp = op->getDialect()->materializeConstant( 526 rewriter, cast<Attribute>(ofr), resultType, op->getLoc()); 527 528 if (!constOp) { 529 // If materialization fails, cleanup any operations generated for 530 // the previous results. 531 llvm::SmallDenseSet<Operation *> replacementOps; 532 for (Value replacement : replacements) { 533 assert(replacement.use_empty() && 534 "folder reused existing op for one result but constant " 535 "materialization failed for another result"); 536 replacementOps.insert(replacement.getDefiningOp()); 537 } 538 for (Operation *op : replacementOps) { 539 rewriter.eraseOp(op); 540 } 541 542 materializationSucceeded = false; 543 break; 544 } 545 546 assert(constOp->hasTrait<OpTrait::ConstantLike>() && 547 "materializeConstant produced op that is not a ConstantLike"); 548 assert(constOp->getResultTypes()[0] == resultType && 549 "materializeConstant produced incorrect result type"); 550 replacements.push_back(constOp->getResult(0)); 551 } 552 553 if (materializationSucceeded) { 554 rewriter.replaceOp(op, replacements); 555 changed = true; 556 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp)); 557 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 558 expensiveChecks.notifyFoldingSuccess(); 559 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 560 continue; 561 } 562 } 563 } 564 565 // Try to match one of the patterns. The rewriter is automatically 566 // notified of any necessary changes, so there is nothing else to do 567 // here. 568 auto canApplyCallback = [&](const Pattern &pattern) { 569 LLVM_DEBUG({ 570 logger.getOStream() << "\n"; 571 logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '" 572 << op->getName() << " -> ("; 573 llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream()); 574 logger.getOStream() << ")' {\n"; 575 logger.indent(); 576 }); 577 if (config.listener) 578 config.listener->notifyPatternBegin(pattern, op); 579 return true; 580 }; 581 function_ref<bool(const Pattern &)> canApply = canApplyCallback; 582 auto onFailureCallback = [&](const Pattern &pattern) { 583 LLVM_DEBUG(logResult("failure", "pattern failed to match")); 584 if (config.listener) 585 config.listener->notifyPatternEnd(pattern, failure()); 586 }; 587 function_ref<void(const Pattern &)> onFailure = onFailureCallback; 588 auto onSuccessCallback = [&](const Pattern &pattern) { 589 LLVM_DEBUG(logResult("success", "pattern applied successfully")); 590 if (config.listener) 591 config.listener->notifyPatternEnd(pattern, success()); 592 return success(); 593 }; 594 function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback; 595 596 #ifdef NDEBUG 597 // Optimization: PatternApplicator callbacks are not needed when running in 598 // optimized mode and without a listener. 599 if (!config.listener) { 600 canApply = nullptr; 601 onFailure = nullptr; 602 onSuccess = nullptr; 603 } 604 #endif // NDEBUG 605 606 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 607 if (config.scope) { 608 expensiveChecks.computeFingerPrints(config.scope->getParentOp()); 609 } 610 auto clearFingerprints = 611 llvm::make_scope_exit([&]() { expensiveChecks.clear(); }); 612 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 613 614 LogicalResult matchResult = 615 matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess); 616 617 if (succeeded(matchResult)) { 618 LLVM_DEBUG(logResultWithLine("success", "pattern matched")); 619 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 620 expensiveChecks.notifyRewriteSuccess(); 621 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 622 changed = true; 623 ++numRewrites; 624 } else { 625 LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); 626 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 627 expensiveChecks.notifyRewriteFailure(); 628 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 629 } 630 } 631 632 return changed; 633 } 634 635 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { 636 assert(op && "expected valid op"); 637 // Gather potential ancestors while looking for a "scope" parent region. 638 SmallVector<Operation *, 8> ancestors; 639 Region *region = nullptr; 640 do { 641 ancestors.push_back(op); 642 region = op->getParentRegion(); 643 if (config.scope == region) { 644 // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops. 645 for (Operation *op : ancestors) 646 addSingleOpToWorklist(op); 647 return; 648 } 649 if (region == nullptr) 650 return; 651 } while ((op = region->getParentOp())); 652 } 653 654 void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) { 655 if (config.strictMode == GreedyRewriteStrictness::AnyOp || 656 strictModeFilteredOps.contains(op)) 657 worklist.push(op); 658 } 659 660 void GreedyPatternRewriteDriver::notifyBlockInserted( 661 Block *block, Region *previous, Region::iterator previousIt) { 662 if (config.listener) 663 config.listener->notifyBlockInserted(block, previous, previousIt); 664 } 665 666 void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) { 667 if (config.listener) 668 config.listener->notifyBlockErased(block); 669 } 670 671 void GreedyPatternRewriteDriver::notifyOperationInserted( 672 Operation *op, OpBuilder::InsertPoint previous) { 673 LLVM_DEBUG({ 674 logger.startLine() << "** Insert : '" << op->getName() << "'(" << op 675 << ")\n"; 676 }); 677 if (config.listener) 678 config.listener->notifyOperationInserted(op, previous); 679 if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps) 680 strictModeFilteredOps.insert(op); 681 addToWorklist(op); 682 } 683 684 void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) { 685 LLVM_DEBUG({ 686 logger.startLine() << "** Modified: '" << op->getName() << "'(" << op 687 << ")\n"; 688 }); 689 if (config.listener) 690 config.listener->notifyOperationModified(op); 691 addToWorklist(op); 692 } 693 694 void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) { 695 for (Value operand : op->getOperands()) { 696 // If this operand currently has at most 2 users, add its defining op to the 697 // worklist. Indeed, after the op is deleted, then the operand will have at 698 // most 1 user left. If it has 0 users left, it can be deleted too, 699 // and if it has 1 user left, there may be further canonicalization 700 // opportunities. 701 if (!operand) 702 continue; 703 704 auto *defOp = operand.getDefiningOp(); 705 if (!defOp) 706 continue; 707 708 Operation *otherUser = nullptr; 709 bool hasMoreThanTwoUses = false; 710 for (auto user : operand.getUsers()) { 711 if (user == op || user == otherUser) 712 continue; 713 if (!otherUser) { 714 otherUser = user; 715 continue; 716 } 717 hasMoreThanTwoUses = true; 718 break; 719 } 720 if (hasMoreThanTwoUses) 721 continue; 722 723 addToWorklist(defOp); 724 } 725 } 726 727 void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) { 728 LLVM_DEBUG({ 729 logger.startLine() << "** Erase : '" << op->getName() << "'(" << op 730 << ")\n"; 731 }); 732 733 #ifndef NDEBUG 734 // Only ops that are within the configured scope are added to the worklist of 735 // the greedy pattern rewriter. Moreover, the parent op of the scope region is 736 // the part of the IR that is taken into account for the "expensive checks". 737 // A greedy pattern rewrite is not allowed to erase the parent op of the scope 738 // region, as that would break the worklist handling and the expensive checks. 739 if (config.scope && config.scope->getParentOp() == op) 740 llvm_unreachable( 741 "scope region must not be erased during greedy pattern rewrite"); 742 #endif // NDEBUG 743 744 if (config.listener) 745 config.listener->notifyOperationErased(op); 746 747 addOperandsToWorklist(op); 748 worklist.remove(op); 749 750 if (config.strictMode != GreedyRewriteStrictness::AnyOp) 751 strictModeFilteredOps.erase(op); 752 } 753 754 void GreedyPatternRewriteDriver::notifyOperationReplaced( 755 Operation *op, ValueRange replacement) { 756 LLVM_DEBUG({ 757 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op 758 << ")\n"; 759 }); 760 if (config.listener) 761 config.listener->notifyOperationReplaced(op, replacement); 762 } 763 764 void GreedyPatternRewriteDriver::notifyMatchFailure( 765 Location loc, function_ref<void(Diagnostic &)> reasonCallback) { 766 LLVM_DEBUG({ 767 Diagnostic diag(loc, DiagnosticSeverity::Remark); 768 reasonCallback(diag); 769 logger.startLine() << "** Match Failure : " << diag.str() << "\n"; 770 }); 771 if (config.listener) 772 config.listener->notifyMatchFailure(loc, reasonCallback); 773 } 774 775 //===----------------------------------------------------------------------===// 776 // RegionPatternRewriteDriver 777 //===----------------------------------------------------------------------===// 778 779 namespace { 780 /// This driver simplfies all ops in a region. 781 class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver { 782 public: 783 explicit RegionPatternRewriteDriver(MLIRContext *ctx, 784 const FrozenRewritePatternSet &patterns, 785 const GreedyRewriteConfig &config, 786 Region ®ions); 787 788 /// Simplify ops inside `region` and simplify the region itself. Return 789 /// success if the transformation converged. 790 LogicalResult simplify(bool *changed) &&; 791 792 private: 793 /// The region that is simplified. 794 Region ®ion; 795 }; 796 } // namespace 797 798 RegionPatternRewriteDriver::RegionPatternRewriteDriver( 799 MLIRContext *ctx, const FrozenRewritePatternSet &patterns, 800 const GreedyRewriteConfig &config, Region ®ion) 801 : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) { 802 // Populate strict mode ops. 803 if (config.strictMode != GreedyRewriteStrictness::AnyOp) { 804 region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); }); 805 } 806 } 807 808 namespace { 809 class GreedyPatternRewriteIteration 810 : public tracing::ActionImpl<GreedyPatternRewriteIteration> { 811 public: 812 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration) 813 GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration) 814 : tracing::ActionImpl<GreedyPatternRewriteIteration>(units), 815 iteration(iteration) {} 816 static constexpr StringLiteral tag = "GreedyPatternRewriteIteration"; 817 void print(raw_ostream &os) const override { 818 os << "GreedyPatternRewriteIteration(" << iteration << ")"; 819 } 820 821 private: 822 int64_t iteration = 0; 823 }; 824 } // namespace 825 826 LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { 827 bool continueRewrites = false; 828 int64_t iteration = 0; 829 MLIRContext *ctx = rewriter.getContext(); 830 do { 831 // Check if the iteration limit was reached. 832 if (++iteration > config.maxIterations && 833 config.maxIterations != GreedyRewriteConfig::kNoLimit) 834 break; 835 836 // New iteration: start with an empty worklist. 837 worklist.clear(); 838 839 // `OperationFolder` CSE's constant ops (and may move them into parents 840 // regions to enable more aggressive CSE'ing). 841 OperationFolder folder(ctx, this); 842 auto insertKnownConstant = [&](Operation *op) { 843 // Check for existing constants when populating the worklist. This avoids 844 // accidentally reversing the constant order during processing. 845 Attribute constValue; 846 if (matchPattern(op, m_Constant(&constValue))) 847 if (!folder.insertKnownConstant(op, constValue)) 848 return true; 849 return false; 850 }; 851 852 if (!config.useTopDownTraversal) { 853 // Add operations to the worklist in postorder. 854 region.walk([&](Operation *op) { 855 if (!config.cseConstants || !insertKnownConstant(op)) 856 addToWorklist(op); 857 }); 858 } else { 859 // Add all nested operations to the worklist in preorder. 860 region.walk<WalkOrder::PreOrder>([&](Operation *op) { 861 if (!config.cseConstants || !insertKnownConstant(op)) { 862 addToWorklist(op); 863 return WalkResult::advance(); 864 } 865 return WalkResult::skip(); 866 }); 867 868 // Reverse the list so our pop-back loop processes them in-order. 869 worklist.reverse(); 870 } 871 872 ctx->executeAction<GreedyPatternRewriteIteration>( 873 [&] { 874 continueRewrites = processWorklist(); 875 876 // After applying patterns, make sure that the CFG of each of the 877 // regions is kept up to date. 878 if (config.enableRegionSimplification != 879 GreedySimplifyRegionLevel::Disabled) { 880 continueRewrites |= succeeded(simplifyRegions( 881 rewriter, region, 882 /*mergeBlocks=*/config.enableRegionSimplification == 883 GreedySimplifyRegionLevel::Aggressive)); 884 } 885 }, 886 {®ion}, iteration); 887 } while (continueRewrites); 888 889 if (changed) 890 *changed = iteration > 1; 891 892 // Whether the rewrite converges, i.e. wasn't changed in the last iteration. 893 return success(!continueRewrites); 894 } 895 896 LogicalResult 897 mlir::applyPatternsGreedily(Region ®ion, 898 const FrozenRewritePatternSet &patterns, 899 GreedyRewriteConfig config, bool *changed) { 900 // The top-level operation must be known to be isolated from above to 901 // prevent performing canonicalizations on operations defined at or above 902 // the region containing 'op'. 903 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && 904 "patterns can only be applied to operations IsolatedFromAbove"); 905 906 // Set scope if not specified. 907 if (!config.scope) 908 config.scope = ®ion; 909 910 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 911 if (failed(verify(config.scope->getParentOp()))) 912 llvm::report_fatal_error( 913 "greedy pattern rewriter input IR failed to verify"); 914 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 915 916 // Start the pattern driver. 917 RegionPatternRewriteDriver driver(region.getContext(), patterns, config, 918 region); 919 LogicalResult converged = std::move(driver).simplify(changed); 920 LLVM_DEBUG(if (failed(converged)) { 921 llvm::dbgs() << "The pattern rewrite did not converge after scanning " 922 << config.maxIterations << " times\n"; 923 }); 924 return converged; 925 } 926 927 //===----------------------------------------------------------------------===// 928 // MultiOpPatternRewriteDriver 929 //===----------------------------------------------------------------------===// 930 931 namespace { 932 /// This driver simplfies a list of ops. 933 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { 934 public: 935 explicit MultiOpPatternRewriteDriver( 936 MLIRContext *ctx, const FrozenRewritePatternSet &patterns, 937 const GreedyRewriteConfig &config, ArrayRef<Operation *> ops, 938 llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr); 939 940 /// Simplify `ops`. Return `success` if the transformation converged. 941 LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&; 942 943 private: 944 void notifyOperationErased(Operation *op) override { 945 GreedyPatternRewriteDriver::notifyOperationErased(op); 946 if (survivingOps) 947 survivingOps->erase(op); 948 } 949 950 /// An optional set of ops that survived the rewrite. This set is populated 951 /// at the beginning of `simplifyLocally` with the inititally provided list 952 /// of ops. 953 llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr; 954 }; 955 } // namespace 956 957 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver( 958 MLIRContext *ctx, const FrozenRewritePatternSet &patterns, 959 const GreedyRewriteConfig &config, ArrayRef<Operation *> ops, 960 llvm::SmallDenseSet<Operation *, 4> *survivingOps) 961 : GreedyPatternRewriteDriver(ctx, patterns, config), 962 survivingOps(survivingOps) { 963 if (config.strictMode != GreedyRewriteStrictness::AnyOp) 964 strictModeFilteredOps.insert(ops.begin(), ops.end()); 965 966 if (survivingOps) { 967 survivingOps->clear(); 968 survivingOps->insert(ops.begin(), ops.end()); 969 } 970 } 971 972 LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops, 973 bool *changed) && { 974 // Populate the initial worklist. 975 for (Operation *op : ops) 976 addSingleOpToWorklist(op); 977 978 // Process ops on the worklist. 979 bool result = processWorklist(); 980 if (changed) 981 *changed = result; 982 983 return success(worklist.empty()); 984 } 985 986 /// Find the region that is the closest common ancestor of all given ops. 987 /// 988 /// Note: This function returns `nullptr` if there is a top-level op among the 989 /// given list of ops. 990 static Region *findCommonAncestor(ArrayRef<Operation *> ops) { 991 assert(!ops.empty() && "expected at least one op"); 992 // Fast path in case there is only one op. 993 if (ops.size() == 1) 994 return ops.front()->getParentRegion(); 995 996 Region *region = ops.front()->getParentRegion(); 997 ops = ops.drop_front(); 998 int sz = ops.size(); 999 llvm::BitVector remainingOps(sz, true); 1000 while (region) { 1001 int pos = -1; 1002 // Iterate over all remaining ops. 1003 while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) { 1004 // Is this op contained in `region`? 1005 if (region->findAncestorOpInRegion(*ops[pos])) 1006 remainingOps.reset(pos); 1007 } 1008 if (remainingOps.none()) 1009 break; 1010 region = region->getParentRegion(); 1011 } 1012 return region; 1013 } 1014 1015 LogicalResult mlir::applyOpPatternsGreedily( 1016 ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns, 1017 GreedyRewriteConfig config, bool *changed, bool *allErased) { 1018 if (ops.empty()) { 1019 if (changed) 1020 *changed = false; 1021 if (allErased) 1022 *allErased = true; 1023 return success(); 1024 } 1025 1026 // Determine scope of rewrite. 1027 if (!config.scope) { 1028 // Compute scope if none was provided. The scope will remain `nullptr` if 1029 // there is a top-level op among `ops`. 1030 config.scope = findCommonAncestor(ops); 1031 } else { 1032 // If a scope was provided, make sure that all ops are in scope. 1033 #ifndef NDEBUG 1034 bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) { 1035 return static_cast<bool>(config.scope->findAncestorOpInRegion(*op)); 1036 }); 1037 assert(allOpsInScope && "ops must be within the specified scope"); 1038 #endif // NDEBUG 1039 } 1040 1041 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 1042 if (config.scope && failed(verify(config.scope->getParentOp()))) 1043 llvm::report_fatal_error( 1044 "greedy pattern rewriter input IR failed to verify"); 1045 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 1046 1047 // Start the pattern driver. 1048 llvm::SmallDenseSet<Operation *, 4> surviving; 1049 MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, 1050 config, ops, 1051 allErased ? &surviving : nullptr); 1052 LogicalResult converged = std::move(driver).simplify(ops, changed); 1053 if (allErased) 1054 *allErased = surviving.empty(); 1055 LLVM_DEBUG(if (failed(converged)) { 1056 llvm::dbgs() << "The pattern rewrite did not converge after " 1057 << config.maxNumRewrites << " rewrites"; 1058 }); 1059 return converged; 1060 } 1061